MaskSR
Publication Date: 4 June 2024, link
Distorted Speech Encoder:
using STFT with window_size:
import torch
audio = torch.randn(1, 3 * 44100) # 3 seconds with sample rate 44.1kHz
y = torch.stft(audio, 2048, 512, return_complex = True)
print(y.shape) # [F, T]
w = torch.abs(y) # [F, T]
w = w.pow(0.3) # X_0.3 compressed power spectrogram with shape [F,T]
print(w.shape)
Cosine Scheduler
During training, the masking ratio (from 0 to 1) will be applied to \( 9 \times T \) codebook by sampling \( r \sim \mathcal{U}(0, 1) \),
The cosine schedule function should be \( f(r) = cos(\frac{\pi}{2}r) \), this makes sure that \(f(0) = 1\) and \(f(1) = 0\).
The cross entropy loss is calculated on the masking position.
Classifier-free Guidance
The Conditional logit \( l_g \) is used to measure \( P(x \mid c )\) where c is the prompt. In this case, c will be the speech encoder input.
The unconditional logit \(l_u\) is used to measure \( P(x \mid c )\) where the model predicts the output without any input. During inference, we just replace the entire codebook with a embedding (null embedding) repeated \(T\) times, and predict the output and use that as the \(l_u\) score. The logit score for inference is defined as \(l_g = (1 + w) l_c - w l_u\).
During training, for each epoch, we randomly select 10% of the training data to be replaced with the null embedding.