commit a86357c1909de0c532b6e3e9fb05b2b3d03445cd Author: matst80 Date: Wed Mar 4 22:21:47 2026 +0100 initial diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..7ba915e --- /dev/null +++ b/.gitignore @@ -0,0 +1,26 @@ +# Python environment +__pycache__/ +*.py[cod] +*$py.class +venv/ +.venv/ +env/ +.env + +# Data and artifacts +*.wav +*.mp3 +*.log +/tmp/ + +# Model and other Large Files +*.bin +*.pt +*.pth +*.h5 +*.onnx + +# IDEs +.vscode/ +.idea/ +.DS_Store diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..3546988 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,31 @@ +# Use NVIDIA CUDA 12.6 image for compatibility +FROM nvidia/cuda:12.6.2-devel-ubuntu22.04 + +# Set working directory +WORKDIR /app + +# Install system dependencies +RUN apt-get update && apt-get install -y \ + python3 \ + python3-pip \ + git \ + ffmpeg \ + libsndfile1 \ + && rm -rf /var/lib/apt/lists/* + +# Install Python dependencies +COPY requirements.txt . +RUN pip3 install --no-cache-dir -r requirements.txt + +# Copy application files +COPY . . + +# Set environment variables +ENV MODEL_ID="nari-labs/Dia-1.6B" +ENV PORT=8000 + +# Expose API port +EXPOSE 8000 + +# Run the FastAPI server +CMD ["python3", "main.py"] diff --git a/README.md b/README.md new file mode 100644 index 0000000..d63d79b --- /dev/null +++ b/README.md @@ -0,0 +1,87 @@ +# Dia-1.6B API Server + +API server for [nari-labs/Dia-1.6B](https://huggingface.co/nari-labs/Dia-1.6B), a 1.6 billion-parameter text-to-speech (TTS) model designed for realistic dialogue generation. + +## Features +- 🗣️ **Realistic Dialogue**: Directly generates natural-sounding conversations from transcripts. +- 🎭 **Emotion and Tone**: Supports non-verbal cues like `(laughs)`, `(coughs)`, and `(clears throat)`. +- 👥 **Multi-Speaker Support**: Uses tags like `[S1]` and `[S2]` to alternate between speakers. +- 🎙️ **Audio Prompting**: Supports voice conditioning and cloning via audio prompts. +- 🚀 **FastAPI Implementation**: High-performance, documented API endpoints. + +## Prerequisites +- **Python 3.9+** +- **NVIDIA GPU (Recommended)**: 10GB+ VRAM for optimal performance. +- **CUDA 12.6+** (Mandatory for inference). + +## Installation + +1. **Clone the repository and navigate into the folder:** + ```bash + git clone + cd dia-api-server + ``` + +2. **Create a virtual environment:** + ```bash + python -m venv .venv + source .venv/bin/activate # On Windows: .venv\Scripts\activate + ``` + +3. **Install dependencies:** + ```bash + pip install -r requirements.txt + ``` + +## Usage + +### Running the Server +```bash +python main.py +``` +The server will be available at `http://localhost:8000`. + +### API Documentation +Once the server is running, you can access the interactive documentation at: +- Swagger UI: `http://localhost:8000/docs` +- Redoc: `http://localhost:8000/redoc` + +### Example Endpoint: `/generate` (POST) +**Parameters:** +- `text` (Form data): The transcript including speaker tags. +- `audio_prompt` (Form file, optional): An audio file to condition the generation. + +**Response:** +Returns a `StreamingResponse` as a `audio/wav` binary stream. + +### Test Script +You can use `test_api.py` to verify the server: +```bash +python test_api.py +``` + +## Docker Deployment (Recommended) +Developing and running locally may be complicated due to CUDA requirements. Here is a sample `Dockerfile` for deployment: + +```dockerfile +FROM nvidia/cuda:12.6.0-devel-ubuntu22.04 + +WORKDIR /app + +RUN apt-get update && apt-get install -y \ + python3 \ + python3-pip \ + git \ + ffmpeg \ + && rm -rf /var/lib/apt/lists/* + +COPY requirements.txt . +RUN pip3 install -r requirements.txt + +COPY . . + +CMD ["python3", "main.py"] +``` + +## License +Refer to the [nari-labs/Dia-1.6B](https://huggingface.co/nari-labs/Dia-1.6B#🪪-license) license on Hugging Face. diff --git a/main.py b/main.py new file mode 100644 index 0000000..d6e26a4 --- /dev/null +++ b/main.py @@ -0,0 +1,104 @@ +import io +import os +import torch +import soundfile as sf +from fastapi import FastAPI, UploadFile, File, Form, HTTPException +from fastapi.responses import StreamingResponse +from pydantic import BaseModel +from typing import Optional +import uvicorn + +# We try to import Dia. In a real environment, this would be installed via requirements.txt +try: + from dia.model import Dia +except ImportError: + # Fallback for development if not installed + Dia = None + +app = FastAPI( + title="Dia-1.6B API Server", + description="API server for Nari Labs Dia-1.6B TTS model. Supports realistic dialogue generation, speaker tags, and audio prompting.", + version="1.0.0" +) + +# Global model instance +model = None + +@app.on_event("startup") +async def load_model(): + """Load the model on startup.""" + global model + if Dia is None: + print("Warning: 'dia' library not found. Model will not be loaded.") + return + + model_id = os.getenv("MODEL_ID", "nari-labs/Dia-1.6B") + device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Loading model {model_id} on {device}...") + + try: + model = Dia.from_pretrained(model_id) + if device == "cuda": + model = model.to(device) + print("Model loaded successfully.") + except Exception as e: + print(f"Error loading model: {e}") + +@app.post("/generate", summary="Generate audio from text") +async def generate( + text: str = Form(..., description="The transcript text to generate audio for. Use speaker tags like [S1], [S2]."), + audio_prompt: Optional[UploadFile] = File(None, description="Optional audio file for voice cloning/conditioning.") +): + """ + Generate realistic dialogue audio from a transcript. + Supports speaker tags [S1], [S2], and non-verbal cues like (laughs). + """ + if model is None: + raise HTTPException(status_code=503, detail="Model not loaded or 'dia' library missing.") + + try: + prompt_path = None + if audio_prompt: + # Temporary file for the audio prompt + prompt_path = f"/tmp/prompt_{audio_prompt.filename}" + with open(prompt_path, "wb") as f: + content = await audio_prompt.read() + f.write(content) + + # Generate audio using the model + # According to documentation, generate returns a numpy array + # Signature: model.generate(text, audio_prompt=None, ...) + # We pass prompt_path if available + output = model.generate(text, audio_prompt=prompt_path) + + # Convert numpy array to WAV bytes + output_buffer = io.BytesIO() + sf.write(output_buffer, output, 44100, format='WAV') + output_buffer.seek(0) + + # Cleanup prompt file + if prompt_path and os.path.exists(prompt_path): + os.remove(prompt_path) + + return StreamingResponse( + output_buffer, + media_type="audio/wav", + headers={"Content-Disposition": "attachment; filename=output.wav"} + ) + + except Exception as e: + print(f"Generation error: {e}") + raise HTTPException(status_code=500, detail=str(e)) + +@app.get("/health", summary="Health check") +def health(): + """Check if the server is alive and the model is loaded.""" + return { + "status": "healthy", + "model_loaded": model is not None, + "device": "cuda" if torch.cuda.is_available() else "cpu" + } + +if __name__ == "__main__": + port = int(os.getenv("PORT", 8000)) + uvicorn.run(app, host="0.0.0.0", port=port) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..f6f6886 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +fastapi +uvicorn +torch>=2.0 +transformers +soundfile +pydantic +python-multipart +git+https://github.com/nari-labs/dia.git diff --git a/test_api.py b/test_api.py new file mode 100644 index 0000000..4b3670b --- /dev/null +++ b/test_api.py @@ -0,0 +1,56 @@ +import requests +from pathlib import Path + +# API endpoint +URL = "http://localhost:8000/generate" + +def test_generation(text: str, output_file: str = "output.wav", audio_prompt: str = None): + """ + Test the Dia API generation endpoint. + + Args: + text (str): Transcript text with speaker tags. + output_file (str): Filename to save the output audio. + audio_prompt (str): Optional path to an audio prompt file. + """ + # Using 'Form' and 'File' parameters in requests + data = {"text": text} + files = {} + + if audio_prompt and Path(audio_prompt).exists(): + files["audio_prompt"] = open(audio_prompt, "rb") + + print(f"Calling generation endpoint with text: '{text[:100]}...'") + + try: + # Use POST request with timeout + response = requests.post(URL, data=data, files=files, timeout=60) + response.raise_for_status() + + # Save the audio output + with open(output_file, "wb") as f: + f.write(response.content) + + print(f"Successfully generated audio. Saved as '{output_file}'") + + except requests.exceptions.RequestException as e: + print(f"Error calling API: {e}") + if response and response.content: + print(f"Server error details: {response.text}") + finally: + if files: + files["audio_prompt"].close() + +if __name__ == "__main__": + # Sample transcript with realistic dialogue features + SAMPLE_TEXT = ( + "[S1] Dia is an open weights text to dialogue model. (laughs) " + "[S2] It allows full control over scripts and voices. " + "[S1] Wow. Really impressive." + ) + + # Check if we should use a specific prompt + # test_generation(SAMPLE_TEXT, audio_prompt="path/to/my_voice.wav") + + print("Testing basic generation...") + test_generation(SAMPLE_TEXT)