This commit is contained in:
matst80
2026-03-04 22:21:47 +01:00
commit a86357c190
6 changed files with 312 additions and 0 deletions

104
main.py Normal file
View File

@@ -0,0 +1,104 @@
import io
import os
import torch
import soundfile as sf
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from typing import Optional
import uvicorn
# We try to import Dia. In a real environment, this would be installed via requirements.txt
try:
from dia.model import Dia
except ImportError:
# Fallback for development if not installed
Dia = None
app = FastAPI(
title="Dia-1.6B API Server",
description="API server for Nari Labs Dia-1.6B TTS model. Supports realistic dialogue generation, speaker tags, and audio prompting.",
version="1.0.0"
)
# Global model instance
model = None
@app.on_event("startup")
async def load_model():
"""Load the model on startup."""
global model
if Dia is None:
print("Warning: 'dia' library not found. Model will not be loaded.")
return
model_id = os.getenv("MODEL_ID", "nari-labs/Dia-1.6B")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading model {model_id} on {device}...")
try:
model = Dia.from_pretrained(model_id)
if device == "cuda":
model = model.to(device)
print("Model loaded successfully.")
except Exception as e:
print(f"Error loading model: {e}")
@app.post("/generate", summary="Generate audio from text")
async def generate(
text: str = Form(..., description="The transcript text to generate audio for. Use speaker tags like [S1], [S2]."),
audio_prompt: Optional[UploadFile] = File(None, description="Optional audio file for voice cloning/conditioning.")
):
"""
Generate realistic dialogue audio from a transcript.
Supports speaker tags [S1], [S2], and non-verbal cues like (laughs).
"""
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded or 'dia' library missing.")
try:
prompt_path = None
if audio_prompt:
# Temporary file for the audio prompt
prompt_path = f"/tmp/prompt_{audio_prompt.filename}"
with open(prompt_path, "wb") as f:
content = await audio_prompt.read()
f.write(content)
# Generate audio using the model
# According to documentation, generate returns a numpy array
# Signature: model.generate(text, audio_prompt=None, ...)
# We pass prompt_path if available
output = model.generate(text, audio_prompt=prompt_path)
# Convert numpy array to WAV bytes
output_buffer = io.BytesIO()
sf.write(output_buffer, output, 44100, format='WAV')
output_buffer.seek(0)
# Cleanup prompt file
if prompt_path and os.path.exists(prompt_path):
os.remove(prompt_path)
return StreamingResponse(
output_buffer,
media_type="audio/wav",
headers={"Content-Disposition": "attachment; filename=output.wav"}
)
except Exception as e:
print(f"Generation error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health", summary="Health check")
def health():
"""Check if the server is alive and the model is loaded."""
return {
"status": "healthy",
"model_loaded": model is not None,
"device": "cuda" if torch.cuda.is_available() else "cpu"
}
if __name__ == "__main__":
port = int(os.getenv("PORT", 8000))
uvicorn.run(app, host="0.0.0.0", port=port)