sidharthrajaram/StyleTTS2

add speed option to inference

Opened this issue · 1 comments

Add a speed option to the infer function

def inference(self,
text: str,
target_voice_path=None,
output_wav_file=None,
output_sample_rate=24000,
alpha=0.3,
beta=0.7,
diffusion_steps=5,
embedding_scale=1,
ref_s=None):

Have tested with dividing the predicted duration by a value

duration = torch.sigmoid(duration).sum(axis=-1)

Adding speed=1 to the function and / speed to the predicted duration gives the following duration of .wav files for various speeds. Speeds b/w .75 and 1.75 sound good but outside of that is rough.

duration = torch.sigmoid(duration).sum(axis=-1) / speed
    def inference(self,
                  text: str,
                  target_voice_path=None,
                  output_wav_file=None,
                  output_sample_rate=24000,
                  alpha=0.3,
                  beta=0.7,
                  diffusion_steps=5,
                  embedding_scale=1,
                  speed=1,
                  ref_s=None):

image
Orange line is duration of the original clip divided by the speed parameter.
Blue line is the duration of the clip produced when the speed parameter was used.

Had to convert to mp4 to play on here:

test_0.50.1.mp4
test_0.67.mp4
test_0.83.mp4
test_1.00.mp4
test_1.17.mp4
test_1.33.mp4
test_1.50.mp4
test_1.67.mp4
test_1.83.mp4
test_2.00.mp4

And here is the code I ran to test that after adding in those changes:

import matplotlib.pyplot as plt
from styletts2 import tts
import numpy as np
import librosa

# No paths provided means default checkpoints/configs will be downloaded/cached.
my_tts = tts.StyleTTS2()

# Optionally create/write an output WAV file.
speed_range = np.linspace(0.5, 2, 10)
for speed in speed_range:
    out = my_tts.inference(
        "Hello there, I am now a python package.",
        output_wav_file=f"test_{speed:.2f}.wav",
        speed=speed,
    )

# plot speed vs duration
durations = {}
for speed in speed_range:
    duration = librosa.get_duration(path=f"test_{speed:.2f}.wav")
    print(f"test_{speed:.2f}.wav: {duration:.2f}s")
    durations[speed] = duration


# using 1 as default plot a perfect line by division
expected_durations = [durations[1] / speed for speed in speed_range]
plt.plot(speed_range, list(durations.values()), label="Actual")
plt.plot(speed_range, expected_durations, label="Expected")
plt.xlabel("Speed")
plt.ylabel("Duration")
plt.show()

This looks interesting! I might use it. Thanks.