Skip to content

Commit 5d18e2a

Browse files
committed
fa
1 parent db420cb commit 5d18e2a

File tree

9 files changed

+254
-11
lines changed

9 files changed

+254
-11
lines changed

batchalign/cli/cli.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def batchalign(ctx, verbose):
111111
@click.option("--wav2vec/--whisper_fa",
112112
default=True, help="Use Whisper instead of Wav2Vec for English (defaults for Whisper for non-English)")
113113
@click.option("--iic", is_flag=True, default=False, help="Use IIC forced alignment (for Chinese).")
114+
@click.option("--wav2vec_yue", is_flag=True, default=False, help="Use Wav2Vec with chantonese fixes forced alignment (for Chinese).")
114115
@click.option("--tencent/--rev",
115116
default=False, help="Use Tencent instead of Rev.AI (default).")
116117
@click.option("--funaudio/--rev",
@@ -119,7 +120,7 @@ def batchalign(ctx, verbose):
119120
@click.option("--wor/--nowor",
120121
default=True, help="Should we write word level alignment line? Default to yes.")
121122
@click.pass_context
122-
def align(ctx, in_dir, out_dir, whisper, wav2vec, iic, tencent, funaudio, **kwargs):
123+
def align(ctx, in_dir, out_dir, whisper, wav2vec, iic, wav2vec_yue, tencent, funaudio, **kwargs):
123124
"""Align transcripts against corresponding media files."""
124125
def loader(file):
125126
return (
@@ -133,6 +134,8 @@ def writer(doc, output):
133134
# Determine FA engine
134135
if iic:
135136
fa_engine = "iic_fa"
137+
elif wav2vec_yue:
138+
fa_engine = "wav2vec_fa_canto"
136139
elif not wav2vec:
137140
fa_engine = "whisper_fa"
138141
else:

batchalign/pipelines/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .cleanup import NgramRetraceEngine, DisfluencyReplacementEngine
88
from .speaker import NemoSpeakerEngine
99

10-
from .fa import WhisperFAEngine, Wave2VecFAEngine, IICFAEngine
10+
from .fa import WhisperFAEngine, Wave2VecFAEngine, IICFAEngine, Wave2VecFAEngineCantonese
1111
from .utr import WhisperUTREngine, RevUTREngine, TencentUTREngine, FunAudioUTREngine
1212

1313
from .analysis import EvaluationEngine

batchalign/pipelines/dispatch.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
StanzaUtteranceEngine, CorefEngine, Wave2VecFAEngine, TencentEngine,
1010
OAIWhisperEngine, TencentUTREngine, AliyunEngine, FunAudioEngine,
1111
FunAudioUTREngine, SeamlessTranslationModel, GoogleTranslateEngine,
12-
OAIWhisperEngine, PyannoteEngine, IICFAEngine)
12+
OAIWhisperEngine, PyannoteEngine, IICFAEngine, Wave2VecFAEngineCantonese)
1313

1414
from batchalign import BatchalignPipeline
1515
from batchalign.models import resolve
@@ -135,6 +135,8 @@ def dispatch_pipeline(pkg_str, lang, num_speakers=None, **arg_overrides):
135135
engines.append(CorefEngine())
136136
elif engine == "wav2vec_fa":
137137
engines.append(Wave2VecFAEngine())
138+
elif engine == "wav2vec_fa_canto":
139+
engines.append(Wave2VecFAEngineCantonese())
138140
elif engine == "iic_fa":
139141
engines.append(IICFAEngine())
140142
elif engine == "seamless_translate":
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .whisper_fa import WhisperFAEngine
22
from .wave2vec_fa import Wave2VecFAEngine
3+
from .wave2vec_fa_canto import Wave2VecFAEngineCantonese
34
from .iic_fa import IICFAEngine
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
from batchalign.models import Wave2VecFAModel
2+
from batchalign.document import *
3+
from batchalign.pipelines.base import *
4+
from batchalign.utils import *
5+
from batchalign.utils.dp import *
6+
from batchalign.constants import *
7+
8+
import pycantonese as pc
9+
10+
import logging
11+
L = logging.getLogger("batchalign")
12+
13+
import re
14+
15+
import pycountry
16+
import warnings
17+
18+
19+
20+
class Wave2VecFAEngineCantonese(BatchalignEngine):
21+
tasks = [ Task.FORCED_ALIGNMENT ]
22+
23+
@staticmethod
24+
def cantonese_to_mms_chars(text: str) -> str:
25+
pairs = pc.characters_to_jyutping(text)
26+
27+
try:
28+
jyut = " ".join(j for _, j in pairs)
29+
except TypeError:
30+
return text
31+
jyut = re.sub(r"[1-6]", "", jyut)
32+
jyut = re.sub(r"\s+", "'", jyut).strip("'")
33+
return jyut
34+
35+
def _hook_status(self, status_hook):
36+
self.status_hook = status_hook
37+
38+
def __init__(self):
39+
self.status_hook = None
40+
self.__wav2vec = Wave2VecFAModel()
41+
42+
def process(self, doc:Document, **kwargs):
43+
# check that the document has a media path to align to
44+
assert doc.media != None and doc.media.url != None, f"We cannot forced-align something that doesn't have a media path! Provided media tier='{doc.media}'"
45+
assert "yue" in doc.langs, "Please use normal wav2vec to align non-cantonese speech."
46+
47+
# load the audio file
48+
L.debug(f"Wave2Vec FA is loading url {doc.media.url}...")
49+
f = self.__wav2vec.load(doc.media.url)
50+
L.debug(f"Wav2Vec FA finished loading media.")
51+
52+
# collect utterances 30 secondish segments to be aligned for whisper
53+
# we have to do this because whisper does poorly with very short segments
54+
groups = []
55+
group = []
56+
seg_start = 0
57+
58+
L.debug(f"Wav2Vec FA finished loading media.")
59+
60+
for i in doc.content:
61+
if not isinstance(i, Utterance):
62+
continue
63+
if i.alignment is None:
64+
warnings.warn("We found at least one utterance without utterance-level alignment; this is usually not an issue, but if the entire transcript is unaligned, it means that utterance level timing recovery (which is fuzzy using ASR) failed due to the audio clarity. On this transcript, before running forced-alignment, please supply utterance-level links.")
65+
continue
66+
67+
# pop the previous group onto the stack
68+
if (i.alignment[-1] - seg_start) > 15*1000:
69+
groups.append(group)
70+
group = []
71+
seg_start = i.alignment[0]
72+
73+
# append the contents to the running group
74+
for word in i.content:
75+
group.append((word, i.alignment))
76+
77+
groups.append(group)
78+
79+
L.debug(f"Begin Wav2Vec Inference...")
80+
81+
for indx, grp in enumerate(groups):
82+
L.info(f"Wave2Vec FA processing segment {indx+1}/{len(groups)}...")
83+
if self.status_hook != None:
84+
self.status_hook(indx+1, len(groups))
85+
86+
# perform alignment
87+
# we take a 2 second buffer in each direction
88+
try:
89+
transcript = [word[0].text for word in grp]
90+
# replace ANY punctuation
91+
for p in MOR_PUNCT + ENDING_PUNCT:
92+
transcript = [i.replace("_", " ") for i in transcript if i.strip() != p]
93+
94+
transcript_mms = {r:self.cantonese_to_mms_chars(r) for r in transcript}
95+
transcript_mms_rev = {v:k for k,v in transcript_mms.items()}
96+
# if "noone's" in detokenized:
97+
# breakpoint()
98+
res = self.__wav2vec(
99+
audio=f.chunk(grp[0][1][0], grp[-1][1][1]),
100+
text=[transcript_mms[r] for r in transcript]
101+
)
102+
except:
103+
# utterance contains nothing
104+
continue
105+
106+
# create reference backplates, which are the word ids to set the timing for
107+
ref_targets = []
108+
for indx, (word, _) in enumerate(grp):
109+
for char in word.text:
110+
ref_targets.append(ReferenceTarget(char, payload=indx))
111+
# create target backplates for the timings
112+
payload_targets = []
113+
timings = []
114+
try:
115+
for indx, (word, time) in enumerate(res):
116+
timings.append(time)
117+
for char in transcript_mms_rev[word]:
118+
payload_targets.append(PayloadTarget(char, payload=indx))
119+
except:
120+
continue
121+
# alignment!
122+
alignments = align(payload_targets, ref_targets, tqdm=False)
123+
124+
# set the ids back to the text ids
125+
# we do this BACKWARDS because we went to have the first timestamp
126+
# we get about a word first
127+
alignments.reverse()
128+
for indx,elem in enumerate(alignments):
129+
if isinstance(elem, Match):
130+
grp[elem.reference_payload][0].time = (int(round((timings[elem.payload][0] +
131+
grp[0][1][0]))),
132+
int(round((timings[elem.payload][1] +
133+
grp[0][1][0]))))
134+
135+
L.debug(f"Correcting text...")
136+
137+
# we now set the end alignment of each word to the start of the next
138+
for doc_ut, ut in enumerate(doc.content):
139+
if not isinstance(ut, Utterance):
140+
continue
141+
142+
# correct each word by bumping it forward
143+
# and if its not a word we remove the timing
144+
for indx, w in enumerate(ut.content):
145+
if w.type in [TokenType.PUNCT, TokenType.FEAT, TokenType.ANNOT]:
146+
w.time = None
147+
elif indx == len(ut.content)-1 and w.text in ENDING_PUNCT:
148+
w.time = None
149+
elif indx != len(ut.content)-1:
150+
# search forward for the next compatible time
151+
tmp = indx+1
152+
while tmp < len(ut.content)-1 and ut.content[tmp].time == None:
153+
tmp += 1
154+
if w.time == None:
155+
continue
156+
if ut.content[tmp].time == None:
157+
# seek forward one utterance to find their start time
158+
next_ut = doc_ut + 1
159+
while next_ut < len(doc.content)-1 and (not isinstance(doc.content, Utterance) or doc.content[next_ut].alignment == None):
160+
next_ut += 1
161+
if next_ut < len(doc.content) and isinstance(doc.content, Utterance) and doc.content[next_ut].alignment:
162+
w.time = (w.time[0], doc.content[next_ut].alignment[0])
163+
else:
164+
w.time = (w.time[0], w.time[0]+500) # give half a second because we don't know
165+
166+
# just in case, bound the time by the utterance derived timings
167+
if ut.alignment and ut.alignment[0] != None:
168+
w.time = (max(w.time[0], ut.alignment[0]), min(w.time[1], ut.alignment[1]))
169+
# if we ended up with timings that don't make sense, drop it
170+
if w.time and w.time[0] >= w.time[1]:
171+
w.time = None
172+
173+
# clear any built-in timing (i.e. we should use utterance-derived timing)
174+
ut.time = None
175+
# correct the text
176+
if ut.alignment and ut.text != None:
177+
if '\x15' not in ut.text:
178+
ut.text = (ut.text+f" \x15{ut.alignment[0]}_{ut.alignment[1]}\x15").strip()
179+
else:
180+
ut.text = re.sub(r"\x15\d+_\d+\x15",
181+
f"\x15{ut.alignment[0]}_{ut.alignment[1]}\x15", ut.text).strip()
182+
elif ut.text != None:
183+
ut.text = re.sub(r"\x15\d+_\d+\x15", f"", ut.text).strip()
184+
185+
return doc

batchalign/pipelines/utr/tencent_utr.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ def __init__(self, key:str=None, lang="eng"):
8080
self.__client = AsrClient(cred, "ap-hongkong")
8181

8282

83-
def replace_cantonese_words(self, word):
83+
@staticmethod
84+
def replace_cantonese_words(word):
8485
"""Function to replace Cantonese words with custom replacements."""
8586
word_replacements = {
8687
"系": "係",
@@ -194,7 +195,7 @@ def process(self, doc, **kwargs):
194195
roman_cache_end = i.StartMs
195196
for j in i.Words:
196197
word = j.Word
197-
if self.__lang == "yue":
198+
if lang == "yue":
198199
word = cc.convert(word)
199200

200201
word = self.replace_cantonese_words(word)

batchalign/version

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
0.7.22-post.34
2-
December 30th, 2025
3-
CA
1+
0.7.23
2+
Janurary 4th, 2026
3+
Cantonese FA

scratchpad.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,57 @@
1616

1717
########
1818

19+
# from batchalign import *
20+
# from batchalign.formats.chat.parser import chat_parse_utterance
21+
22+
# # !uv pip install pycantonese
23+
# doc = CHATFile(path="/Users/houjun/Documents/Projects/talkbank-alignment/cantonese/input/030021s.cha").doc
24+
25+
# pipe = BatchalignPipeline(TencentUTREngine(), Wave2VecFAEngineCantonese())
26+
# doc = pipe(doc)
27+
28+
# # [i.alignment for i in doc.content if isinstance(i, Utterance)]
29+
# # [i.time for i in doc if isinstance(i, Ut]
30+
31+
32+
# # res = pipe(doc)
33+
# CHATFile(doc=doc).write("/Users/houjun/Documents/Projects/talkbank-alignment/cantonese/input/030021s.out.cha")
34+
35+
# # f = Wave2VecFAEngineCantonese()._Wave2VecFAEngineCantonese__wav2vec.load(doc.media.url)
36+
# # f
37+
# # f.tensor.size(0)//f.rate
38+
# # 1+1
39+
# # 1+1
40+
41+
42+
43+
44+
45+
46+
# pipe = BatchalignPipeline(Wave2VecFAEngine())
47+
# pipe
48+
# res
49+
# audio =
50+
# model = Wave2VecFAModel()
51+
# audio = model.load(doc.media.url)
52+
# 1+1
53+
# text = str(res[12]).split(".")[0]
54+
# full_audio = audio
55+
# # res[12]
56+
# audio = audio.chunk(12460, 14610)
57+
# audio
58+
# res[6][0]
59+
# 1+1
60+
# f1 = '饭'
61+
# f2 = '飯'
62+
63+
# cc.convert(f1)
64+
65+
# doc.langs
66+
67+
# import ipdb
68+
# ipdb.set_trace()
1969

20-
from batchalign import *
21-
from batchalign.formats.chat.parser import chat_parse_utterance
2270

2371
# # 1+1
2472
# text = "Hello are you the f b i ?"
@@ -35,6 +83,8 @@
3583
# 1+1
3684

3785

86+
# !uv pip install pycantonese
87+
3888
# doc = CHATFile(path="../talkbank-alignment/input/011116.cha").doc
3989
# newdoc = Document(content=[doc[4]], langs=["heb"])
4090
# pipe = StanzaEngine()

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@ def read(fname):
7575
"pyannote.audio",
7676
"onnxruntime",
7777
"certifi>=2025.10.5",
78-
"regex"
78+
"regex",
79+
"pycantonese"
7980
],
8081
extras_require={
8182
'dev': [

0 commit comments

Comments
 (0)