import io import os import torch import soundfile as sf from fastapi import FastAPI, UploadFile, File, Form, HTTPException from fastapi.responses import StreamingResponse from pydantic import BaseModel from typing import Optional import uvicorn # We try to import Dia. In a real environment, this would be installed via requirements.txt try: from dia.model import Dia except ImportError: # Fallback for development if not installed Dia = None app = FastAPI( title="Dia-1.6B API Server", description="API server for Nari Labs Dia-1.6B TTS model. Supports realistic dialogue generation, speaker tags, and audio prompting.", version="1.0.0" ) # Global model instance model = None @app.on_event("startup") async def load_model(): """Load the model on startup.""" global model if Dia is None: print("Warning: 'dia' library not found. Model will not be loaded.") return model_id = os.getenv("MODEL_ID", "nari-labs/Dia-1.6B") device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Loading model {model_id} on {device}...") try: model = Dia.from_pretrained(model_id) if device == "cuda": model = model.to(device) print("Model loaded successfully.") except Exception as e: print(f"Error loading model: {e}") @app.post("/generate", summary="Generate audio from text") async def generate( text: str = Form(..., description="The transcript text to generate audio for. Use speaker tags like [S1], [S2]."), audio_prompt: Optional[UploadFile] = File(None, description="Optional audio file for voice cloning/conditioning.") ): """ Generate realistic dialogue audio from a transcript. Supports speaker tags [S1], [S2], and non-verbal cues like (laughs). """ if model is None: raise HTTPException(status_code=503, detail="Model not loaded or 'dia' library missing.") try: prompt_path = None if audio_prompt: # Temporary file for the audio prompt prompt_path = f"/tmp/prompt_{audio_prompt.filename}" with open(prompt_path, "wb") as f: content = await audio_prompt.read() f.write(content) # Generate audio using the model # According to documentation, generate returns a numpy array # Signature: model.generate(text, audio_prompt=None, ...) # We pass prompt_path if available output = model.generate(text, audio_prompt=prompt_path) # Convert numpy array to WAV bytes output_buffer = io.BytesIO() sf.write(output_buffer, output, 44100, format='WAV') output_buffer.seek(0) # Cleanup prompt file if prompt_path and os.path.exists(prompt_path): os.remove(prompt_path) return StreamingResponse( output_buffer, media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=output.wav"} ) except Exception as e: print(f"Generation error: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/health", summary="Health check") def health(): """Check if the server is alive and the model is loaded.""" return { "status": "healthy", "model_loaded": model is not None, "device": "cuda" if torch.cuda.is_available() else "cpu" } if __name__ == "__main__": port = int(os.getenv("PORT", 8000)) uvicorn.run(app, host="0.0.0.0", port=port)