initial
This commit is contained in:
104
main.py
Normal file
104
main.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import io
|
||||
import os
|
||||
import torch
|
||||
import soundfile as sf
|
||||
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
import uvicorn
|
||||
|
||||
# We try to import Dia. In a real environment, this would be installed via requirements.txt
|
||||
try:
|
||||
from dia.model import Dia
|
||||
except ImportError:
|
||||
# Fallback for development if not installed
|
||||
Dia = None
|
||||
|
||||
app = FastAPI(
|
||||
title="Dia-1.6B API Server",
|
||||
description="API server for Nari Labs Dia-1.6B TTS model. Supports realistic dialogue generation, speaker tags, and audio prompting.",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
# Global model instance
|
||||
model = None
|
||||
|
||||
@app.on_event("startup")
|
||||
async def load_model():
|
||||
"""Load the model on startup."""
|
||||
global model
|
||||
if Dia is None:
|
||||
print("Warning: 'dia' library not found. Model will not be loaded.")
|
||||
return
|
||||
|
||||
model_id = os.getenv("MODEL_ID", "nari-labs/Dia-1.6B")
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"Loading model {model_id} on {device}...")
|
||||
|
||||
try:
|
||||
model = Dia.from_pretrained(model_id)
|
||||
if device == "cuda":
|
||||
model = model.to(device)
|
||||
print("Model loaded successfully.")
|
||||
except Exception as e:
|
||||
print(f"Error loading model: {e}")
|
||||
|
||||
@app.post("/generate", summary="Generate audio from text")
|
||||
async def generate(
|
||||
text: str = Form(..., description="The transcript text to generate audio for. Use speaker tags like [S1], [S2]."),
|
||||
audio_prompt: Optional[UploadFile] = File(None, description="Optional audio file for voice cloning/conditioning.")
|
||||
):
|
||||
"""
|
||||
Generate realistic dialogue audio from a transcript.
|
||||
Supports speaker tags [S1], [S2], and non-verbal cues like (laughs).
|
||||
"""
|
||||
if model is None:
|
||||
raise HTTPException(status_code=503, detail="Model not loaded or 'dia' library missing.")
|
||||
|
||||
try:
|
||||
prompt_path = None
|
||||
if audio_prompt:
|
||||
# Temporary file for the audio prompt
|
||||
prompt_path = f"/tmp/prompt_{audio_prompt.filename}"
|
||||
with open(prompt_path, "wb") as f:
|
||||
content = await audio_prompt.read()
|
||||
f.write(content)
|
||||
|
||||
# Generate audio using the model
|
||||
# According to documentation, generate returns a numpy array
|
||||
# Signature: model.generate(text, audio_prompt=None, ...)
|
||||
# We pass prompt_path if available
|
||||
output = model.generate(text, audio_prompt=prompt_path)
|
||||
|
||||
# Convert numpy array to WAV bytes
|
||||
output_buffer = io.BytesIO()
|
||||
sf.write(output_buffer, output, 44100, format='WAV')
|
||||
output_buffer.seek(0)
|
||||
|
||||
# Cleanup prompt file
|
||||
if prompt_path and os.path.exists(prompt_path):
|
||||
os.remove(prompt_path)
|
||||
|
||||
return StreamingResponse(
|
||||
output_buffer,
|
||||
media_type="audio/wav",
|
||||
headers={"Content-Disposition": "attachment; filename=output.wav"}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Generation error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/health", summary="Health check")
|
||||
def health():
|
||||
"""Check if the server is alive and the model is loaded."""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"model_loaded": model is not None,
|
||||
"device": "cuda" if torch.cuda.is_available() else "cpu"
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
port = int(os.getenv("PORT", 8000))
|
||||
uvicorn.run(app, host="0.0.0.0", port=port)
|
||||
Reference in New Issue
Block a user