From 716097630fe0b966a0c28d2ba227be3149af3a77 Mon Sep 17 00:00:00 2001 From: Georgios Hadjiharalambous Date: Thu, 12 Mar 2026 16:27:21 +0000 Subject: [PATCH 1/6] add support for requesting parallel engines in a http batch job --- sdk/batch/speechmatics/batch/_async_client.py | 13 +- sdk/batch/speechmatics/batch/_transport.py | 14 +- tests/batch/test_submit_job.py | 152 ++++++++++++++++++ 3 files changed, 175 insertions(+), 4 deletions(-) create mode 100644 tests/batch/test_submit_job.py diff --git a/sdk/batch/speechmatics/batch/_async_client.py b/sdk/batch/speechmatics/batch/_async_client.py index e2863530..5c4622e4 100644 --- a/sdk/batch/speechmatics/batch/_async_client.py +++ b/sdk/batch/speechmatics/batch/_async_client.py @@ -139,6 +139,7 @@ async def submit_job( *, config: Optional[JobConfig] = None, transcription_config: Optional[TranscriptionConfig] = None, + requested_parallel: Optional[int] = None, ) -> JobDetails: """ Submit a new transcription job. @@ -154,6 +155,9 @@ async def submit_job( to build a basic job configuration. transcription_config: Transcription-specific configuration. Used if config is not provided. + requested_parallel: Optional number of parallel engines to request for this job. + Sent as ``{"requested_parallel": N}`` in the ``X-SM-Processing-Data`` header. + This only applies when using the container onPrem on http batch mode. Returns: JobDetails object containing the job ID and initial status. @@ -200,7 +204,7 @@ async def submit_job( assert audio_file is not None # for type checker; validated above multipart_data, filename = await self._prepare_file_submission(audio_file, config_dict) - return await self._submit_and_create_job_details(multipart_data, filename, config) + return await self._submit_and_create_job_details(multipart_data, filename, config, requested_parallel) except Exception as e: if isinstance(e, (AuthenticationError, BatchError)): raise @@ -528,10 +532,13 @@ async def _prepare_file_submission(self, audio_file: Union[str, BinaryIO], confi return multipart_data, filename async def _submit_and_create_job_details( - self, multipart_data: dict, filename: str, config: JobConfig + self, multipart_data: dict, filename: str, config: JobConfig, requested_parallel: Optional[int] = None ) -> JobDetails: """Submit job and create JobDetails response.""" - response = await self._transport.post("/jobs", multipart_data=multipart_data) + extra_headers: Optional[dict[str, dict[str, int]]] = None + if requested_parallel is not None: + extra_headers = {"X-SM-Processing-Data": {"requested_parallel": requested_parallel}} + response = await self._transport.post("/jobs", multipart_data=multipart_data, extra_headers=extra_headers) job_id = response.get("id") if not job_id: raise BatchError("No job ID returned from server") diff --git a/sdk/batch/speechmatics/batch/_transport.py b/sdk/batch/speechmatics/batch/_transport.py index bab82750..87422d86 100644 --- a/sdk/batch/speechmatics/batch/_transport.py +++ b/sdk/batch/speechmatics/batch/_transport.py @@ -116,6 +116,7 @@ async def post( json_data: Optional[dict[str, Any]] = None, multipart_data: Optional[dict[str, Any]] = None, timeout: Optional[float] = None, + extra_headers: Optional[dict[str, dict[str, int]]] = None, ) -> dict[str, Any]: """ Send POST request to the API. @@ -125,6 +126,7 @@ async def post( json_data: Optional JSON data for request body multipart_data: Optional multipart form data timeout: Optional request timeout + extra_headers: Optional additional headers to include in the request Returns: JSON response as dictionary @@ -133,7 +135,14 @@ async def post( AuthenticationError: If authentication fails TransportError: If request fails """ - return await self._request("POST", path, json_data=json_data, multipart_data=multipart_data, timeout=timeout) + return await self._request( + "POST", + path, + json_data=json_data, + multipart_data=multipart_data, + timeout=timeout, + extra_headers=extra_headers, + ) async def delete(self, path: str, timeout: Optional[float] = None) -> dict[str, Any]: """ @@ -200,6 +209,7 @@ async def _request( json_data: Optional[dict[str, Any]] = None, multipart_data: Optional[dict[str, Any]] = None, timeout: Optional[float] = None, + extra_headers: Optional[dict[str, dict[str, int]]] = None, ) -> dict[str, Any]: """ Send HTTP request to the API. @@ -227,6 +237,8 @@ async def _request( url = f"{self._url.rstrip('/')}{path}" headers = await self._prepare_headers() + if extra_headers: + headers.update(extra_headers) self._logger.debug( "Sending HTTP request %s %s (json=%s, multipart=%s)", diff --git a/tests/batch/test_submit_job.py b/tests/batch/test_submit_job.py new file mode 100644 index 00000000..a23318bf --- /dev/null +++ b/tests/batch/test_submit_job.py @@ -0,0 +1,152 @@ +"""Unit tests for AsyncClient.submit_job, focusing on the requested_parallel feature.""" + +import json +from io import BytesIO +from unittest.mock import AsyncMock +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest + +from speechmatics.batch import AsyncClient +from speechmatics.batch import JobConfig +from speechmatics.batch import JobStatus +from speechmatics.batch import JobType +from speechmatics.batch import TranscriptionConfig + + +def _make_client(api_key: str = "test-key") -> AsyncClient: + return AsyncClient(api_key=api_key) + + +def _job_response(job_id: str = "job-123") -> dict: + return {"id": job_id, "created_at": "2024-01-01T00:00:00Z"} + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _captured_extra_headers(mock_post: AsyncMock) -> dict | None: + """Return the extra_headers kwarg from the first call to transport.post.""" + _, kwargs = mock_post.call_args + return kwargs.get("extra_headers") + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestRequestedParallelHeader: + """X-SM-Processing-Data header is set correctly based on requested_parallel.""" + + @pytest.mark.asyncio + async def test_header_sent_when_requested_parallel_provided(self): + client = _make_client() + audio = BytesIO(b"fake-audio") + + with patch.object(client._transport, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = _job_response() + await client.submit_job(audio, requested_parallel=4) + + extra_headers = _captured_extra_headers(mock_post) + assert extra_headers is not None + assert "X-SM-Processing-Data" in extra_headers + payload = extra_headers["X-SM-Processing-Data"] + assert payload == {"requested_parallel": 4} + + @pytest.mark.asyncio + async def test_header_not_sent_when_requested_parallel_is_none(self): + client = _make_client() + audio = BytesIO(b"fake-audio") + + with patch.object(client._transport, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = _job_response() + await client.submit_job(audio) + + extra_headers = _captured_extra_headers(mock_post) + assert extra_headers is None + + @pytest.mark.asyncio + async def test_header_value_is_valid_json(self): + client = _make_client() + audio = BytesIO(b"fake-audio") + + with patch.object(client._transport, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = _job_response() + await client.submit_job(audio, requested_parallel=8) + + extra_headers = _captured_extra_headers(mock_post) + # Must be parseable JSON + assert extra_headers is not None + parsed = extra_headers["X-SM-Processing-Data"] + assert parsed["requested_parallel"] == 8 + + @pytest.mark.asyncio + async def test_requested_parallel_one(self): + client = _make_client() + audio = BytesIO(b"fake-audio") + + with patch.object(client._transport, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = _job_response() + await client.submit_job(audio, requested_parallel=1) + + extra_headers = _captured_extra_headers(mock_post) + assert extra_headers is not None + payload = extra_headers["X-SM-Processing-Data"] + assert payload["requested_parallel"] == 1 + + @pytest.mark.asyncio + async def test_header_sent_with_fetch_data_config(self): + """requested_parallel works with fetch_data submissions too.""" + client = _make_client() + config = JobConfig( + type=JobType.TRANSCRIPTION, + fetch_data=MagicMock(url="https://example.com/audio.wav"), + transcription_config=TranscriptionConfig(language="en"), + ) + # Patch to_dict so fetch_data key is present + config_dict = { + "type": "transcription", + "fetch_data": {"url": "https://example.com/audio.wav"}, + "transcription_config": {"language": "en"}, + } + with patch.object(config, "to_dict", return_value=config_dict): + with patch.object(client._transport, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = _job_response() + await client.submit_job(None, config=config, requested_parallel=2) + + extra_headers = _captured_extra_headers(mock_post) + assert extra_headers is not None + payload = extra_headers["X-SM-Processing-Data"] + assert payload == {"requested_parallel": 2} + + +class TestSubmitJobReturnValue: + """submit_job still returns the correct JobDetails regardless of requested_parallel.""" + + @pytest.mark.asyncio + async def test_returns_job_details_with_correct_id(self): + client = _make_client() + audio = BytesIO(b"fake-audio") + + with patch.object(client._transport, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = _job_response("abc-456") + job = await client.submit_job(audio, requested_parallel=3) + + assert job.id == "abc-456" + assert job.status == JobStatus.RUNNING + + @pytest.mark.asyncio + async def test_post_called_with_jobs_path(self): + client = _make_client() + audio = BytesIO(b"fake-audio") + + with patch.object(client._transport, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = _job_response() + await client.submit_job(audio, requested_parallel=2) + + args, _ = mock_post.call_args + assert args[0] == "/jobs" From f8d80db0785243de4324c742cc8ec7c0ea3533d6 Mon Sep 17 00:00:00 2001 From: Georgios Hadjiharalambous Date: Thu, 12 Mar 2026 16:55:44 +0000 Subject: [PATCH 2/6] make header name const --- sdk/batch/speechmatics/batch/__init__.py | 2 ++ sdk/batch/speechmatics/batch/_async_client.py | 3 ++- sdk/batch/speechmatics/batch/_transport.py | 2 ++ tests/batch/test_submit_job.py | 11 ++++++----- 4 files changed, 12 insertions(+), 6 deletions(-) diff --git a/sdk/batch/speechmatics/batch/__init__.py b/sdk/batch/speechmatics/batch/__init__.py index cf57d4f0..1e8e5165 100644 --- a/sdk/batch/speechmatics/batch/__init__.py +++ b/sdk/batch/speechmatics/batch/__init__.py @@ -32,9 +32,11 @@ from ._models import TranscriptFilteringConfig from ._models import TranscriptionConfig from ._models import TranslationConfig +from ._transport import PROCESSING_DATA_HEADER __all__ = [ "AsyncClient", + "PROCESSING_DATA_HEADER", "AuthBase", "AuthenticationError", "AutoChaptersConfig", diff --git a/sdk/batch/speechmatics/batch/_async_client.py b/sdk/batch/speechmatics/batch/_async_client.py index 5c4622e4..b8fbe646 100644 --- a/sdk/batch/speechmatics/batch/_async_client.py +++ b/sdk/batch/speechmatics/batch/_async_client.py @@ -31,6 +31,7 @@ from ._models import JobType from ._models import Transcript from ._models import TranscriptionConfig +from ._transport import PROCESSING_DATA_HEADER from ._transport import Transport @@ -537,7 +538,7 @@ async def _submit_and_create_job_details( """Submit job and create JobDetails response.""" extra_headers: Optional[dict[str, dict[str, int]]] = None if requested_parallel is not None: - extra_headers = {"X-SM-Processing-Data": {"requested_parallel": requested_parallel}} + extra_headers = {PROCESSING_DATA_HEADER: {"requested_parallel": requested_parallel}} response = await self._transport.post("/jobs", multipart_data=multipart_data, extra_headers=extra_headers) job_id = response.get("id") if not job_id: diff --git a/sdk/batch/speechmatics/batch/_transport.py b/sdk/batch/speechmatics/batch/_transport.py index 87422d86..149ae29c 100644 --- a/sdk/batch/speechmatics/batch/_transport.py +++ b/sdk/batch/speechmatics/batch/_transport.py @@ -25,6 +25,8 @@ from ._logging import get_logger from ._models import ConnectionConfig +PROCESSING_DATA_HEADER = "X-SM-Processing-Data" + class Transport: """ diff --git a/tests/batch/test_submit_job.py b/tests/batch/test_submit_job.py index a23318bf..ccb9ff4e 100644 --- a/tests/batch/test_submit_job.py +++ b/tests/batch/test_submit_job.py @@ -13,6 +13,7 @@ from speechmatics.batch import JobStatus from speechmatics.batch import JobType from speechmatics.batch import TranscriptionConfig +from speechmatics.batch import PROCESSING_DATA_HEADER def _make_client(api_key: str = "test-key") -> AsyncClient: @@ -53,8 +54,8 @@ async def test_header_sent_when_requested_parallel_provided(self): extra_headers = _captured_extra_headers(mock_post) assert extra_headers is not None - assert "X-SM-Processing-Data" in extra_headers - payload = extra_headers["X-SM-Processing-Data"] + assert PROCESSING_DATA_HEADER in extra_headers + payload = extra_headers[PROCESSING_DATA_HEADER] assert payload == {"requested_parallel": 4} @pytest.mark.asyncio @@ -81,7 +82,7 @@ async def test_header_value_is_valid_json(self): extra_headers = _captured_extra_headers(mock_post) # Must be parseable JSON assert extra_headers is not None - parsed = extra_headers["X-SM-Processing-Data"] + parsed = extra_headers[PROCESSING_DATA_HEADER] assert parsed["requested_parallel"] == 8 @pytest.mark.asyncio @@ -95,7 +96,7 @@ async def test_requested_parallel_one(self): extra_headers = _captured_extra_headers(mock_post) assert extra_headers is not None - payload = extra_headers["X-SM-Processing-Data"] + payload = extra_headers[PROCESSING_DATA_HEADER] assert payload["requested_parallel"] == 1 @pytest.mark.asyncio @@ -120,7 +121,7 @@ async def test_header_sent_with_fetch_data_config(self): extra_headers = _captured_extra_headers(mock_post) assert extra_headers is not None - payload = extra_headers["X-SM-Processing-Data"] + payload = extra_headers[PROCESSING_DATA_HEADER] assert payload == {"requested_parallel": 2} From bb8b221805bcdcd44c8e805e390e1022ee4a814c Mon Sep 17 00:00:00 2001 From: Georgios Hadjiharalambous Date: Thu, 12 Mar 2026 16:56:51 +0000 Subject: [PATCH 3/6] test fix on typerror python3.9 --- tests/batch/test_submit_job.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/batch/test_submit_job.py b/tests/batch/test_submit_job.py index ccb9ff4e..47f835c0 100644 --- a/tests/batch/test_submit_job.py +++ b/tests/batch/test_submit_job.py @@ -6,6 +6,8 @@ from unittest.mock import MagicMock from unittest.mock import patch +from typing import Optional + import pytest from speechmatics.batch import AsyncClient @@ -29,7 +31,7 @@ def _job_response(job_id: str = "job-123") -> dict: # --------------------------------------------------------------------------- -def _captured_extra_headers(mock_post: AsyncMock) -> dict | None: +def _captured_extra_headers(mock_post: AsyncMock) -> Optional[dict]: """Return the extra_headers kwarg from the first call to transport.post.""" _, kwargs = mock_post.call_args return kwargs.get("extra_headers") From 15266cee7a751ffc63b1a5b95dfb369f891c14e9 Mon Sep 17 00:00:00 2001 From: Georgios Hadjiharalambous Date: Thu, 12 Mar 2026 17:23:21 +0000 Subject: [PATCH 4/6] add flag on transcribe function --- sdk/batch/speechmatics/batch/_async_client.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sdk/batch/speechmatics/batch/_async_client.py b/sdk/batch/speechmatics/batch/_async_client.py index b8fbe646..1399bff5 100644 --- a/sdk/batch/speechmatics/batch/_async_client.py +++ b/sdk/batch/speechmatics/batch/_async_client.py @@ -440,6 +440,7 @@ async def transcribe( transcription_config: Optional[TranscriptionConfig] = None, polling_interval: float = 5.0, timeout: Optional[float] = None, + requested_parallel: Optional[int] = None, ) -> Union[Transcript, str]: """ Complete transcription workflow: submit job and wait for completion. @@ -453,6 +454,9 @@ async def transcribe( transcription_config: Transcription-specific configuration. polling_interval: Time in seconds between status checks. timeout: Maximum time in seconds to wait for completion. + requested_parallel: Optional number of parallel engines to request for this job. + Sent as ``{"requested_parallel": N}`` in the ``X-SM-Processing-Data`` header. + This only applies when using the container onPrem on http batch mode. Returns: Transcript object containing the transcript and metadata. @@ -480,6 +484,7 @@ async def transcribe( audio_file, config=config, transcription_config=transcription_config, + requested_parallel=requested_parallel, ) # Wait for completion and return result From f36d1677fbeec0d5bfd23d4438b1949a153a6276 Mon Sep 17 00:00:00 2001 From: Georgios Hadjiharalambous Date: Thu, 12 Mar 2026 17:52:55 +0000 Subject: [PATCH 5/6] proper working, after json serialization issue --- sdk/batch/speechmatics/batch/_async_client.py | 2 +- sdk/batch/speechmatics/batch/_transport.py | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/sdk/batch/speechmatics/batch/_async_client.py b/sdk/batch/speechmatics/batch/_async_client.py index 1399bff5..a2ded430 100644 --- a/sdk/batch/speechmatics/batch/_async_client.py +++ b/sdk/batch/speechmatics/batch/_async_client.py @@ -541,7 +541,7 @@ async def _submit_and_create_job_details( self, multipart_data: dict, filename: str, config: JobConfig, requested_parallel: Optional[int] = None ) -> JobDetails: """Submit job and create JobDetails response.""" - extra_headers: Optional[dict[str, dict[str, int]]] = None + extra_headers: Optional[dict[str, Any]] = None if requested_parallel is not None: extra_headers = {PROCESSING_DATA_HEADER: {"requested_parallel": requested_parallel}} response = await self._transport.post("/jobs", multipart_data=multipart_data, extra_headers=extra_headers) diff --git a/sdk/batch/speechmatics/batch/_transport.py b/sdk/batch/speechmatics/batch/_transport.py index 149ae29c..4beb3d69 100644 --- a/sdk/batch/speechmatics/batch/_transport.py +++ b/sdk/batch/speechmatics/batch/_transport.py @@ -10,6 +10,7 @@ import asyncio import io +import json as _json import sys import uuid from typing import Any @@ -118,7 +119,7 @@ async def post( json_data: Optional[dict[str, Any]] = None, multipart_data: Optional[dict[str, Any]] = None, timeout: Optional[float] = None, - extra_headers: Optional[dict[str, dict[str, int]]] = None, + extra_headers: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: """ Send POST request to the API. @@ -211,7 +212,7 @@ async def _request( json_data: Optional[dict[str, Any]] = None, multipart_data: Optional[dict[str, Any]] = None, timeout: Optional[float] = None, - extra_headers: Optional[dict[str, dict[str, int]]] = None, + extra_headers: Optional[dict[str, Any]] = None, ) -> dict[str, Any]: """ Send HTTP request to the API. @@ -240,7 +241,8 @@ async def _request( url = f"{self._url.rstrip('/')}{path}" headers = await self._prepare_headers() if extra_headers: - headers.update(extra_headers) + for k, v in extra_headers.items(): + headers[k] = _json.dumps(v) if isinstance(v, dict) else v self._logger.debug( "Sending HTTP request %s %s (json=%s, multipart=%s)", From 664ab1d4a7feb8063177887e2610f1e7da9df99c Mon Sep 17 00:00:00 2001 From: Georgios Hadjiharalambous Date: Fri, 13 Mar 2026 18:08:16 +0000 Subject: [PATCH 6/6] add user_id as input to header too --- sdk/batch/speechmatics/batch/_async_client.py | 25 +++++- tests/batch/test_submit_job.py | 86 ++++++++++++++++++- 2 files changed, 107 insertions(+), 4 deletions(-) diff --git a/sdk/batch/speechmatics/batch/_async_client.py b/sdk/batch/speechmatics/batch/_async_client.py index a2ded430..1192c21f 100644 --- a/sdk/batch/speechmatics/batch/_async_client.py +++ b/sdk/batch/speechmatics/batch/_async_client.py @@ -141,6 +141,7 @@ async def submit_job( config: Optional[JobConfig] = None, transcription_config: Optional[TranscriptionConfig] = None, requested_parallel: Optional[int] = None, + user_id: Optional[str] = None, ) -> JobDetails: """ Submit a new transcription job. @@ -159,6 +160,9 @@ async def submit_job( requested_parallel: Optional number of parallel engines to request for this job. Sent as ``{"requested_parallel": N}`` in the ``X-SM-Processing-Data`` header. This only applies when using the container onPrem on http batch mode. + user_id: Optional user identifier to associate with this job. + Sent as ``{"user_id": "..."}`` in the ``X-SM-Processing-Data`` header. + This only applies when using the container onPrem on http batch mode. Returns: JobDetails object containing the job ID and initial status. @@ -205,7 +209,7 @@ async def submit_job( assert audio_file is not None # for type checker; validated above multipart_data, filename = await self._prepare_file_submission(audio_file, config_dict) - return await self._submit_and_create_job_details(multipart_data, filename, config, requested_parallel) + return await self._submit_and_create_job_details(multipart_data, filename, config, requested_parallel, user_id) except Exception as e: if isinstance(e, (AuthenticationError, BatchError)): raise @@ -441,6 +445,7 @@ async def transcribe( polling_interval: float = 5.0, timeout: Optional[float] = None, requested_parallel: Optional[int] = None, + user_id: Optional[str] = None, ) -> Union[Transcript, str]: """ Complete transcription workflow: submit job and wait for completion. @@ -457,6 +462,9 @@ async def transcribe( requested_parallel: Optional number of parallel engines to request for this job. Sent as ``{"requested_parallel": N}`` in the ``X-SM-Processing-Data`` header. This only applies when using the container onPrem on http batch mode. + user_id: Optional user identifier to associate with this job. + Sent as ``{"user_id": "..."}`` in the ``X-SM-Processing-Data`` header. + This only applies when using the container onPrem on http batch mode. Returns: Transcript object containing the transcript and metadata. @@ -485,6 +493,7 @@ async def transcribe( config=config, transcription_config=transcription_config, requested_parallel=requested_parallel, + user_id=user_id, ) # Wait for completion and return result @@ -538,12 +547,22 @@ async def _prepare_file_submission(self, audio_file: Union[str, BinaryIO], confi return multipart_data, filename async def _submit_and_create_job_details( - self, multipart_data: dict, filename: str, config: JobConfig, requested_parallel: Optional[int] = None + self, + multipart_data: dict, + filename: str, + config: JobConfig, + requested_parallel: Optional[int] = None, + user_id: Optional[str] = None, ) -> JobDetails: """Submit job and create JobDetails response.""" extra_headers: Optional[dict[str, Any]] = None + processing_data: dict[str, Any] = {} if requested_parallel is not None: - extra_headers = {PROCESSING_DATA_HEADER: {"requested_parallel": requested_parallel}} + processing_data["requested_parallel"] = requested_parallel + if user_id is not None: + processing_data["user_id"] = user_id + if processing_data: + extra_headers = {PROCESSING_DATA_HEADER: processing_data} response = await self._transport.post("/jobs", multipart_data=multipart_data, extra_headers=extra_headers) job_id = response.get("id") if not job_id: diff --git a/tests/batch/test_submit_job.py b/tests/batch/test_submit_job.py index 47f835c0..bf6ead46 100644 --- a/tests/batch/test_submit_job.py +++ b/tests/batch/test_submit_job.py @@ -1,4 +1,4 @@ -"""Unit tests for AsyncClient.submit_job, focusing on the requested_parallel feature.""" +"""Unit tests for AsyncClient.submit_job, focusing on the requested_parallel and user_id features.""" import json from io import BytesIO @@ -127,6 +127,90 @@ async def test_header_sent_with_fetch_data_config(self): assert payload == {"requested_parallel": 2} +class TestUserIdHeader: + """X-SM-Processing-Data header is set correctly based on user_id.""" + + @pytest.mark.asyncio + async def test_header_sent_when_user_id_provided(self): + client = _make_client() + audio = BytesIO(b"fake-audio") + + with patch.object(client._transport, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = _job_response() + await client.submit_job(audio, user_id="user-abc") + + extra_headers = _captured_extra_headers(mock_post) + assert extra_headers is not None + assert PROCESSING_DATA_HEADER in extra_headers + payload = extra_headers[PROCESSING_DATA_HEADER] + assert payload == {"user_id": "user-abc"} + + @pytest.mark.asyncio + async def test_header_not_sent_when_user_id_is_none(self): + client = _make_client() + audio = BytesIO(b"fake-audio") + + with patch.object(client._transport, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = _job_response() + await client.submit_job(audio) + + extra_headers = _captured_extra_headers(mock_post) + assert extra_headers is None + + @pytest.mark.asyncio + async def test_user_id_and_requested_parallel_sent_together(self): + client = _make_client() + audio = BytesIO(b"fake-audio") + + with patch.object(client._transport, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = _job_response() + await client.submit_job(audio, requested_parallel=4, user_id="user-xyz") + + extra_headers = _captured_extra_headers(mock_post) + assert extra_headers is not None + payload = extra_headers[PROCESSING_DATA_HEADER] + assert payload == {"requested_parallel": 4, "user_id": "user-xyz"} + + @pytest.mark.asyncio + async def test_user_id_does_not_appear_when_only_requested_parallel_set(self): + client = _make_client() + audio = BytesIO(b"fake-audio") + + with patch.object(client._transport, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = _job_response() + await client.submit_job(audio, requested_parallel=2) + + payload = _captured_extra_headers(mock_post)[PROCESSING_DATA_HEADER] + assert "user_id" not in payload + + @pytest.mark.asyncio + async def test_requested_parallel_does_not_appear_when_only_user_id_set(self): + client = _make_client() + audio = BytesIO(b"fake-audio") + + with patch.object(client._transport, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = _job_response() + await client.submit_job(audio, user_id="u1") + + payload = _captured_extra_headers(mock_post)[PROCESSING_DATA_HEADER] + assert "requested_parallel" not in payload + + @pytest.mark.asyncio + async def test_user_id_forwarded_from_transcribe(self): + client = _make_client() + audio = BytesIO(b"fake-audio") + + with patch.object(client._transport, "post", new_callable=AsyncMock) as mock_post: + mock_post.return_value = _job_response() + with patch.object(client, "wait_for_completion", new_callable=AsyncMock) as mock_wait: + mock_wait.return_value = MagicMock() + await client.transcribe(audio, user_id="transcribe-user") + + extra_headers = _captured_extra_headers(mock_post) + assert extra_headers is not None + assert extra_headers[PROCESSING_DATA_HEADER]["user_id"] == "transcribe-user" + + class TestSubmitJobReturnValue: """submit_job still returns the correct JobDetails regardless of requested_parallel."""