Skip to content

Commit 769c55b

Browse files
committed
wav2vec support is added
1 parent 184dabc commit 769c55b

File tree

7 files changed

+47
-26
lines changed

7 files changed

+47
-26
lines changed

batchalign/cli/cli.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,12 @@ def batchalign(ctx, verbose):
107107
@common_options
108108
@click.option("--whisper/--rev",
109109
default=False, help="For utterance timing recovery, OpenAI Whisper (ASR) instead of Rev.AI (default).")
110+
@click.option("--wav2vec/--whisper_fa",
111+
default=False, help="Use Whisper instead of Wav2Vec for English (defaults for Whisper for non-English)")
110112
@click.option("--pauses", type=bool, default=False, help="Should we try to bullet each word or should we try to add pauses in between words by grouping them? Default: no pauses.", is_flag=True)
111113

112114
@click.pass_context
113-
def align(ctx, in_dir, out_dir, whisper, **kwargs):
115+
def align(ctx, in_dir, out_dir, whisper, wav2vec, **kwargs):
114116
"""Align transcripts against corresponding media files."""
115117
def loader(file):
116118
return (
@@ -121,12 +123,22 @@ def loader(file):
121123
def writer(doc, output):
122124
CHATFile(doc=doc).write(output)
123125

124-
_dispatch("align", "eng", 1,
125-
["cha"], ctx,
126-
in_dir, out_dir,
127-
loader, writer, C,
128-
utr="whisper_utr" if whisper else "rev_utr",
129-
**kwargs)
126+
if not wav2vec:
127+
_dispatch("align", "eng", 1,
128+
["cha"], ctx,
129+
in_dir, out_dir,
130+
loader, writer, C,
131+
fa="whisper_fa",
132+
utr="whisper_utr" if whisper else "rev_utr",
133+
**kwargs)
134+
else:
135+
_dispatch("align", "eng", 1,
136+
["cha"], ctx,
137+
in_dir, out_dir,
138+
loader, writer, C,
139+
fa="wav2vec_fa",
140+
utr="whisper_utr" if whisper else "rev_utr",
141+
**kwargs)
130142

131143
#################### TRANSCRIBE ################################
132144

batchalign/pipelines/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from .cleanup import NgramRetraceEngine, DisfluencyReplacementEngine
77
from .speaker import NemoSpeakerEngine
88

9-
from .fa import WhisperFAEngine
9+
from .fa import WhisperFAEngine, Wave2VecFAEngine
1010
from .utr import WhisperUTREngine, RevUTREngine
1111

1212
from .analysis import EvaluationEngine

batchalign/pipelines/dispatch.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from batchalign import (WhisperEngine, WhisperFAEngine, StanzaEngine, RevEngine,
77
NgramRetraceEngine, DisfluencyReplacementEngine, WhisperUTREngine,
88
RevUTREngine, EvaluationEngine, WhisperXEngine, NemoSpeakerEngine,
9-
StanzaUtteranceEngine, CorefEngine)
9+
StanzaUtteranceEngine, CorefEngine, Wave2VecFAEngine)
1010
from batchalign import BatchalignPipeline
1111
from batchalign.models import resolve
1212

@@ -127,7 +127,8 @@ def dispatch_pipeline(pkg_str, lang, num_speakers=None, **arg_overrides):
127127
engines.append(StanzaUtteranceEngine())
128128
elif engine == "stanza_coref":
129129
engines.append(CorefEngine())
130-
130+
elif engine == "wav2vec_fa":
131+
engines.append(Wave2VecFAEngine())
131132

132133
L.debug(f"Done initalizing packages.")
133134
return BatchalignPipeline(*engines)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from .whisper_fa import WhisperFAEngine
2+
from .wave2vec_fa import Wave2VecFAEngine

batchalign/pipelines/fa/wave2vec_fa.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,21 @@ def process(self, doc:Document, **kwargs):
2727
# check that the document has a media path to align to
2828
assert doc.media != None and doc.media.url != None, f"We cannot forced-align something that doesn't have a media path! Provided media tier='{doc.media}'"
2929

30+
if doc.langs[0] != "eng":
31+
warnings.warn("Looks like you are not aligning English with wav2vec; this works for a lot of Roman languages, but outside of that your milage may vary.")
32+
3033
# load the audio file
3134
L.debug(f"Wave2Vec FA is loading url {doc.media.url}...")
3235
f = self.__wav2vec.load(doc.media.url)
33-
L.debug(f"Wave2Vec FA finished loading media.")
36+
L.debug(f"Wav2Vec FA finished loading media.")
3437

3538
# collect utterances 30 secondish segments to be aligned for whisper
3639
# we have to do this because whisper does poorly with very short segments
3740
groups = []
3841
group = []
3942
seg_start = 0
4043

41-
L.debug(f"Wave2Vec FA finished loading media.")
44+
L.debug(f"Wav2Vec FA finished loading media.")
4245

4346
for i in doc.content:
4447
if not isinstance(i, Utterance):
@@ -59,7 +62,7 @@ def process(self, doc:Document, **kwargs):
5962

6063
groups.append(group)
6164

62-
L.debug(f"Begin Whisper Inference...")
65+
L.debug(f"Begin Wav2Vec Inference...")
6366

6467
for indx, grp in enumerate(groups):
6568
L.info(f"Wave2Vec FA processing segment {indx+1}/{len(groups)}...")

batchalign/version

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
0.7.11-beta.3
2-
Feburary 2nd, 2025
3-
Incorporate additional pauses
1+
0.7.11-beta.4
2+
Feburary 6nd, 2025
3+
Wav2vec support!

scratchpad.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@
4848

4949
# print(str(CHATFile(doc=doc)))
5050

51-
# doc = CHATFile(path="../talkbank-alignment/input/barry.cha").doc
51+
# doc = CHATFile(path="../talkbank-alignment/testing_playground_2/input/test.cha").doc
52+
# pipe = Wave2
53+
5254
# doc[3][0]
5355
# て
5456
# print(str(CHATFile(doc=res)))
@@ -99,6 +101,8 @@
99101
# ppe = pipeline
100102
# cha = CHATFile(path="../talkbank-alignment/testing_playground_2/input/test.cha")
101103
# doc = cha.doc
104+
# engine = Wave2VecFAEngine()
105+
# res = engine(doc)
102106

103107
# # print(str(CHATFile(doc=doc)))
104108
# result = ppe(doc)
@@ -263,15 +267,15 @@
263267
########### The Batchalign String Test Harness ###########
264268
# from batchalign.formats.chat.parser import chat_parse_utterance
265269

266-
file = "/Users/houjun/Documents/Projects/talkbank-alignment/input/spanish.mp3"
267-
function = "asr"
268-
lang = "spa"
269-
num_speakers = 1
270-
271-
ut = Document.new(media_path=file, lang=lang)
272-
pipeline = BatchalignPipeline.new(function, lang=lang, num_speakers=num_speakers, asr="rev")
273-
doc = pipeline(ut)
274-
doc
270+
# file = "/Users/houjun/Documents/Projects/talkbank-alignment/input/spanish.mp3"
271+
# function = "asr"
272+
# lang = "spa"
273+
# num_speakers = 1
274+
275+
# ut = Document.new(media_path=file, lang=lang)
276+
# pipeline = BatchalignPipeline.new(function, lang=lang, num_speakers=num_speakers, asr="rev")
277+
# doc = pipeline(ut)
278+
# doc
275279
# doc.content
276280
# # doc[0][-1]
277281
# # doc[0][-2].model_dump()

0 commit comments

Comments
 (0)