-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtts_server.py
More file actions
111 lines (86 loc) · 2.87 KB
/
tts_server.py
File metadata and controls
111 lines (86 loc) · 2.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""TTS API server using Chatterbox (Resemble AI)."""
import io
import tempfile
from pathlib import Path
import av
import numpy as np
import soundfile as sf
import torchaudio as ta
from fastapi import FastAPI, File, Form, UploadFile
from fastapi.responses import StreamingResponse
from config import (
TTS_CFG_WEIGHT,
TTS_DEVICE,
TTS_EXAGGERATION,
TTS_MODEL,
)
app = FastAPI(title="GPU TTS Server", version="1.0.0")
model = None
def get_model():
global model
if model is not None:
return model
if TTS_MODEL == "chatterbox-turbo":
from chatterbox.tts_turbo import ChatterboxTurboTTS
model = ChatterboxTurboTTS.from_pretrained(device=TTS_DEVICE)
else:
from chatterbox.tts import ChatterboxTTS
model = ChatterboxTTS.from_pretrained(device=TTS_DEVICE)
return model
@app.get("/health")
def health():
return {"status": "ok", "model": TTS_MODEL, "device": TTS_DEVICE}
@app.post("/tts")
async def synthesize(
text: str = Form(...),
voice: UploadFile | None = File(None),
exaggeration: float = Form(TTS_EXAGGERATION),
cfg_weight: float = Form(TTS_CFG_WEIGHT),
):
"""Generate speech from text. Optionally provide a voice reference WAV for cloning."""
tts = get_model()
voice_path: str | None = None
tmp_file = None
if voice is not None:
# Browser mic recordings arrive as WebM/Opus — decode with PyAV and save as WAV
content = await voice.read()
raw_file = tempfile.NamedTemporaryFile(suffix=".webm", delete=False)
raw_file.write(content)
raw_file.flush()
raw_file.close()
container = av.open(raw_file.name)
frames = []
src_rate = 48000
for frame in container.decode(audio=0):
arr = frame.to_ndarray()
frames.append(arr)
src_rate = frame.sample_rate
container.close()
Path(raw_file.name).unlink(missing_ok=True)
audio = np.concatenate(frames, axis=1)
if audio.shape[0] > 1:
audio = audio.mean(axis=0, keepdims=True)
tmp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
tmp_file.close()
sf.write(tmp_file.name, audio.squeeze(), src_rate)
voice_path = tmp_file.name
wav = tts.generate(
text,
audio_prompt_path=voice_path,
exaggeration=exaggeration,
cfg_weight=cfg_weight,
)
if tmp_file is not None:
Path(tmp_file.name).unlink(missing_ok=True)
buf = io.BytesIO()
ta.save(buf, wav, tts.sr, format="wav")
buf.seek(0)
return StreamingResponse(buf, media_type="audio/wav", headers={
"Content-Disposition": "attachment; filename=output.wav",
})
if __name__ == "__main__":
import uvicorn
from config import HOST, PORT
# Eagerly load model at startup
get_model()
uvicorn.run(app, host=HOST, port=int(PORT))