This commit is contained in:
matst80
2026-03-04 22:21:47 +01:00
commit a86357c190
6 changed files with 312 additions and 0 deletions

26
.gitignore vendored Normal file
View File

@@ -0,0 +1,26 @@
# Python environment
__pycache__/
*.py[cod]
*$py.class
venv/
.venv/
env/
.env
# Data and artifacts
*.wav
*.mp3
*.log
/tmp/
# Model and other Large Files
*.bin
*.pt
*.pth
*.h5
*.onnx
# IDEs
.vscode/
.idea/
.DS_Store

31
Dockerfile Normal file
View File

@@ -0,0 +1,31 @@
# Use NVIDIA CUDA 12.6 image for compatibility
FROM nvidia/cuda:12.6.2-devel-ubuntu22.04
# Set working directory
WORKDIR /app
# Install system dependencies
RUN apt-get update && apt-get install -y \
python3 \
python3-pip \
git \
ffmpeg \
libsndfile1 \
&& rm -rf /var/lib/apt/lists/*
# Install Python dependencies
COPY requirements.txt .
RUN pip3 install --no-cache-dir -r requirements.txt
# Copy application files
COPY . .
# Set environment variables
ENV MODEL_ID="nari-labs/Dia-1.6B"
ENV PORT=8000
# Expose API port
EXPOSE 8000
# Run the FastAPI server
CMD ["python3", "main.py"]

87
README.md Normal file
View File

@@ -0,0 +1,87 @@
# Dia-1.6B API Server
API server for [nari-labs/Dia-1.6B](https://huggingface.co/nari-labs/Dia-1.6B), a 1.6 billion-parameter text-to-speech (TTS) model designed for realistic dialogue generation.
## Features
- 🗣️ **Realistic Dialogue**: Directly generates natural-sounding conversations from transcripts.
- 🎭 **Emotion and Tone**: Supports non-verbal cues like `(laughs)`, `(coughs)`, and `(clears throat)`.
- 👥 **Multi-Speaker Support**: Uses tags like `[S1]` and `[S2]` to alternate between speakers.
- 🎙️ **Audio Prompting**: Supports voice conditioning and cloning via audio prompts.
- 🚀 **FastAPI Implementation**: High-performance, documented API endpoints.
## Prerequisites
- **Python 3.9+**
- **NVIDIA GPU (Recommended)**: 10GB+ VRAM for optimal performance.
- **CUDA 12.6+** (Mandatory for inference).
## Installation
1. **Clone the repository and navigate into the folder:**
```bash
git clone <repo-url>
cd dia-api-server
```
2. **Create a virtual environment:**
```bash
python -m venv .venv
source .venv/bin/activate # On Windows: .venv\Scripts\activate
```
3. **Install dependencies:**
```bash
pip install -r requirements.txt
```
## Usage
### Running the Server
```bash
python main.py
```
The server will be available at `http://localhost:8000`.
### API Documentation
Once the server is running, you can access the interactive documentation at:
- Swagger UI: `http://localhost:8000/docs`
- Redoc: `http://localhost:8000/redoc`
### Example Endpoint: `/generate` (POST)
**Parameters:**
- `text` (Form data): The transcript including speaker tags.
- `audio_prompt` (Form file, optional): An audio file to condition the generation.
**Response:**
Returns a `StreamingResponse` as a `audio/wav` binary stream.
### Test Script
You can use `test_api.py` to verify the server:
```bash
python test_api.py
```
## Docker Deployment (Recommended)
Developing and running locally may be complicated due to CUDA requirements. Here is a sample `Dockerfile` for deployment:
```dockerfile
FROM nvidia/cuda:12.6.0-devel-ubuntu22.04
WORKDIR /app
RUN apt-get update && apt-get install -y \
python3 \
python3-pip \
git \
ffmpeg \
&& rm -rf /var/lib/apt/lists/*
COPY requirements.txt .
RUN pip3 install -r requirements.txt
COPY . .
CMD ["python3", "main.py"]
```
## License
Refer to the [nari-labs/Dia-1.6B](https://huggingface.co/nari-labs/Dia-1.6B#🪪-license) license on Hugging Face.

104
main.py Normal file
View File

@@ -0,0 +1,104 @@
import io
import os
import torch
import soundfile as sf
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from typing import Optional
import uvicorn
# We try to import Dia. In a real environment, this would be installed via requirements.txt
try:
from dia.model import Dia
except ImportError:
# Fallback for development if not installed
Dia = None
app = FastAPI(
title="Dia-1.6B API Server",
description="API server for Nari Labs Dia-1.6B TTS model. Supports realistic dialogue generation, speaker tags, and audio prompting.",
version="1.0.0"
)
# Global model instance
model = None
@app.on_event("startup")
async def load_model():
"""Load the model on startup."""
global model
if Dia is None:
print("Warning: 'dia' library not found. Model will not be loaded.")
return
model_id = os.getenv("MODEL_ID", "nari-labs/Dia-1.6B")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading model {model_id} on {device}...")
try:
model = Dia.from_pretrained(model_id)
if device == "cuda":
model = model.to(device)
print("Model loaded successfully.")
except Exception as e:
print(f"Error loading model: {e}")
@app.post("/generate", summary="Generate audio from text")
async def generate(
text: str = Form(..., description="The transcript text to generate audio for. Use speaker tags like [S1], [S2]."),
audio_prompt: Optional[UploadFile] = File(None, description="Optional audio file for voice cloning/conditioning.")
):
"""
Generate realistic dialogue audio from a transcript.
Supports speaker tags [S1], [S2], and non-verbal cues like (laughs).
"""
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded or 'dia' library missing.")
try:
prompt_path = None
if audio_prompt:
# Temporary file for the audio prompt
prompt_path = f"/tmp/prompt_{audio_prompt.filename}"
with open(prompt_path, "wb") as f:
content = await audio_prompt.read()
f.write(content)
# Generate audio using the model
# According to documentation, generate returns a numpy array
# Signature: model.generate(text, audio_prompt=None, ...)
# We pass prompt_path if available
output = model.generate(text, audio_prompt=prompt_path)
# Convert numpy array to WAV bytes
output_buffer = io.BytesIO()
sf.write(output_buffer, output, 44100, format='WAV')
output_buffer.seek(0)
# Cleanup prompt file
if prompt_path and os.path.exists(prompt_path):
os.remove(prompt_path)
return StreamingResponse(
output_buffer,
media_type="audio/wav",
headers={"Content-Disposition": "attachment; filename=output.wav"}
)
except Exception as e:
print(f"Generation error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health", summary="Health check")
def health():
"""Check if the server is alive and the model is loaded."""
return {
"status": "healthy",
"model_loaded": model is not None,
"device": "cuda" if torch.cuda.is_available() else "cpu"
}
if __name__ == "__main__":
port = int(os.getenv("PORT", 8000))
uvicorn.run(app, host="0.0.0.0", port=port)

8
requirements.txt Normal file
View File

@@ -0,0 +1,8 @@
fastapi
uvicorn
torch>=2.0
transformers
soundfile
pydantic
python-multipart
git+https://github.com/nari-labs/dia.git

56
test_api.py Normal file
View File

@@ -0,0 +1,56 @@
import requests
from pathlib import Path
# API endpoint
URL = "http://localhost:8000/generate"
def test_generation(text: str, output_file: str = "output.wav", audio_prompt: str = None):
"""
Test the Dia API generation endpoint.
Args:
text (str): Transcript text with speaker tags.
output_file (str): Filename to save the output audio.
audio_prompt (str): Optional path to an audio prompt file.
"""
# Using 'Form' and 'File' parameters in requests
data = {"text": text}
files = {}
if audio_prompt and Path(audio_prompt).exists():
files["audio_prompt"] = open(audio_prompt, "rb")
print(f"Calling generation endpoint with text: '{text[:100]}...'")
try:
# Use POST request with timeout
response = requests.post(URL, data=data, files=files, timeout=60)
response.raise_for_status()
# Save the audio output
with open(output_file, "wb") as f:
f.write(response.content)
print(f"Successfully generated audio. Saved as '{output_file}'")
except requests.exceptions.RequestException as e:
print(f"Error calling API: {e}")
if response and response.content:
print(f"Server error details: {response.text}")
finally:
if files:
files["audio_prompt"].close()
if __name__ == "__main__":
# Sample transcript with realistic dialogue features
SAMPLE_TEXT = (
"[S1] Dia is an open weights text to dialogue model. (laughs) "
"[S2] It allows full control over scripts and voices. "
"[S1] Wow. Really impressive."
)
# Check if we should use a specific prompt
# test_generation(SAMPLE_TEXT, audio_prompt="path/to/my_voice.wav")
print("Testing basic generation...")
test_generation(SAMPLE_TEXT)