Skip to content

Commit 3eecd94

Browse files
committed
support back avg upsampling for batch, cover up non-mask case
1 parent d9a6945 commit 3eecd94

File tree

1 file changed

+20
-19
lines changed
  • src/f5_tts/model/backbones

1 file changed

+20
-19
lines changed

src/f5_tts/model/backbones/dit.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __init__(
4343

4444
if conv_layers > 0:
4545
self.extra_modeling = True
46-
self.precompute_max_pos = 4096 # ~44s of 24khz audio
46+
self.precompute_max_pos = 8192 # 8192 is ~87.38s of 24khz audio; 4096 is ~43.69s of 24khz audio
4747
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
4848
self.text_blocks = nn.Sequential(
4949
*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
@@ -53,32 +53,33 @@ def __init__(
5353

5454
def average_upsample_text_by_mask(self, text, text_mask):
5555
batch, text_len, text_dim = text.shape
56-
assert batch == 1
5756

58-
valid_mask = text_mask[0]
59-
audio_len = text_len
60-
valid_len = valid_mask.sum().item()
61-
62-
if valid_len == 0:
63-
return torch.zeros_like(text)
57+
audio_len = text_len # cuz text already padded to same length as audio sequence
58+
text_lens = text_mask.sum(dim=1) # [batch]
6459

6560
upsampled_text = torch.zeros_like(text)
6661

67-
valid_ind = torch.where(valid_mask)[0]
68-
valid_data = text[0, valid_ind, :] # [valid_len, text_dim]
62+
for i in range(batch):
63+
text_len = text_lens[i].item()
64+
65+
if text_len == 0:
66+
continue
67+
68+
valid_ind = torch.where(text_mask[i])[0]
69+
valid_data = text[i, valid_ind, :] # [text_len, text_dim]
6970

70-
base_repeat = audio_len // valid_len
71-
remainder = audio_len % valid_len
71+
base_repeat = audio_len // text_len
72+
remainder = audio_len % text_len
7273

73-
indices = []
74-
for j in range(valid_len):
75-
repeat_count = base_repeat + (1 if j >= valid_len - remainder else 0)
76-
indices.extend([j] * repeat_count)
74+
indices = []
75+
for j in range(text_len):
76+
repeat_count = base_repeat + (1 if j >= text_len - remainder else 0)
77+
indices.extend([j] * repeat_count)
7778

78-
indices = torch.tensor(indices[:audio_len], device=text.device, dtype=torch.long)
79-
upsampled = valid_data[indices] # [audio_len, text_dim]
79+
indices = torch.tensor(indices[:audio_len], device=text.device, dtype=torch.long)
80+
upsampled = valid_data[indices] # [audio_len, text_dim]
8081

81-
upsampled_text[0, :audio_len, :] = upsampled
82+
upsampled_text[i, :audio_len, :] = upsampled
8283

8384
return upsampled_text
8485

0 commit comments

Comments
 (0)