@@ -43,7 +43,7 @@ def __init__(
4343
4444 if conv_layers > 0 :
4545 self .extra_modeling = True
46- self .precompute_max_pos = 4096 # ~44s of 24khz audio
46+ self .precompute_max_pos = 8192 # 8192 is ~87.38s of 24khz audio; 4096 is ~43.69s of 24khz audio
4747 self .register_buffer ("freqs_cis" , precompute_freqs_cis (text_dim , self .precompute_max_pos ), persistent = False )
4848 self .text_blocks = nn .Sequential (
4949 * [ConvNeXtV2Block (text_dim , text_dim * conv_mult ) for _ in range (conv_layers )]
@@ -53,32 +53,33 @@ def __init__(
5353
5454 def average_upsample_text_by_mask (self , text , text_mask ):
5555 batch , text_len , text_dim = text .shape
56- assert batch == 1
5756
58- valid_mask = text_mask [0 ]
59- audio_len = text_len
60- valid_len = valid_mask .sum ().item ()
61-
62- if valid_len == 0 :
63- return torch .zeros_like (text )
57+ audio_len = text_len # cuz text already padded to same length as audio sequence
58+ text_lens = text_mask .sum (dim = 1 ) # [batch]
6459
6560 upsampled_text = torch .zeros_like (text )
6661
67- valid_ind = torch .where (valid_mask )[0 ]
68- valid_data = text [0 , valid_ind , :] # [valid_len, text_dim]
62+ for i in range (batch ):
63+ text_len = text_lens [i ].item ()
64+
65+ if text_len == 0 :
66+ continue
67+
68+ valid_ind = torch .where (text_mask [i ])[0 ]
69+ valid_data = text [i , valid_ind , :] # [text_len, text_dim]
6970
70- base_repeat = audio_len // valid_len
71- remainder = audio_len % valid_len
71+ base_repeat = audio_len // text_len
72+ remainder = audio_len % text_len
7273
73- indices = []
74- for j in range (valid_len ):
75- repeat_count = base_repeat + (1 if j >= valid_len - remainder else 0 )
76- indices .extend ([j ] * repeat_count )
74+ indices = []
75+ for j in range (text_len ):
76+ repeat_count = base_repeat + (1 if j >= text_len - remainder else 0 )
77+ indices .extend ([j ] * repeat_count )
7778
78- indices = torch .tensor (indices [:audio_len ], device = text .device , dtype = torch .long )
79- upsampled = valid_data [indices ] # [audio_len, text_dim]
79+ indices = torch .tensor (indices [:audio_len ], device = text .device , dtype = torch .long )
80+ upsampled = valid_data [indices ] # [audio_len, text_dim]
8081
81- upsampled_text [0 , :audio_len , :] = upsampled
82+ upsampled_text [i , :audio_len , :] = upsampled
8283
8384 return upsampled_text
8485
0 commit comments