-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbot.py
More file actions
227 lines (186 loc) · 8.25 KB
/
bot.py
File metadata and controls
227 lines (186 loc) · 8.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
import asyncio
import os
import sys
from dotenv import load_dotenv
from loguru import logger
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import BotInterruptionFrame, EndFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.serializers.protobuf import ProtobufFrameSerializer
from pipecat.services.cartesia.tts import CartesiaTTSService
from pipecat.services.deepgram.stt import DeepgramSTTService
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.transports.network.websocket_server import (
WebsocketServerParams,
WebsocketServerTransport,
)
import datetime
from processors import TextTranscriptionProcessor
from pipecat.frames.frames import TextFrame, TranscriptionFrame
from pipecat.processors.frame_processor import FrameProcessor, FrameDirection
# Custom AIResponseProcessor class for compatibility with tests
class AIResponseProcessor(FrameProcessor):
"""
Processor that converts Text frames to transcription frames with 'ai_assistant' user_id.
This is used for test compatibility.
"""
def __init__(self):
super().__init__()
async def process_frame(self, frame, direction=None):
# Call the parent's process_frame method first
await super().process_frame(frame, direction)
result = [frame]
# If this is a TextFrame, create a transcription frame
if isinstance(frame, TextFrame):
text = frame.text if hasattr(frame, 'text') else ""
# Create a TranscriptionFrame
timestamp = datetime.datetime.now().isoformat()
transcription_frame = TranscriptionFrame(
text=text,
user_id="ai_assistant",
timestamp=timestamp
)
result.append(transcription_frame)
return result
load_dotenv(override=True)
logger.remove(0)
logger.add(sys.stderr, level="DEBUG")
class SessionTimeoutHandler:
"""Handles actions to be performed when a session times out.
Inputs:
- task: Pipeline task (used to queue frames).
- tts: TTS service (used to generate speech output).
"""
def __init__(self, task, tts):
self.task = task
self.tts = tts
self.background_tasks = set()
async def handle_timeout(self, client_address):
"""Handles the timeout event for a session."""
try:
logger.info(f"Connection timeout for {client_address}")
# Queue a BotInterruptionFrame to notify the user
await self.task.queue_frames([BotInterruptionFrame()])
# Send the TTS message to inform the user about the timeout
await self.tts.say(
"I'm sorry, we are ending the call now. Please feel free to reach out again if you need assistance."
)
# Start the process to gracefully end the call in the background
end_call_task = asyncio.create_task(self._end_call())
self.background_tasks.add(end_call_task)
end_call_task.add_done_callback(self.background_tasks.discard)
except Exception as e:
logger.error(f"Error during session timeout handling: {e}")
async def _end_call(self):
"""Completes the session termination process after the TTS message."""
try:
# Wait for a duration to ensure TTS has completed
await asyncio.sleep(15)
# Queue both BotInterruptionFrame and EndFrame to conclude the session
await self.task.queue_frames([BotInterruptionFrame(), EndFrame()])
logger.info("TTS completed and EndFrame pushed successfully.")
except Exception as e:
logger.error(f"Error during call termination: {e}")
class Bot:
"""Main bot class that sets up and runs the conversation pipeline."""
def __init__(self):
self.transport = None
self.llm = None
self.stt = None
self.tts = None
self.context = None
self.context_aggregator = None
self.pipeline = None
self.task = None
self.runner = None
self.ai_response_processor = None
self.messages = [
{
"role": "system",
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
},
]
def setup_transport(self):
"""Set up the WebSocket transport."""
self.transport = WebsocketServerTransport(
params=WebsocketServerParams(
serializer=ProtobufFrameSerializer(),
audio_out_enabled=True,
add_wav_header=True,
vad_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
vad_audio_passthrough=True,
session_timeout=60 * 3, # 3 minutes
)
)
return self.transport
def setup_services(self):
"""Set up LLM, STT, and TTS services."""
self.llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o")
self.stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
# todo : this might be better suited for a different service? OpenAI? Not that this is bad? But is it the best?
self.tts = CartesiaTTSService(
api_key=os.getenv("CARTESIA_API_KEY"),
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121"
# "71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
#"694f9389-aac1-45b6-b726-9d9369183238" # Sarah USA
)
def setup_context(self):
"""Set up the LLM context and aggregator."""
self.context = OpenAILLMContext(self.messages)
self.context_aggregator = self.llm.create_context_aggregator(self.context)
def setup_pipeline(self):
"""Set up the processing pipeline."""
self.ai_response_processor = AIResponseProcessor()
self.pipeline = Pipeline(
[
self.transport.input(), # Websocket input from client
self.stt, # Speech-To-Text
self.context_aggregator.user(),
self.llm, # LLM
self.tts, # Text-To-Speech
self.transport.output(), # Websocket output to client
self.context_aggregator.assistant(),
]
)
self.task = PipelineTask(
self.pipeline,
params=PipelineParams(
audio_in_sample_rate=16000,
audio_out_sample_rate=16000,
allow_interruptions=True,
),
)
def setup_event_handlers(self):
"""Set up event handlers for the transport."""
@self.transport.event_handler("on_client_connected")
async def on_client_connected(transport, client):
# Kick off the conversation.
self.messages.append({"role": "system", "content": "Please introduce yourself to the user."})
await self.task.queue_frames([self.context_aggregator.user().get_context_frame()])
@self.transport.event_handler("on_session_timeout")
async def on_session_timeout(transport, client):
logger.info(f"Entering in timeout for {client.remote_address}")
timeout_handler = SessionTimeoutHandler(self.task, self.tts)
await timeout_handler.handle_timeout(client)
def initialize(self):
"""Initialize all components of the bot."""
self.setup_transport()
self.setup_services()
self.setup_context()
self.setup_pipeline()
self.setup_event_handlers()
self.runner = PipelineRunner()
async def run(self):
"""Run the bot."""
await self.runner.run(self.task)
async def main():
"""Initialize and run the bot."""
bot = Bot()
bot.initialize()
await bot.run()
if __name__ == "__main__":
asyncio.run(main())