Files
dia-api-server/test_api.py
2026-03-04 22:21:47 +01:00

57 lines
1.8 KiB
Python

import requests
from pathlib import Path
# API endpoint
URL = "http://localhost:8000/generate"
def test_generation(text: str, output_file: str = "output.wav", audio_prompt: str = None):
"""
Test the Dia API generation endpoint.
Args:
text (str): Transcript text with speaker tags.
output_file (str): Filename to save the output audio.
audio_prompt (str): Optional path to an audio prompt file.
"""
# Using 'Form' and 'File' parameters in requests
data = {"text": text}
files = {}
if audio_prompt and Path(audio_prompt).exists():
files["audio_prompt"] = open(audio_prompt, "rb")
print(f"Calling generation endpoint with text: '{text[:100]}...'")
try:
# Use POST request with timeout
response = requests.post(URL, data=data, files=files, timeout=60)
response.raise_for_status()
# Save the audio output
with open(output_file, "wb") as f:
f.write(response.content)
print(f"Successfully generated audio. Saved as '{output_file}'")
except requests.exceptions.RequestException as e:
print(f"Error calling API: {e}")
if response and response.content:
print(f"Server error details: {response.text}")
finally:
if files:
files["audio_prompt"].close()
if __name__ == "__main__":
# Sample transcript with realistic dialogue features
SAMPLE_TEXT = (
"[S1] Dia is an open weights text to dialogue model. (laughs) "
"[S2] It allows full control over scripts and voices. "
"[S1] Wow. Really impressive."
)
# Check if we should use a specific prompt
# test_generation(SAMPLE_TEXT, audio_prompt="path/to/my_voice.wav")
print("Testing basic generation...")
test_generation(SAMPLE_TEXT)