diff --git a/sdk/batch/speechmatics/batch/__init__.py b/sdk/batch/speechmatics/batch/__init__.py index cf57d4f..1e8e516 100644 --- a/sdk/batch/speechmatics/batch/__init__.py +++ b/sdk/batch/speechmatics/batch/__init__.py @@ -32,9 +32,11 @@ from ._models import TranscriptFilteringConfig from ._models import TranscriptionConfig from ._models import TranslationConfig +from ._transport import PROCESSING_DATA_HEADER __all__ = [ "AsyncClient", + "PROCESSING_DATA_HEADER", "AuthBase", "AuthenticationError", "AutoChaptersConfig", diff --git a/sdk/batch/speechmatics/batch/_async_client.py b/sdk/batch/speechmatics/batch/_async_client.py index e286353..b8fbe64 100644 --- a/sdk/batch/speechmatics/batch/_async_client.py +++ b/sdk/batch/speechmatics/batch/_async_client.py @@ -31,6 +31,7 @@ from ._models import JobType from ._models import Transcript from ._models import TranscriptionConfig +from ._transport import PROCESSING_DATA_HEADER from ._transport import Transport @@ -139,6 +140,7 @@ async def submit_job( *, config: Optional[JobConfig] = None, transcription_config: Optional[TranscriptionConfig] = None, + requested_parallel: Optional[int] = None, ) -> JobDetails: """ Submit a new transcription job. @@ -154,6 +156,9 @@ async def submit_job( to build a basic job configuration. transcription_config: Transcription-specific configuration. Used if config is not provided. + requested_parallel: Optional number of parallel engines to request for this job. + Sent as ``{"requested_parallel": N}`` in the ``X-SM-Processing-Data`` header. + This only applies when using the container onPrem on http batch mode. Returns: JobDetails object containing the job ID and initial status. @@ -200,7 +205,7 @@ async def submit_job( assert audio_file is not None # for type checker; validated above multipart_data, filename = await self._prepare_file_submission(audio_file, config_dict) - return await self._submit_and_create_job_details(multipart_data, filename, config) + return await self._submit_and_create_job_details(multipart_data, filename, config, requested_parallel) except Exception as e: if isinstance(e, (AuthenticationError, BatchError)): raise @@ -528,10 +533,13 @@ async def _prepare_file_submission(self, audio_file: Union[str, BinaryIO], confi return multipart_data, filename async def _submit_and_create_job_details( - self, multipart_data: dict, filename: str, config: JobConfig + self, multipart_data: dict, filename: str, config: JobConfig, requested_parallel: Optional[int] = None ) -> JobDetails: """Submit job and create JobDetails response.""" - response = await self._transport.post("/jobs", multipart_data=multipart_data) + extra_headers: Optional[dict[str, dict[str, int]]] = None + if requested_parallel is not None: + extra_headers = {PROCESSING_DATA_HEADER: {"requested_parallel": requested_parallel}} + response = await self._transport.post("/jobs", multipart_data=multipart_data, extra_headers=extra_headers) job_id = response.get("id") if not job_id: raise BatchError("No job ID returned from server") diff --git a/sdk/batch/speechmatics/batch/_transport.py b/sdk/batch/speechmatics/batch/_transport.py index bab8275..149ae29 100644 --- a/sdk/batch/speechmatics/batch/_transport.py +++ b/sdk/batch/speechmatics/batch/_transport.py @@ -25,6 +25,8 @@ from ._logging import get_logger from ._models import ConnectionConfig +PROCESSING_DATA_HEADER = "X-SM-Processing-Data" + class Transport: """ @@ -116,6 +118,7 @@ async def post( json_data: Optional[dict[str, Any]] = None, multipart_data: Optional[dict[str, Any]] = None, timeout: Optional[float] = None, + extra_headers: Optional[dict[str, dict[str, int]]] = None, ) -> dict[str, Any]: """ Send POST request to the API. @@ -125,6 +128,7 @@ async def post( json_data: Optional JSON data for request body multipart_data: Optional multipart form data timeout: Optional request timeout + extra_headers: Optional additional headers to include in the request Returns: JSON response as dictionary @@ -133,7 +137,14 @@ async def post( AuthenticationError: If authentication fails TransportError: If request fails """ - return await self._request("POST", path, json_data=json_data, multipart_data=multipart_data, timeout=timeout) + return await self._request( + "POST", + path, + json_data=json_data, + multipart_data=multipart_data, + timeout=timeout, + extra_headers=extra_headers, + ) async def delete(self, path: str, timeout: Optional[float] = None) -> dict[str, Any]: """ @@ -200,6 +211,7 @@ async def _request( json_data: Optional[dict[str, Any]] = None, multipart_data: Optional[dict[str, Any]] = None, timeout: Optional[float] = None, + extra_headers: Optional[dict[str, dict[str, int]]] = None, ) -> dict[str, Any]: """ Send HTTP request to the API. @@ -227,6 +239,8 @@ async def _request( url = f"{self._url.rstrip('/')}{path}" headers = await self._prepare_headers() + if extra_headers: + headers.update(extra_headers) self._logger.debug( "Sending HTTP request %s %s (json=%s, multipart=%s)", diff --git a/tests/batch/test_submit_job.py b/tests/batch/test_submit_job.py new file mode 100644 index 0000000..47f835c --- /dev/null +++ b/tests/batch/test_submit_job.py @@ -0,0 +1,155 @@ +"""Unit tests for AsyncClient.submit_job, focusing on the requested_parallel feature.""" + +import json +from io import BytesIO +from unittest.mock import AsyncMock +from unittest.mock import MagicMock +from unittest.mock import patch + +from typing import Optional + +import pytest + +from speechmatics.batch import AsyncClient +from speechmatics.batch import JobConfig +from speechmatics.batch import JobStatus +from speechmatics.batch import JobType +from speechmatics.batch import TranscriptionConfig +from speechmatics.batch import PROCESSING_DATA_HEADER + + +def _make_client(api_key: str = "test-key") -> AsyncClient: + return AsyncClient(api_key=api_key) + + +def _job_response(job_id: str = "job-123") -> dict: + return {"id": job_id, "created_at": "2024-01-01T00:00:00Z"} + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _captured_extra_headers(mock_post: AsyncMock) -> Optional[dict]: + """Return the extra_headers kwarg from the first call to transport.post.""" + _, kwargs = mock_post.call_args + return kwargs.get("extra_headers") + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestRequestedParallelHeader: + """X-SM-Processing-Data header is set correctly based on requested_parallel.""" + + @pytest.mark.asyncio + async def test_header_sent_when_requested_parallel_provided(self): + client = _make_client() + audio = BytesIO(b"fake-audio") + + with patch.object(client._transport, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = _job_response() + await client.submit_job(audio, requested_parallel=4) + + extra_headers = _captured_extra_headers(mock_post) + assert extra_headers is not None + assert PROCESSING_DATA_HEADER in extra_headers + payload = extra_headers[PROCESSING_DATA_HEADER] + assert payload == {"requested_parallel": 4} + + @pytest.mark.asyncio + async def test_header_not_sent_when_requested_parallel_is_none(self): + client = _make_client() + audio = BytesIO(b"fake-audio") + + with patch.object(client._transport, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = _job_response() + await client.submit_job(audio) + + extra_headers = _captured_extra_headers(mock_post) + assert extra_headers is None + + @pytest.mark.asyncio + async def test_header_value_is_valid_json(self): + client = _make_client() + audio = BytesIO(b"fake-audio") + + with patch.object(client._transport, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = _job_response() + await client.submit_job(audio, requested_parallel=8) + + extra_headers = _captured_extra_headers(mock_post) + # Must be parseable JSON + assert extra_headers is not None + parsed = extra_headers[PROCESSING_DATA_HEADER] + assert parsed["requested_parallel"] == 8 + + @pytest.mark.asyncio + async def test_requested_parallel_one(self): + client = _make_client() + audio = BytesIO(b"fake-audio") + + with patch.object(client._transport, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = _job_response() + await client.submit_job(audio, requested_parallel=1) + + extra_headers = _captured_extra_headers(mock_post) + assert extra_headers is not None + payload = extra_headers[PROCESSING_DATA_HEADER] + assert payload["requested_parallel"] == 1 + + @pytest.mark.asyncio + async def test_header_sent_with_fetch_data_config(self): + """requested_parallel works with fetch_data submissions too.""" + client = _make_client() + config = JobConfig( + type=JobType.TRANSCRIPTION, + fetch_data=MagicMock(url="https://example.com/audio.wav"), + transcription_config=TranscriptionConfig(language="en"), + ) + # Patch to_dict so fetch_data key is present + config_dict = { + "type": "transcription", + "fetch_data": {"url": "https://example.com/audio.wav"}, + "transcription_config": {"language": "en"}, + } + with patch.object(config, "to_dict", return_value=config_dict): + with patch.object(client._transport, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = _job_response() + await client.submit_job(None, config=config, requested_parallel=2) + + extra_headers = _captured_extra_headers(mock_post) + assert extra_headers is not None + payload = extra_headers[PROCESSING_DATA_HEADER] + assert payload == {"requested_parallel": 2} + + +class TestSubmitJobReturnValue: + """submit_job still returns the correct JobDetails regardless of requested_parallel.""" + + @pytest.mark.asyncio + async def test_returns_job_details_with_correct_id(self): + client = _make_client() + audio = BytesIO(b"fake-audio") + + with patch.object(client._transport, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = _job_response("abc-456") + job = await client.submit_job(audio, requested_parallel=3) + + assert job.id == "abc-456" + assert job.status == JobStatus.RUNNING + + @pytest.mark.asyncio + async def test_post_called_with_jobs_path(self): + client = _make_client() + audio = BytesIO(b"fake-audio") + + with patch.object(client._transport, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = _job_response() + await client.submit_job(audio, requested_parallel=2) + + args, _ = mock_post.call_args + assert args[0] == "/jobs"