57 lines
1.8 KiB
Python
57 lines
1.8 KiB
Python
import requests
|
|
from pathlib import Path
|
|
|
|
# API endpoint
|
|
URL = "http://localhost:8000/generate"
|
|
|
|
def test_generation(text: str, output_file: str = "output.wav", audio_prompt: str = None):
|
|
"""
|
|
Test the Dia API generation endpoint.
|
|
|
|
Args:
|
|
text (str): Transcript text with speaker tags.
|
|
output_file (str): Filename to save the output audio.
|
|
audio_prompt (str): Optional path to an audio prompt file.
|
|
"""
|
|
# Using 'Form' and 'File' parameters in requests
|
|
data = {"text": text}
|
|
files = {}
|
|
|
|
if audio_prompt and Path(audio_prompt).exists():
|
|
files["audio_prompt"] = open(audio_prompt, "rb")
|
|
|
|
print(f"Calling generation endpoint with text: '{text[:100]}...'")
|
|
|
|
try:
|
|
# Use POST request with timeout
|
|
response = requests.post(URL, data=data, files=files, timeout=60)
|
|
response.raise_for_status()
|
|
|
|
# Save the audio output
|
|
with open(output_file, "wb") as f:
|
|
f.write(response.content)
|
|
|
|
print(f"Successfully generated audio. Saved as '{output_file}'")
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
print(f"Error calling API: {e}")
|
|
if response and response.content:
|
|
print(f"Server error details: {response.text}")
|
|
finally:
|
|
if files:
|
|
files["audio_prompt"].close()
|
|
|
|
if __name__ == "__main__":
|
|
# Sample transcript with realistic dialogue features
|
|
SAMPLE_TEXT = (
|
|
"[S1] Dia is an open weights text to dialogue model. (laughs) "
|
|
"[S2] It allows full control over scripts and voices. "
|
|
"[S1] Wow. Really impressive."
|
|
)
|
|
|
|
# Check if we should use a specific prompt
|
|
# test_generation(SAMPLE_TEXT, audio_prompt="path/to/my_voice.wav")
|
|
|
|
print("Testing basic generation...")
|
|
test_generation(SAMPLE_TEXT)
|