Skip to content

Commit c172ad3

Browse files
committed
[ci skip] machine translation
1 parent 3b1b4b2 commit c172ad3

File tree

12 files changed

+131
-22
lines changed

12 files changed

+131
-22
lines changed

batchalign/cli/cli.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,28 @@ def writer(doc, output):
196196
loader, writer, C,
197197
asr=asr, **kwargs)
198198

199+
#################### TRANSLATE ################################
200+
201+
@batchalign.command()
202+
@common_options
203+
@click.pass_context
204+
def translate(ctx, in_dir, out_dir, **kwargs):
205+
"""Translate the transcript to English."""
206+
207+
def loader(file):
208+
cf = CHATFile(path=os.path.abspath(file), special_mor_=True)
209+
doc = cf.doc
210+
# if str(cf).count("%mor") > 0:
211+
# doc.ba_special_["special_mor_notation"] = True
212+
return doc
213+
214+
def writer(doc, output):
215+
CHATFile(doc=doc).write(output)
216+
217+
_dispatch("translate", "eng", 1, ["cha"], ctx,
218+
in_dir, out_dir,
219+
loader, writer, C)
220+
199221
#################### MORPHOTAG ################################
200222

201223
@batchalign.command()

batchalign/cli/dispatch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
"benchmark": "asr,eval",
4949
"utseg": "utterance",
5050
"coref": "coref",
51+
"translate": "translate",
5152
}
5253

5354
# this is the main runner used by all functions

batchalign/document.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class Task(IntEnum):
3131
MORPHOSYNTAX = 11
3232
COREF = 12
3333
WER = 13
34+
TRANSLATE = 14
3435

3536

3637
DEBUG__G = 0
@@ -54,6 +55,7 @@ class TaskType(IntEnum):
5455
Task.DISFLUENCY_ANALYSIS: TaskType.PROCESSING,
5556
Task.COREF: TaskType.PROCESSING,
5657
Task.WER: TaskType.ANALYSIS,
58+
Task.TRANSLATE: TaskType.PROCESSING,
5759

5860
Task.DEBUG__G: TaskType.GENERATION,
5961
Task.DEBUG__P: TaskType.PROCESSING,
@@ -73,6 +75,7 @@ class TaskType(IntEnum):
7375
Task.DISFLUENCY_ANALYSIS: "Disfluncy Analysis",
7476
Task.COREF: "Coreference Resolution",
7577
Task.WER: "Word Error Rate",
78+
Task.TRANSLATE: "Translation",
7679
Task.DEBUG__G: "TEST_GENERATION",
7780
Task.DEBUG__P: "TEST_PROCESSING",
7881
Task.DEBUG__A: "TEST_ANALYSIS",
@@ -150,6 +153,7 @@ class Utterance(BaseModel):
150153
tier: Tier = Field(default=Tier())
151154
content: Sentence
152155
text: Optional[str] = Field(default=None)
156+
translation: Optional[str] = Field(default=None)
153157
time: Optional[Tuple[int,int]] = Field(default=None)
154158
custom_dependencies: List[CustomLine] = Field(default=[])
155159

batchalign/formats/chat/generator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,8 @@ def generate_chat_utterance(utterance: Utterance, special_mor=False, write_wor=T
9595
result.append("%wor:\t"+" ".join(wor_elems))
9696
if has_coref:
9797
result.append("%coref:\t"+" ".join(coref_elems))
98-
98+
if utterance.translation != None:
99+
result.append("%xtra:\t"+utterance.translation)
99100

100101

101102
#### EXTRA LINE GENERATION ####

batchalign/formats/chat/parser.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ def chat_parse_doc(lines, special_mor=False):
280280
mor = None
281281
gra = None
282282
wor = None
283+
translation = None
283284
additional = []
284285

285286
while raw[0][0] == "%":
@@ -291,6 +292,8 @@ def chat_parse_doc(lines, special_mor=False):
291292
gra = line
292293
elif beg.strip() == "wor" or beg.strip() == "xwor":
293294
wor = line
295+
elif beg.strip() == "xtra":
296+
translation = line
294297
else:
295298
additional.append(CustomLine(id=beg.strip(),
296299
type=CustomLineType.DEPENDENT,
@@ -309,7 +312,8 @@ def chat_parse_doc(lines, special_mor=False):
309312
"content": parsed,
310313
"text": text,
311314
"delim": delim,
312-
"custom_dependencies": additional
315+
"custom_dependencies": additional,
316+
"translation": translation
313317
})
314318

315319
timing = re.findall(rf"\x15(\d+)_(\d+)\x15", text)

batchalign/pipelines/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@
1212
from .analysis import EvaluationEngine
1313
from .utterance import StanzaUtteranceEngine
1414

15+
from .translate import SeamlessTranslationModel

batchalign/pipelines/dispatch.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from batchalign import (WhisperEngine, WhisperFAEngine, StanzaEngine, RevEngine,
77
NgramRetraceEngine, DisfluencyReplacementEngine, WhisperUTREngine,
88
RevUTREngine, EvaluationEngine, WhisperXEngine, NemoSpeakerEngine,
9-
StanzaUtteranceEngine, CorefEngine, Wave2VecFAEngine)
9+
StanzaUtteranceEngine, CorefEngine, Wave2VecFAEngine, SeamlessTranslationModel)
1010
from batchalign import BatchalignPipeline
1111
from batchalign.models import resolve
1212

@@ -28,6 +28,7 @@
2828
"eval": "evaluation",
2929
"utterance": "stanza_utt",
3030
"coref": "stanza_coref",
31+
"translate": "seamless_translate",
3132
}
3233

3334
LANGUAGE_OVERRIDE_PACKAGES = {
@@ -129,6 +130,8 @@ def dispatch_pipeline(pkg_str, lang, num_speakers=None, **arg_overrides):
129130
engines.append(CorefEngine())
130131
elif engine == "wav2vec_fa":
131132
engines.append(Wave2VecFAEngine())
133+
elif engine == "seamless_translate":
134+
engines.append(SeamlessTranslationModel())
132135

133136
L.debug(f"Done initalizing packages.")
134137
return BatchalignPipeline(*engines)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .seamless import SeamlessTranslationModel
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from batchalign.models import WhisperFAModel
2+
from batchalign.document import *
3+
from batchalign.pipelines.base import *
4+
from batchalign.utils import *
5+
from batchalign.utils.dp import *
6+
from batchalign.constants import *
7+
8+
from transformers import AutoProcessor, SeamlessM4TModel
9+
10+
import logging
11+
L = logging.getLogger("batchalign")
12+
13+
import re
14+
15+
# !uv pip install sentencepiece
16+
17+
import pycountry
18+
import warnings
19+
20+
class SeamlessTranslationModel(BatchalignEngine):
21+
tasks = [ Task.TRANSLATE ]
22+
23+
def _hook_status(self, status_hook):
24+
self.status_hook = status_hook
25+
26+
def __init__(self):
27+
self.status_hook = None
28+
self.processor = AutoProcessor.from_pretrained("facebook/hf-seamless-m4t-medium")
29+
self.model = SeamlessM4TModel.from_pretrained("facebook/hf-seamless-m4t-medium")
30+
31+
def process(self, doc:Document, **kwargs):
32+
33+
for indx, i in enumerate(doc.content):
34+
if not isinstance(i, Utterance):
35+
continue
36+
if i.translation:
37+
continue
38+
39+
text = i.strip(join_with_spaces=False, include_retrace=True, include_fp=True)
40+
text_inputs = self.processor(text=text, src_lang=doc.langs[0] if doc.langs[0] != "zho" else "cmn", return_tensors="pt")
41+
output_tokens = self.model.generate(**text_inputs, tgt_lang="eng", generate_speech=False)
42+
translated_text_from_text = self.processor.decode(output_tokens[0].tolist()[0], skip_special_tokens=True)
43+
44+
i.translation = translated_text_from_text
45+
for j in MOR_PUNCT + ENDING_PUNCT:
46+
i.translation = i.translation.replace(j, " "+j)
47+
48+
if self.status_hook != None:
49+
self.status_hook(indx+1, len(doc.content))
50+
51+
return doc
52+
53+

batchalign/version

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
0.7.13-post.1
2-
Feburary 14nd, 2025
3-
Remove hash sign.
1+
0.7.14
2+
Feburary 19nd, 2025
3+
machine translation!

0 commit comments

Comments
 (0)