Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sdk/batch/speechmatics/batch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@
from ._models import TranscriptFilteringConfig
from ._models import TranscriptionConfig
from ._models import TranslationConfig
from ._transport import PROCESSING_DATA_HEADER

__all__ = [
"AsyncClient",
"PROCESSING_DATA_HEADER",
"AuthBase",
"AuthenticationError",
"AutoChaptersConfig",
Expand Down
38 changes: 35 additions & 3 deletions sdk/batch/speechmatics/batch/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from ._models import JobType
from ._models import Transcript
from ._models import TranscriptionConfig
from ._transport import PROCESSING_DATA_HEADER
from ._transport import Transport


Expand Down Expand Up @@ -139,6 +140,8 @@ async def submit_job(
*,
config: Optional[JobConfig] = None,
transcription_config: Optional[TranscriptionConfig] = None,
requested_parallel: Optional[int] = None,
user_id: Optional[str] = None,
) -> JobDetails:
"""
Submit a new transcription job.
Expand All @@ -154,6 +157,12 @@ async def submit_job(
to build a basic job configuration.
transcription_config: Transcription-specific configuration. Used if config
is not provided.
requested_parallel: Optional number of parallel engines to request for this job.
Sent as ``{"requested_parallel": N}`` in the ``X-SM-Processing-Data`` header.
This only applies when using the container onPrem on http batch mode.
user_id: Optional user identifier to associate with this job.
Sent as ``{"user_id": "..."}`` in the ``X-SM-Processing-Data`` header.
This only applies when using the container onPrem on http batch mode.

Returns:
JobDetails object containing the job ID and initial status.
Expand Down Expand Up @@ -200,7 +209,7 @@ async def submit_job(
assert audio_file is not None # for type checker; validated above
multipart_data, filename = await self._prepare_file_submission(audio_file, config_dict)

return await self._submit_and_create_job_details(multipart_data, filename, config)
return await self._submit_and_create_job_details(multipart_data, filename, config, requested_parallel, user_id)
except Exception as e:
if isinstance(e, (AuthenticationError, BatchError)):
raise
Expand Down Expand Up @@ -435,6 +444,8 @@ async def transcribe(
transcription_config: Optional[TranscriptionConfig] = None,
polling_interval: float = 5.0,
timeout: Optional[float] = None,
requested_parallel: Optional[int] = None,
user_id: Optional[str] = None,
) -> Union[Transcript, str]:
"""
Complete transcription workflow: submit job and wait for completion.
Expand All @@ -448,6 +459,12 @@ async def transcribe(
transcription_config: Transcription-specific configuration.
polling_interval: Time in seconds between status checks.
timeout: Maximum time in seconds to wait for completion.
requested_parallel: Optional number of parallel engines to request for this job.
Sent as ``{"requested_parallel": N}`` in the ``X-SM-Processing-Data`` header.
This only applies when using the container onPrem on http batch mode.
user_id: Optional user identifier to associate with this job.
Sent as ``{"user_id": "..."}`` in the ``X-SM-Processing-Data`` header.
This only applies when using the container onPrem on http batch mode.

Returns:
Transcript object containing the transcript and metadata.
Expand Down Expand Up @@ -475,6 +492,8 @@ async def transcribe(
audio_file,
config=config,
transcription_config=transcription_config,
requested_parallel=requested_parallel,
user_id=user_id,
)

# Wait for completion and return result
Expand Down Expand Up @@ -528,10 +547,23 @@ async def _prepare_file_submission(self, audio_file: Union[str, BinaryIO], confi
return multipart_data, filename

async def _submit_and_create_job_details(
self, multipart_data: dict, filename: str, config: JobConfig
self,
multipart_data: dict,
filename: str,
config: JobConfig,
requested_parallel: Optional[int] = None,
user_id: Optional[str] = None,
) -> JobDetails:
"""Submit job and create JobDetails response."""
response = await self._transport.post("/jobs", multipart_data=multipart_data)
extra_headers: Optional[dict[str, Any]] = None
processing_data: dict[str, Any] = {}
if requested_parallel is not None:
processing_data["requested_parallel"] = requested_parallel
if user_id is not None:
processing_data["user_id"] = user_id
if processing_data:
extra_headers = {PROCESSING_DATA_HEADER: processing_data}
response = await self._transport.post("/jobs", multipart_data=multipart_data, extra_headers=extra_headers)
job_id = response.get("id")
if not job_id:
raise BatchError("No job ID returned from server")
Expand Down
18 changes: 17 additions & 1 deletion sdk/batch/speechmatics/batch/_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import asyncio
import io
import json as _json
import sys
import uuid
from typing import Any
Expand All @@ -25,6 +26,8 @@
from ._logging import get_logger
from ._models import ConnectionConfig

PROCESSING_DATA_HEADER = "X-SM-Processing-Data"


class Transport:
"""
Expand Down Expand Up @@ -116,6 +119,7 @@ async def post(
json_data: Optional[dict[str, Any]] = None,
multipart_data: Optional[dict[str, Any]] = None,
timeout: Optional[float] = None,
extra_headers: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
"""
Send POST request to the API.
Expand All @@ -125,6 +129,7 @@ async def post(
json_data: Optional JSON data for request body
multipart_data: Optional multipart form data
timeout: Optional request timeout
extra_headers: Optional additional headers to include in the request

Returns:
JSON response as dictionary
Expand All @@ -133,7 +138,14 @@ async def post(
AuthenticationError: If authentication fails
TransportError: If request fails
"""
return await self._request("POST", path, json_data=json_data, multipart_data=multipart_data, timeout=timeout)
return await self._request(
"POST",
path,
json_data=json_data,
multipart_data=multipart_data,
timeout=timeout,
extra_headers=extra_headers,
)

async def delete(self, path: str, timeout: Optional[float] = None) -> dict[str, Any]:
"""
Expand Down Expand Up @@ -200,6 +212,7 @@ async def _request(
json_data: Optional[dict[str, Any]] = None,
multipart_data: Optional[dict[str, Any]] = None,
timeout: Optional[float] = None,
extra_headers: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
"""
Send HTTP request to the API.
Expand Down Expand Up @@ -227,6 +240,9 @@ async def _request(

url = f"{self._url.rstrip('/')}{path}"
headers = await self._prepare_headers()
if extra_headers:
for k, v in extra_headers.items():
headers[k] = _json.dumps(v) if isinstance(v, dict) else v

self._logger.debug(
"Sending HTTP request %s %s (json=%s, multipart=%s)",
Expand Down
Loading
Loading