105 lines
3.4 KiB
Python
105 lines
3.4 KiB
Python
import io
|
|
import os
|
|
import torch
|
|
import soundfile as sf
|
|
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
|
from fastapi.responses import StreamingResponse
|
|
from pydantic import BaseModel
|
|
from typing import Optional
|
|
import uvicorn
|
|
|
|
# We try to import Dia. In a real environment, this would be installed via requirements.txt
|
|
try:
|
|
from dia.model import Dia
|
|
except ImportError:
|
|
# Fallback for development if not installed
|
|
Dia = None
|
|
|
|
app = FastAPI(
|
|
title="Dia-1.6B API Server",
|
|
description="API server for Nari Labs Dia-1.6B TTS model. Supports realistic dialogue generation, speaker tags, and audio prompting.",
|
|
version="1.0.0"
|
|
)
|
|
|
|
# Global model instance
|
|
model = None
|
|
|
|
@app.on_event("startup")
|
|
async def load_model():
|
|
"""Load the model on startup."""
|
|
global model
|
|
if Dia is None:
|
|
print("Warning: 'dia' library not found. Model will not be loaded.")
|
|
return
|
|
|
|
model_id = os.getenv("MODEL_ID", "nari-labs/Dia-1.6B")
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
print(f"Loading model {model_id} on {device}...")
|
|
|
|
try:
|
|
model = Dia.from_pretrained(model_id)
|
|
if device == "cuda":
|
|
model = model.to(device)
|
|
print("Model loaded successfully.")
|
|
except Exception as e:
|
|
print(f"Error loading model: {e}")
|
|
|
|
@app.post("/generate", summary="Generate audio from text")
|
|
async def generate(
|
|
text: str = Form(..., description="The transcript text to generate audio for. Use speaker tags like [S1], [S2]."),
|
|
audio_prompt: Optional[UploadFile] = File(None, description="Optional audio file for voice cloning/conditioning.")
|
|
):
|
|
"""
|
|
Generate realistic dialogue audio from a transcript.
|
|
Supports speaker tags [S1], [S2], and non-verbal cues like (laughs).
|
|
"""
|
|
if model is None:
|
|
raise HTTPException(status_code=503, detail="Model not loaded or 'dia' library missing.")
|
|
|
|
try:
|
|
prompt_path = None
|
|
if audio_prompt:
|
|
# Temporary file for the audio prompt
|
|
prompt_path = f"/tmp/prompt_{audio_prompt.filename}"
|
|
with open(prompt_path, "wb") as f:
|
|
content = await audio_prompt.read()
|
|
f.write(content)
|
|
|
|
# Generate audio using the model
|
|
# According to documentation, generate returns a numpy array
|
|
# Signature: model.generate(text, audio_prompt=None, ...)
|
|
# We pass prompt_path if available
|
|
output = model.generate(text, audio_prompt=prompt_path)
|
|
|
|
# Convert numpy array to WAV bytes
|
|
output_buffer = io.BytesIO()
|
|
sf.write(output_buffer, output, 44100, format='WAV')
|
|
output_buffer.seek(0)
|
|
|
|
# Cleanup prompt file
|
|
if prompt_path and os.path.exists(prompt_path):
|
|
os.remove(prompt_path)
|
|
|
|
return StreamingResponse(
|
|
output_buffer,
|
|
media_type="audio/wav",
|
|
headers={"Content-Disposition": "attachment; filename=output.wav"}
|
|
)
|
|
|
|
except Exception as e:
|
|
print(f"Generation error: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.get("/health", summary="Health check")
|
|
def health():
|
|
"""Check if the server is alive and the model is loaded."""
|
|
return {
|
|
"status": "healthy",
|
|
"model_loaded": model is not None,
|
|
"device": "cuda" if torch.cuda.is_available() else "cpu"
|
|
}
|
|
|
|
if __name__ == "__main__":
|
|
port = int(os.getenv("PORT", 8000))
|
|
uvicorn.run(app, host="0.0.0.0", port=port)
|