add speed option to inference
Opened this issue · 1 comments
Add a speed option to the infer function
StyleTTS2/src/styletts2/tts.py
Lines 186 to 195 in 350b888
Have tested with dividing the predicted duration by a value
StyleTTS2/src/styletts2/tts.py
Line 267 in 350b888
Adding speed=1
to the function and / speed
to the predicted duration gives the following duration of .wav files for various speeds. Speeds b/w .75 and 1.75 sound good but outside of that is rough.
duration = torch.sigmoid(duration).sum(axis=-1) / speed
def inference(self,
text: str,
target_voice_path=None,
output_wav_file=None,
output_sample_rate=24000,
alpha=0.3,
beta=0.7,
diffusion_steps=5,
embedding_scale=1,
speed=1,
ref_s=None):
Orange line is duration of the original clip divided by the speed parameter.
Blue line is the duration of the clip produced when the speed parameter was used.
Had to convert to mp4 to play on here:
test_0.50.1.mp4
test_0.67.mp4
test_0.83.mp4
test_1.00.mp4
test_1.17.mp4
test_1.33.mp4
test_1.50.mp4
test_1.67.mp4
test_1.83.mp4
test_2.00.mp4
And here is the code I ran to test that after adding in those changes:
import matplotlib.pyplot as plt
from styletts2 import tts
import numpy as np
import librosa
# No paths provided means default checkpoints/configs will be downloaded/cached.
my_tts = tts.StyleTTS2()
# Optionally create/write an output WAV file.
speed_range = np.linspace(0.5, 2, 10)
for speed in speed_range:
out = my_tts.inference(
"Hello there, I am now a python package.",
output_wav_file=f"test_{speed:.2f}.wav",
speed=speed,
)
# plot speed vs duration
durations = {}
for speed in speed_range:
duration = librosa.get_duration(path=f"test_{speed:.2f}.wav")
print(f"test_{speed:.2f}.wav: {duration:.2f}s")
durations[speed] = duration
# using 1 as default plot a perfect line by division
expected_durations = [durations[1] / speed for speed in speed_range]
plt.plot(speed_range, list(durations.values()), label="Actual")
plt.plot(speed_range, expected_durations, label="Expected")
plt.xlabel("Speed")
plt.ylabel("Duration")
plt.show()
This looks interesting! I might use it. Thanks.