Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sdk/batch/speechmatics/batch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@
from ._models import TranscriptFilteringConfig
from ._models import TranscriptionConfig
from ._models import TranslationConfig
from ._transport import PROCESSING_DATA_HEADER

__all__ = [
"AsyncClient",
"PROCESSING_DATA_HEADER",
"AuthBase",
"AuthenticationError",
"AutoChaptersConfig",
Expand Down
14 changes: 11 additions & 3 deletions sdk/batch/speechmatics/batch/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from ._models import JobType
from ._models import Transcript
from ._models import TranscriptionConfig
from ._transport import PROCESSING_DATA_HEADER
from ._transport import Transport


Expand Down Expand Up @@ -139,6 +140,7 @@ async def submit_job(
*,
config: Optional[JobConfig] = None,
transcription_config: Optional[TranscriptionConfig] = None,
requested_parallel: Optional[int] = None,
) -> JobDetails:
"""
Submit a new transcription job.
Expand All @@ -154,6 +156,9 @@ async def submit_job(
to build a basic job configuration.
transcription_config: Transcription-specific configuration. Used if config
is not provided.
requested_parallel: Optional number of parallel engines to request for this job.
Sent as ``{"requested_parallel": N}`` in the ``X-SM-Processing-Data`` header.
This only applies when using the container onPrem on http batch mode.

Returns:
JobDetails object containing the job ID and initial status.
Expand Down Expand Up @@ -200,7 +205,7 @@ async def submit_job(
assert audio_file is not None # for type checker; validated above
multipart_data, filename = await self._prepare_file_submission(audio_file, config_dict)

return await self._submit_and_create_job_details(multipart_data, filename, config)
return await self._submit_and_create_job_details(multipart_data, filename, config, requested_parallel)
except Exception as e:
if isinstance(e, (AuthenticationError, BatchError)):
raise
Expand Down Expand Up @@ -528,10 +533,13 @@ async def _prepare_file_submission(self, audio_file: Union[str, BinaryIO], confi
return multipart_data, filename

async def _submit_and_create_job_details(
self, multipart_data: dict, filename: str, config: JobConfig
self, multipart_data: dict, filename: str, config: JobConfig, requested_parallel: Optional[int] = None
) -> JobDetails:
"""Submit job and create JobDetails response."""
response = await self._transport.post("/jobs", multipart_data=multipart_data)
extra_headers: Optional[dict[str, dict[str, int]]] = None
if requested_parallel is not None:
extra_headers = {PROCESSING_DATA_HEADER: {"requested_parallel": requested_parallel}}
response = await self._transport.post("/jobs", multipart_data=multipart_data, extra_headers=extra_headers)
job_id = response.get("id")
if not job_id:
raise BatchError("No job ID returned from server")
Expand Down
16 changes: 15 additions & 1 deletion sdk/batch/speechmatics/batch/_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from ._logging import get_logger
from ._models import ConnectionConfig

PROCESSING_DATA_HEADER = "X-SM-Processing-Data"


class Transport:
"""
Expand Down Expand Up @@ -116,6 +118,7 @@ async def post(
json_data: Optional[dict[str, Any]] = None,
multipart_data: Optional[dict[str, Any]] = None,
timeout: Optional[float] = None,
extra_headers: Optional[dict[str, dict[str, int]]] = None,
) -> dict[str, Any]:
"""
Send POST request to the API.
Expand All @@ -125,6 +128,7 @@ async def post(
json_data: Optional JSON data for request body
multipart_data: Optional multipart form data
timeout: Optional request timeout
extra_headers: Optional additional headers to include in the request

Returns:
JSON response as dictionary
Expand All @@ -133,7 +137,14 @@ async def post(
AuthenticationError: If authentication fails
TransportError: If request fails
"""
return await self._request("POST", path, json_data=json_data, multipart_data=multipart_data, timeout=timeout)
return await self._request(
"POST",
path,
json_data=json_data,
multipart_data=multipart_data,
timeout=timeout,
extra_headers=extra_headers,
)

async def delete(self, path: str, timeout: Optional[float] = None) -> dict[str, Any]:
"""
Expand Down Expand Up @@ -200,6 +211,7 @@ async def _request(
json_data: Optional[dict[str, Any]] = None,
multipart_data: Optional[dict[str, Any]] = None,
timeout: Optional[float] = None,
extra_headers: Optional[dict[str, dict[str, int]]] = None,
) -> dict[str, Any]:
"""
Send HTTP request to the API.
Expand Down Expand Up @@ -227,6 +239,8 @@ async def _request(

url = f"{self._url.rstrip('/')}{path}"
headers = await self._prepare_headers()
if extra_headers:
headers.update(extra_headers)

self._logger.debug(
"Sending HTTP request %s %s (json=%s, multipart=%s)",
Expand Down
155 changes: 155 additions & 0 deletions tests/batch/test_submit_job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
"""Unit tests for AsyncClient.submit_job, focusing on the requested_parallel feature."""

import json
from io import BytesIO
from unittest.mock import AsyncMock
from unittest.mock import MagicMock
from unittest.mock import patch

from typing import Optional

import pytest

from speechmatics.batch import AsyncClient
from speechmatics.batch import JobConfig
from speechmatics.batch import JobStatus
from speechmatics.batch import JobType
from speechmatics.batch import TranscriptionConfig
from speechmatics.batch import PROCESSING_DATA_HEADER


def _make_client(api_key: str = "test-key") -> AsyncClient:
return AsyncClient(api_key=api_key)


def _job_response(job_id: str = "job-123") -> dict:
return {"id": job_id, "created_at": "2024-01-01T00:00:00Z"}


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------


def _captured_extra_headers(mock_post: AsyncMock) -> Optional[dict]:
"""Return the extra_headers kwarg from the first call to transport.post."""
_, kwargs = mock_post.call_args
return kwargs.get("extra_headers")


# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------


class TestRequestedParallelHeader:
"""X-SM-Processing-Data header is set correctly based on requested_parallel."""

@pytest.mark.asyncio
async def test_header_sent_when_requested_parallel_provided(self):
client = _make_client()
audio = BytesIO(b"fake-audio")

with patch.object(client._transport, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = _job_response()
await client.submit_job(audio, requested_parallel=4)

extra_headers = _captured_extra_headers(mock_post)
assert extra_headers is not None
assert PROCESSING_DATA_HEADER in extra_headers
payload = extra_headers[PROCESSING_DATA_HEADER]
assert payload == {"requested_parallel": 4}

@pytest.mark.asyncio
async def test_header_not_sent_when_requested_parallel_is_none(self):
client = _make_client()
audio = BytesIO(b"fake-audio")

with patch.object(client._transport, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = _job_response()
await client.submit_job(audio)

extra_headers = _captured_extra_headers(mock_post)
assert extra_headers is None

@pytest.mark.asyncio
async def test_header_value_is_valid_json(self):
client = _make_client()
audio = BytesIO(b"fake-audio")

with patch.object(client._transport, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = _job_response()
await client.submit_job(audio, requested_parallel=8)

extra_headers = _captured_extra_headers(mock_post)
# Must be parseable JSON
assert extra_headers is not None
parsed = extra_headers[PROCESSING_DATA_HEADER]
assert parsed["requested_parallel"] == 8

@pytest.mark.asyncio
async def test_requested_parallel_one(self):
client = _make_client()
audio = BytesIO(b"fake-audio")

with patch.object(client._transport, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = _job_response()
await client.submit_job(audio, requested_parallel=1)

extra_headers = _captured_extra_headers(mock_post)
assert extra_headers is not None
payload = extra_headers[PROCESSING_DATA_HEADER]
assert payload["requested_parallel"] == 1

@pytest.mark.asyncio
async def test_header_sent_with_fetch_data_config(self):
"""requested_parallel works with fetch_data submissions too."""
client = _make_client()
config = JobConfig(
type=JobType.TRANSCRIPTION,
fetch_data=MagicMock(url="https://example.com/audio.wav"),
transcription_config=TranscriptionConfig(language="en"),
)
# Patch to_dict so fetch_data key is present
config_dict = {
"type": "transcription",
"fetch_data": {"url": "https://example.com/audio.wav"},
"transcription_config": {"language": "en"},
}
with patch.object(config, "to_dict", return_value=config_dict):
with patch.object(client._transport, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = _job_response()
await client.submit_job(None, config=config, requested_parallel=2)

extra_headers = _captured_extra_headers(mock_post)
assert extra_headers is not None
payload = extra_headers[PROCESSING_DATA_HEADER]
assert payload == {"requested_parallel": 2}


class TestSubmitJobReturnValue:
"""submit_job still returns the correct JobDetails regardless of requested_parallel."""

@pytest.mark.asyncio
async def test_returns_job_details_with_correct_id(self):
client = _make_client()
audio = BytesIO(b"fake-audio")

with patch.object(client._transport, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = _job_response("abc-456")
job = await client.submit_job(audio, requested_parallel=3)

assert job.id == "abc-456"
assert job.status == JobStatus.RUNNING

@pytest.mark.asyncio
async def test_post_called_with_jobs_path(self):
client = _make_client()
audio = BytesIO(b"fake-audio")

with patch.object(client._transport, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = _job_response()
await client.submit_job(audio, requested_parallel=2)

args, _ = mock_post.call_args
assert args[0] == "/jobs"
Loading