diff --git a/aiola/clients/stt/client.py b/aiola/clients/stt/client.py index a88c978..19acbce 100644 --- a/aiola/clients/stt/client.py +++ b/aiola/clients/stt/client.py @@ -63,13 +63,17 @@ def _build_query_and_headers( query = { "execution_id": execution_id, "flow_id": resolved_workflow_id, - "lang_code": lang_code or "en", "time_zone": time_zone or "UTC", - "keywords": json.dumps(keywords or {}), - "tasks_config": json.dumps(tasks_config or {}), "x-aiola-api-token": access_token, } + if lang_code is not None: + query["lang_code"] = lang_code + if keywords is not None: + query["keywords"] = json.dumps(keywords) + if tasks_config is not None: + query["tasks_config"] = json.dumps(tasks_config) + headers = { "Authorization": f"Bearer {access_token}", } diff --git a/aiola/types.py b/aiola/types.py index 9d8b883..0978fc3 100644 --- a/aiola/types.py +++ b/aiola/types.py @@ -3,7 +3,7 @@ import enum from collections.abc import Mapping from dataclasses import dataclass -from typing import IO, Union +from typing import IO, Any, Union from .constants import DEFAULT_AUTH_BASE_URL, DEFAULT_BASE_URL, DEFAULT_HTTP_TIMEOUT, DEFAULT_WORKFLOW_ID @@ -46,16 +46,7 @@ def __post_init__(self) -> None: class LiveEvents(str, enum.Enum): Transcript = "transcript" Translation = "translation" - SentimentAnalysis = "sentiment_analysis" - Summarization = "summarization" - TopicDetection = "topic_detection" - ContentModeration = "content_moderation" - AutoChapters = "auto_chapters" - FormFilling = "form_filling" - EntityDetection = "entity_detection" - EntityDetectionFromList = "entity_detection_from_list" - KeyPhrases = "key_phrases" - PiiRedaction = "pii_redaction" + Structured = "structured" Error = "error" Disconnect = "disconnect" Connect = "connect" @@ -115,6 +106,13 @@ def from_dict(cls, data: dict) -> TranscriptionResponse: ) +@dataclass +class StructuredResponse: + """Response from structured API.""" + + results: dict[str, Any] + + @dataclass class SessionCloseResponse: """Response from session close API.""" @@ -137,40 +135,9 @@ class TranslationPayload: dst_lang_code: str -@dataclass -class EntityDetectionFromListPayload: - entity_list: list[str] - - -@dataclass -class _EmptyPayload: - pass - - -EntityDetectionPayload = _EmptyPayload -KeyPhrasesPayload = _EmptyPayload -PiiRedactionPayload = _EmptyPayload -SentimentAnalysisPayload = _EmptyPayload -SummarizationPayload = _EmptyPayload -TopicDetectionPayload = _EmptyPayload -ContentModerationPayload = _EmptyPayload -AutoChaptersPayload = _EmptyPayload -FormFillingPayload = _EmptyPayload - - @dataclass class TasksConfig: - FORM_FILLING: FormFillingPayload | None = None TRANSLATION: TranslationPayload | None = None - ENTITY_DETECTION: EntityDetectionPayload | None = None - ENTITY_DETECTION_FROM_LIST: EntityDetectionFromListPayload | None = None - KEY_PHRASES: KeyPhrasesPayload | None = None - PII_REDACTION: PiiRedactionPayload | None = None - SENTIMENT_ANALYSIS: SentimentAnalysisPayload | None = None - SUMMARIZATION: SummarizationPayload | None = None - TOPIC_DETECTION: TopicDetectionPayload | None = None - CONTENT_MODERATION: ContentModerationPayload | None = None - AUTO_CHAPTERS: AutoChaptersPayload | None = None FileContent = Union[IO[bytes], bytes, str] diff --git a/tests/unit/stt/test_stt_client.py b/tests/unit/stt/test_stt_client.py index 8780cbc..feca5de 100644 --- a/tests/unit/stt/test_stt_client.py +++ b/tests/unit/stt/test_stt_client.py @@ -326,7 +326,7 @@ def test_stt_stream_with_empty_tasks_config(patch_dummy_socket): def test_stt_stream_with_no_tasks_config(patch_dummy_socket): - """``SttClient.stream`` handles None tasks_config properly.""" + """``SttClient.stream`` handles None tasks_config properly by not including it in URL.""" client = AiolaClient(api_key="secret-key", base_url="https://speech.example") @@ -341,15 +341,14 @@ def test_stt_stream_with_no_tasks_config(patch_dummy_socket): # Access the underlying socket to validate connection parameters sio = connection._sio - # Verify None tasks_config is serialized as empty JSON object + # Verify None tasks_config is not included in URL kwargs = sio.connect_kwargs url = kwargs["url"] parsed = urllib.parse.urlparse(url) query = urllib.parse.parse_qs(parsed.query) - tasks_config_json = query["tasks_config"][0] - parsed_tasks_config = json.loads(tasks_config_json) - assert parsed_tasks_config == {} + # tasks_config should not be present when None + assert "tasks_config" not in query def test_stt_stream_with_all_tasks_config(patch_dummy_socket):