initial
This commit is contained in:
26
.gitignore
vendored
Normal file
26
.gitignore
vendored
Normal file
@@ -0,0 +1,26 @@
|
||||
# Python environment
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
venv/
|
||||
.venv/
|
||||
env/
|
||||
.env
|
||||
|
||||
# Data and artifacts
|
||||
*.wav
|
||||
*.mp3
|
||||
*.log
|
||||
/tmp/
|
||||
|
||||
# Model and other Large Files
|
||||
*.bin
|
||||
*.pt
|
||||
*.pth
|
||||
*.h5
|
||||
*.onnx
|
||||
|
||||
# IDEs
|
||||
.vscode/
|
||||
.idea/
|
||||
.DS_Store
|
||||
31
Dockerfile
Normal file
31
Dockerfile
Normal file
@@ -0,0 +1,31 @@
|
||||
# Use NVIDIA CUDA 12.6 image for compatibility
|
||||
FROM nvidia/cuda:12.6.2-devel-ubuntu22.04
|
||||
|
||||
# Set working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
python3 \
|
||||
python3-pip \
|
||||
git \
|
||||
ffmpeg \
|
||||
libsndfile1 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Python dependencies
|
||||
COPY requirements.txt .
|
||||
RUN pip3 install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy application files
|
||||
COPY . .
|
||||
|
||||
# Set environment variables
|
||||
ENV MODEL_ID="nari-labs/Dia-1.6B"
|
||||
ENV PORT=8000
|
||||
|
||||
# Expose API port
|
||||
EXPOSE 8000
|
||||
|
||||
# Run the FastAPI server
|
||||
CMD ["python3", "main.py"]
|
||||
87
README.md
Normal file
87
README.md
Normal file
@@ -0,0 +1,87 @@
|
||||
# Dia-1.6B API Server
|
||||
|
||||
API server for [nari-labs/Dia-1.6B](https://huggingface.co/nari-labs/Dia-1.6B), a 1.6 billion-parameter text-to-speech (TTS) model designed for realistic dialogue generation.
|
||||
|
||||
## Features
|
||||
- 🗣️ **Realistic Dialogue**: Directly generates natural-sounding conversations from transcripts.
|
||||
- 🎭 **Emotion and Tone**: Supports non-verbal cues like `(laughs)`, `(coughs)`, and `(clears throat)`.
|
||||
- 👥 **Multi-Speaker Support**: Uses tags like `[S1]` and `[S2]` to alternate between speakers.
|
||||
- 🎙️ **Audio Prompting**: Supports voice conditioning and cloning via audio prompts.
|
||||
- 🚀 **FastAPI Implementation**: High-performance, documented API endpoints.
|
||||
|
||||
## Prerequisites
|
||||
- **Python 3.9+**
|
||||
- **NVIDIA GPU (Recommended)**: 10GB+ VRAM for optimal performance.
|
||||
- **CUDA 12.6+** (Mandatory for inference).
|
||||
|
||||
## Installation
|
||||
|
||||
1. **Clone the repository and navigate into the folder:**
|
||||
```bash
|
||||
git clone <repo-url>
|
||||
cd dia-api-server
|
||||
```
|
||||
|
||||
2. **Create a virtual environment:**
|
||||
```bash
|
||||
python -m venv .venv
|
||||
source .venv/bin/activate # On Windows: .venv\Scripts\activate
|
||||
```
|
||||
|
||||
3. **Install dependencies:**
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### Running the Server
|
||||
```bash
|
||||
python main.py
|
||||
```
|
||||
The server will be available at `http://localhost:8000`.
|
||||
|
||||
### API Documentation
|
||||
Once the server is running, you can access the interactive documentation at:
|
||||
- Swagger UI: `http://localhost:8000/docs`
|
||||
- Redoc: `http://localhost:8000/redoc`
|
||||
|
||||
### Example Endpoint: `/generate` (POST)
|
||||
**Parameters:**
|
||||
- `text` (Form data): The transcript including speaker tags.
|
||||
- `audio_prompt` (Form file, optional): An audio file to condition the generation.
|
||||
|
||||
**Response:**
|
||||
Returns a `StreamingResponse` as a `audio/wav` binary stream.
|
||||
|
||||
### Test Script
|
||||
You can use `test_api.py` to verify the server:
|
||||
```bash
|
||||
python test_api.py
|
||||
```
|
||||
|
||||
## Docker Deployment (Recommended)
|
||||
Developing and running locally may be complicated due to CUDA requirements. Here is a sample `Dockerfile` for deployment:
|
||||
|
||||
```dockerfile
|
||||
FROM nvidia/cuda:12.6.0-devel-ubuntu22.04
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN apt-get update && apt-get install -y \
|
||||
python3 \
|
||||
python3-pip \
|
||||
git \
|
||||
ffmpeg \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN pip3 install -r requirements.txt
|
||||
|
||||
COPY . .
|
||||
|
||||
CMD ["python3", "main.py"]
|
||||
```
|
||||
|
||||
## License
|
||||
Refer to the [nari-labs/Dia-1.6B](https://huggingface.co/nari-labs/Dia-1.6B#🪪-license) license on Hugging Face.
|
||||
104
main.py
Normal file
104
main.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import io
|
||||
import os
|
||||
import torch
|
||||
import soundfile as sf
|
||||
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
import uvicorn
|
||||
|
||||
# We try to import Dia. In a real environment, this would be installed via requirements.txt
|
||||
try:
|
||||
from dia.model import Dia
|
||||
except ImportError:
|
||||
# Fallback for development if not installed
|
||||
Dia = None
|
||||
|
||||
app = FastAPI(
|
||||
title="Dia-1.6B API Server",
|
||||
description="API server for Nari Labs Dia-1.6B TTS model. Supports realistic dialogue generation, speaker tags, and audio prompting.",
|
||||
version="1.0.0"
|
||||
)
|
||||
|
||||
# Global model instance
|
||||
model = None
|
||||
|
||||
@app.on_event("startup")
|
||||
async def load_model():
|
||||
"""Load the model on startup."""
|
||||
global model
|
||||
if Dia is None:
|
||||
print("Warning: 'dia' library not found. Model will not be loaded.")
|
||||
return
|
||||
|
||||
model_id = os.getenv("MODEL_ID", "nari-labs/Dia-1.6B")
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"Loading model {model_id} on {device}...")
|
||||
|
||||
try:
|
||||
model = Dia.from_pretrained(model_id)
|
||||
if device == "cuda":
|
||||
model = model.to(device)
|
||||
print("Model loaded successfully.")
|
||||
except Exception as e:
|
||||
print(f"Error loading model: {e}")
|
||||
|
||||
@app.post("/generate", summary="Generate audio from text")
|
||||
async def generate(
|
||||
text: str = Form(..., description="The transcript text to generate audio for. Use speaker tags like [S1], [S2]."),
|
||||
audio_prompt: Optional[UploadFile] = File(None, description="Optional audio file for voice cloning/conditioning.")
|
||||
):
|
||||
"""
|
||||
Generate realistic dialogue audio from a transcript.
|
||||
Supports speaker tags [S1], [S2], and non-verbal cues like (laughs).
|
||||
"""
|
||||
if model is None:
|
||||
raise HTTPException(status_code=503, detail="Model not loaded or 'dia' library missing.")
|
||||
|
||||
try:
|
||||
prompt_path = None
|
||||
if audio_prompt:
|
||||
# Temporary file for the audio prompt
|
||||
prompt_path = f"/tmp/prompt_{audio_prompt.filename}"
|
||||
with open(prompt_path, "wb") as f:
|
||||
content = await audio_prompt.read()
|
||||
f.write(content)
|
||||
|
||||
# Generate audio using the model
|
||||
# According to documentation, generate returns a numpy array
|
||||
# Signature: model.generate(text, audio_prompt=None, ...)
|
||||
# We pass prompt_path if available
|
||||
output = model.generate(text, audio_prompt=prompt_path)
|
||||
|
||||
# Convert numpy array to WAV bytes
|
||||
output_buffer = io.BytesIO()
|
||||
sf.write(output_buffer, output, 44100, format='WAV')
|
||||
output_buffer.seek(0)
|
||||
|
||||
# Cleanup prompt file
|
||||
if prompt_path and os.path.exists(prompt_path):
|
||||
os.remove(prompt_path)
|
||||
|
||||
return StreamingResponse(
|
||||
output_buffer,
|
||||
media_type="audio/wav",
|
||||
headers={"Content-Disposition": "attachment; filename=output.wav"}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Generation error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@app.get("/health", summary="Health check")
|
||||
def health():
|
||||
"""Check if the server is alive and the model is loaded."""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"model_loaded": model is not None,
|
||||
"device": "cuda" if torch.cuda.is_available() else "cpu"
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
port = int(os.getenv("PORT", 8000))
|
||||
uvicorn.run(app, host="0.0.0.0", port=port)
|
||||
8
requirements.txt
Normal file
8
requirements.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
fastapi
|
||||
uvicorn
|
||||
torch>=2.0
|
||||
transformers
|
||||
soundfile
|
||||
pydantic
|
||||
python-multipart
|
||||
git+https://github.com/nari-labs/dia.git
|
||||
56
test_api.py
Normal file
56
test_api.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import requests
|
||||
from pathlib import Path
|
||||
|
||||
# API endpoint
|
||||
URL = "http://localhost:8000/generate"
|
||||
|
||||
def test_generation(text: str, output_file: str = "output.wav", audio_prompt: str = None):
|
||||
"""
|
||||
Test the Dia API generation endpoint.
|
||||
|
||||
Args:
|
||||
text (str): Transcript text with speaker tags.
|
||||
output_file (str): Filename to save the output audio.
|
||||
audio_prompt (str): Optional path to an audio prompt file.
|
||||
"""
|
||||
# Using 'Form' and 'File' parameters in requests
|
||||
data = {"text": text}
|
||||
files = {}
|
||||
|
||||
if audio_prompt and Path(audio_prompt).exists():
|
||||
files["audio_prompt"] = open(audio_prompt, "rb")
|
||||
|
||||
print(f"Calling generation endpoint with text: '{text[:100]}...'")
|
||||
|
||||
try:
|
||||
# Use POST request with timeout
|
||||
response = requests.post(URL, data=data, files=files, timeout=60)
|
||||
response.raise_for_status()
|
||||
|
||||
# Save the audio output
|
||||
with open(output_file, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
print(f"Successfully generated audio. Saved as '{output_file}'")
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"Error calling API: {e}")
|
||||
if response and response.content:
|
||||
print(f"Server error details: {response.text}")
|
||||
finally:
|
||||
if files:
|
||||
files["audio_prompt"].close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Sample transcript with realistic dialogue features
|
||||
SAMPLE_TEXT = (
|
||||
"[S1] Dia is an open weights text to dialogue model. (laughs) "
|
||||
"[S2] It allows full control over scripts and voices. "
|
||||
"[S1] Wow. Really impressive."
|
||||
)
|
||||
|
||||
# Check if we should use a specific prompt
|
||||
# test_generation(SAMPLE_TEXT, audio_prompt="path/to/my_voice.wav")
|
||||
|
||||
print("Testing basic generation...")
|
||||
test_generation(SAMPLE_TEXT)
|
||||
Reference in New Issue
Block a user