Demystifying AI dubbing - Part 3: Text-To-Speech with F5-TTS

Demystifying AI dubbing - Part 3: Text-To-Speech with F5-TTS


In the first part of the series, we discussed the different parts involved in a modern dubbing pipeline:

  1. Speech To Text - Convert an audio source into text
  2. Translation - Translate text from the source language to the target language
  3. Text To Speech - Produce speech from the text translation in step 2
  4. Lip sync - Synchronize lip movements of characters in a video with an audio source

In the second part, we dove into the speech to text component of the dubbing pipeline. We understood how speech-to-text works by diving into a concrete example in Whisper.

In this part, we’ll discuss the next step. We’ll transform the text transcribed in the previous step to natural sounding speech, cloning the speaker’s voice and generating new audio from text.

As a reference to learn how a modern TTS system works, we’ll use F5-TTS, a state-of-the-art open-source Text-To-Speech model available at https://github.com/SWivid/F5-TTS.

F5-TTS

F5-TTS-Architecture

F5-TTS is a recently published Text-To-Speech model that can clone a speaker’s voice and, given a text prompt, produce natural-sounding audio from text.

The model is comprised of the following components:

  1. A Conditional Flow Matching (CFM) pipeline, used to generate audio (in the form of a mel-spectrogram) using the Diffusion Transformer.
  2. A Diffusion Transformer, used as part of the CFM to iteratively produce natural sounding audio (in the form of a mel-spectrogram) given the speaker audio and text
  3. A Vocoder, used to decode the final audio in the form of a mel-spectrogram to audio

The F5-TTS flow starts with the main function

def main():
    ...
    reg1 = r"(?=\[\w+\])"
    chunks = re.split(reg1, gen_text)
    ...
    generated_audio_segments = []
    ...
    for text in chunks:
        ...
        gen_text_ = text.strip()
        ...
        audio_segment, final_sample_rate, spectrogram = infer_process(
            ref_audio_,
            ref_text_,
            gen_text_,
            ema_model,
            vocoder,
            ...
        )
        generated_audio_segments.append(audio_segment)
        ...

    final_wave = np.concatenate(generated_audio_segments)
    ...
    with open(wave_path, "wb") as f:
        sf.write(f.name, final_wave, final_sample_rate)

The function receives (via global variables parsed from commandline arguments):

  1. gen_text - The text to generate
  2. ref_audio - A reference audio file of the speaker’s voice to clone
  3. ref_text - A reference text translation of the speaker’s voice in ref_audio

It then splits the text into chunks, runs them through infer_process, which generates audio, concatenates the audio segments back together and saves them to a new audio file.

So what does infer_process do?

def infer_process(
    ...
):
    ...
    # Split the input text into batches
    gen_text_batches = chunk_text(gen_text, max_chars=max_chars)

    return next(
        infer_batch_process(
            ...
            gen_text_batches,
            ...
        )
    )

It splits the text into sentences and calls infer_batch_process as a generator to generate audio per sentence (per call).

Let’s look into infer_batch_process, which is the main logic in this loop.

def infer_batch_process(
    ref_audio,
    ref_text,
    gen_text_batches,
    model_obj,
    vocoder,
    ...
):
    audio, sr = ref_audio
    if audio.shape[0] > 1:
        audio = torch.mean(audio, dim=0, keepdim=True)

    rms = torch.sqrt(torch.mean(torch.square(audio)))
    if rms < target_rms:
        audio = audio * target_rms / rms

    generated_waves = []
    spectrograms = []

    def process_batch(gen_text):
        ...
        with torch.inference_mode():
            generated, _ = model_obj.sample(
                cond=audio,
                text=final_text_list,
                duration=duration,
                steps=nfe_step,
                cfg_strength=cfg_strength,
                sway_sampling_coef=sway_sampling_coef,
            )

            generated = generated.to(torch.float32)  # generated mel spectrogram
            generated = generated[:, ref_audio_len:, :]
            generated = generated.permute(0, 2, 1)
            ...
            generated_wave = vocoder.decode(generated)

            if rms < target_rms:
                generated_wave = generated_wave * rms / target_rms

            # wav -> numpy
            generated_wave = generated_wave.squeeze().cpu().numpy()

            for j in range(0, len(generated_wave), chunk_size):
                yield generated_wave[j : j + chunk_size], target_sample_rate

    for gen_text in progress.tqdm(gen_text_batches) if progress is not None else gen_text_batches:
        for chunk in process_batch(gen_text):
            yield chunk

This loop is responsible for the following:

  1. Iterates over gen_text_batches, each element of which is a sentence.
  2. Runs model_obj.sample on each sentence, generating a mel-spectrogram using a Continuous Flow Matching pipeline.
  3. Runs a vocoder on the mel-spectrogram, generating a waveform.
  4. Each waveform is split into chunk_size pieces and yielded one-by-one.

infer_batch_process is basically a generator yielding audio chunks lazily.

It’s now clear that the main parts responsible for the core logic of F5-TTS are the CFM and the vocoder. So let’s dive in.

Continuous Flow Matching

Continuous Flow Matching is a method used to sample from probability distributions (e.g., generating audio or images) by learning to predict the final output in one step, given a noise input.

This is in contrast to Denoising Diffusion Probabilistic Models (DDPMs), which require multiple (usually hundreds) of steps to generate a natural-looking image.

The high performance penalty of DDPMs is a major obstacle for building robust video and audio generative models, as both audio and video require generating a new frame or waveform many times to produce the full work.

In F5-TTS, the authors make use of CFM models to drastically improve the generation speed by training a CFM model, requiring significantly fewer steps for generating highly realistic audio.

The core logic for the CFM is implemented in the sample procedure, let’s see exactly how it works.

def sample(
    self,
    cond: float["b n d"] | float["b nw"],  # noqa: F722 # Reference Audio
    text: int["b nt"] | list[str],  # noqa: F722 # Text to generate
    duration: int | int["b"],  # noqa: F821
    *,
    lens: int["b"] | None = None,  # noqa: F821
    ...
):
    ...
    batch, cond_seq_len, device = *cond.shape[:2], cond.device
    ...
    cond_mask = lens_to_mask(lens)
    ...
    step_cond = torch.where(
        cond_mask, cond, torch.zeros_like(cond)
    )  # allow direct control (cut cond audio) with lens passed in

    if batch > 1:
        mask = lens_to_mask(duration)
    else:  # save memory and speed up, as single inference need no mask currently
        mask = None

    def fn(t, x):
        ...
        # predict flow (cond and uncond), for classifier-free guidance
        pred_cfg = self.transformer(
            x=x,
            cond=step_cond,
            text=text,
            time=t,
            mask=mask,
            cfg_infer=True,
            cache=True,
        )
        pred, null_pred = torch.chunk(pred_cfg, 2, dim=0)
        return pred + (pred - null_pred) * cfg_strength

    # noise input
    # to make sure batch inference result is same with different batch size, and for sure single inference
    # still some difference maybe due to convolutional layers
    y0 = []
    for dur in duration:
        ...
        y0.append(torch.randn(dur, self.num_channels, device=self.device, dtype=step_cond.dtype))
    ...
    t = torch.linspace(t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype)
    ...
    t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t)

    trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
    ...
    sampled = trajectory[-1]
    out = sampled
    ...
    if exists(vocoder):
        ...
        out = vocoder(out)

    return out, trajectory

The sample function proceeds as follows:

  1. Generate random noise as the initial state for the model.
  2. Generate uniform timestamps between 0 and 1 to represent the progression from noise to clean audio.
  3. Call fn, a function that predicts the flow vector field at each timestep using the Diffusion Transformer, conditioned on:
  • speaker audio (reference)
  • text to generate
  • timestep
  • current noisy state
  1. Use ODE numerical integration to iteratively transform the random noise into a clean mel-spectrogram at timestep 1.0 (the final encoded audio representation).
  2. Run the mel-spectrogram through a vocoder (vocos: closing the gap between time-domain and fourier-based neural vocoders for high-quality audio synthesis) to decode it into a wave output, producing the final audio waveform

The function does a few interesting things worth noting. First, the Diffusion Transformer uses classifier-free guidance to produce “noise predictions” (i.e., the flow vector field). This works by running the model twice - once conditioned on the speaker audio and text, and once unconditioned. The two outputs are combined using a weighted difference controlled by cfg_strength. This effectively “guides” the model towards producing outputs more aligned with the conditioning, improving fidelity without requiring a separate classifier.

Second, F5-TTS uses sway sampling to slightly perturb the uniform timesteps in the ODE solver. Instead of strictly linear interpolation between timesteps, it adds a cosine-shaped offset. This helps the model avoid getting stuck in local minima during integration and can improve diversity and quality of generated outputs, especially in regions where the vector field changes rapidly.

Finally, the usage of ODE integration (via torchdiffeq) provides a major advantage over DDPMs. Rather than requiring hundreds of forward passes to iteratively denoise (as in diffusion models), the ODE solver can integrate across timesteps in far fewer steps. This drastically reduces generation time while maintaining high quality. Essentially, it treats the generation process as solving a continuous-time differential equation from random noise to clean audio.

The bulk of the remaining logic hides inside the Diffusion Transformer, so how does this transformer predict the flow map? Let’s have a look.

Diffusion Transformer

Diffusion Transformer Architecture

This definitely warrants an entire blog to cover in-depth, but for our purposes, a brief review of the Diffusion Transformer architecture, initially presented in Scalable Diffusion Models with Transformers used for generating the final mel-spectrogram will suffice.

class DiT(nn.Module):
    def __init__(
        ...
    ):
        self.time_embed = TimestepEmbedding(dim)
        ...
        self.text_embed = TextEmbedding(
            text_num_embeds, text_dim, mask_padding=text_mask_padding, conv_layers=conv_layers
        )
        self.input_embed = InputEmbedding(mel_dim, text_dim, dim)

        self.rotary_embed = RotaryEmbedding(dim_head)
        ...
        self.transformer_blocks = nn.ModuleList(
            [
                DiTBlock(
                    ...
                )
                for _ in range(depth)
            ]
        )
        ...
        self.norm_out = AdaLayerNorm_Final(dim)  # final modulation
        self.proj_out = nn.Linear(dim, mel_dim)
    ...
    def forward(
        self,
        x: float["b n d"],  # nosied input audio  # noqa: F722
        cond: float["b n d"],  # masked cond audio  # noqa: F722
        text: int["b nt"],  # text  # noqa: F722
        time: float["b"] | float[""],  # time step  # noqa: F821 F722
        mask: bool["b n"] | None = None,  # noqa: F722
        drop_audio_cond: bool = False,  # cfg for cond audio
        drop_text: bool = False,  # cfg for text
        ...
        cache: bool = False,
    ):
        batch, seq_len = x.shape[0], x.shape[1]
        if time.ndim == 0:
            time = time.repeat(batch)

        # 1. Timestamp Embedding
        # t: conditioning time, text: text, x: noised audio + cond audio + text
        t = self.time_embed(time)

        # 2. Classifier Free Guidance
        x_cond = self.get_input_embed(x, cond, text, drop_audio_cond=False, drop_text=False, cache=cache)
        x_uncond = self.get_input_embed(x, cond, text, drop_audio_cond=True, drop_text=True, cache=cache)
        x = torch.cat((x_cond, x_uncond), dim=0)
        t = torch.cat((t, t), dim=0)
        mask = torch.cat((mask, mask), dim=0) if mask is not None else None
        ...
        # 3. Transformer blocks
        for block in self.transformer_blocks:
            x = block(x, t, mask=mask, rope=rope)

        # Layer norm
        x = self.norm_out(x, t)
        # Linear transformation to make the output fit the expected output dimensions
        output = self.proj_out(x)

        return output

The Diffusion Transformer forward function works as follows:

  1. The current timestamp is embedded by the timestamp embedding layer
# t: conditioning time, text: text, x: noised audio + cond audio + text
t = self.time_embed(time)
Which is basically a wrapper around a Sinusoidal Position Embedding (followed by a couple of feedforward layers mixed with a non-linear activation)

class SinusPositionEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, x, scale=1000):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
        emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

class TimestepEmbedding(nn.Module):
    def __init__(self, dim, freq_embed_dim=256):
        super().__init__()
        self.time_embed = SinusPositionEmbedding(freq_embed_dim)
        self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))

    def forward(self, timestep: float["b"]):
        time_hidden = self.time_embed(timestep)
        time_hidden = time_hidden.to(timestep.dtype)
        time = self.time_mlp(time_hidden)  # b d
        return time
  1. Classifier-free guidance is performed using the input (Conditional and unconditional embeddings are calculated for the input)
# 2. Classifier Free Guidance
x_cond = self.get_input_embed(x, cond, text, drop_audio_cond=False, drop_text=False)
x_uncond = self.get_input_embed(x, cond, text, drop_audio_cond=True, drop_text=True)
x = torch.cat((x_cond, x_uncond), dim=0)
t = torch.cat((t, t), dim=0)
mask = torch.cat((mask, mask), dim=0) if mask is not None else None

Which works as follows:

def get_input_embed(
    self,
    x,  # b n d
    cond,  # b n d
    text,  # b nt
    drop_audio_cond: bool = False,
    drop_text: bool = False,
    cache: bool = True,
):
    seq_len = x.shape[1]
    text_embed = self.text_embed(text, seq_len, drop_text=drop_text)

    x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)

    return x

The text is embedded using an embedding layer, positional encoding is added, and it is run through ConvNeXt V2 blocks. It is then passed, together with the noise (x) and speaker audio (cond) into the input_embed function, which runs it through a linear layer and a convolutional position embedding (adding a skip connection)

class TextEmbedding(nn.Module):
    def __init__(self, text_num_embeds, text_dim, mask_padding=True, conv_layers=0, conv_mult=2):
        super().__init__()
        self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim)  # use 0 as filler token

        self.mask_padding = mask_padding  # mask filler and batch padding tokens or not

        if conv_layers > 0:
            self.extra_modeling = True
            self.precompute_max_pos = 4096  # ~44s of 24khz audio
            self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
            self.text_blocks = nn.Sequential(
                *[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
            )
        ...

    def forward(self, text: int["b nt"], seq_len, drop_text=False):  # noqa: F722
        text = text + 1  # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
        text = text[:, :seq_len]  # curtail if character tokens are more than the mel spec tokens
        batch, text_len = text.shape[0], text.shape[1]
        text = F.pad(text, (0, seq_len - text_len), value=0)
        if self.mask_padding:
            text_mask = text == 0

        if drop_text:  # cfg for text
            text = torch.zeros_like(text)

        text = self.text_embed(text)  # b n -> b n d

        # possible extra modeling
        if self.extra_modeling:
            # sinus pos emb
            batch_start = torch.zeros((batch,), dtype=torch.long)
            pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
            text_pos_embed = self.freqs_cis[pos_idx]
            text = text + text_pos_embed

            # convnextv2 blocks
            ...
            text = self.text_blocks(text)

        return text

class InputEmbedding(nn.Module):
    def __init__(self, mel_dim, text_dim, out_dim):
        super().__init__()
        self.proj = nn.Linear(mel_dim * 2 + text_dim, out_dim)
        self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)

    def forward(self, x: float["b n d"], cond: float["b n d"], text_embed: float["b n d"], drop_audio_cond=False):  # noqa: F722
        if drop_audio_cond:  # cfg for cond audio
            cond = torch.zeros_like(cond)

        x = self.proj(torch.cat((x, cond, text_embed), dim=-1))
        x = self.conv_pos_embed(x) + x
        return x
  1. The Transformer blocks are sequentially executed by passing the output of each layer as the input of the next layer.
# 3. Transformer blocks
for block in self.transformer_blocks:
    x = block(x, t, mask=mask, rope=rope)

Finally, the output is normalized and projected to the expected dimensions.

# Layer norm
x = self.norm_out(x, t)
# Linear transformation to make the output fit the expected output dimensions
output = self.proj_out(x)

return output

Wrapping up

In this post, we explored how F5-TTS bridges the gap between text and natural-sounding audio. We dissected its core components - Conditional Flow Matching (CFM) for efficient generation and the Diffusion Transformer for modeling complex audio-text relationships.

By leveraging ODE-based integration and innovations like classifier-free guidance and sway sampling, F5-TTS achieves high-quality results in significantly fewer steps than traditional diffusion models - an essential feature for real-time applications like AI dubbing.

But our journey doesn’t stop at generating audio. In the next post, we’ll dive into LatentSync and see how we synchronize the generated speech with video, creating new facial expressions and lip movements that match the dubbed audio naturally.