|
| 1 | +from batchalign.models import Wave2VecFAModel |
| 2 | +from batchalign.document import * |
| 3 | +from batchalign.pipelines.base import * |
| 4 | +from batchalign.utils import * |
| 5 | +from batchalign.utils.dp import * |
| 6 | +from batchalign.constants import * |
| 7 | + |
| 8 | +import pycantonese as pc |
| 9 | + |
| 10 | +import logging |
| 11 | +L = logging.getLogger("batchalign") |
| 12 | + |
| 13 | +import re |
| 14 | + |
| 15 | +import pycountry |
| 16 | +import warnings |
| 17 | + |
| 18 | + |
| 19 | + |
| 20 | +class Wave2VecFAEngineCantonese(BatchalignEngine): |
| 21 | + tasks = [ Task.FORCED_ALIGNMENT ] |
| 22 | + |
| 23 | + @staticmethod |
| 24 | + def cantonese_to_mms_chars(text: str) -> str: |
| 25 | + pairs = pc.characters_to_jyutping(text) |
| 26 | + |
| 27 | + try: |
| 28 | + jyut = " ".join(j for _, j in pairs) |
| 29 | + except TypeError: |
| 30 | + return text |
| 31 | + jyut = re.sub(r"[1-6]", "", jyut) |
| 32 | + jyut = re.sub(r"\s+", "'", jyut).strip("'") |
| 33 | + return jyut |
| 34 | + |
| 35 | + def _hook_status(self, status_hook): |
| 36 | + self.status_hook = status_hook |
| 37 | + |
| 38 | + def __init__(self): |
| 39 | + self.status_hook = None |
| 40 | + self.__wav2vec = Wave2VecFAModel() |
| 41 | + |
| 42 | + def process(self, doc:Document, **kwargs): |
| 43 | + # check that the document has a media path to align to |
| 44 | + assert doc.media != None and doc.media.url != None, f"We cannot forced-align something that doesn't have a media path! Provided media tier='{doc.media}'" |
| 45 | + assert "yue" in doc.langs, "Please use normal wav2vec to align non-cantonese speech." |
| 46 | + |
| 47 | + # load the audio file |
| 48 | + L.debug(f"Wave2Vec FA is loading url {doc.media.url}...") |
| 49 | + f = self.__wav2vec.load(doc.media.url) |
| 50 | + L.debug(f"Wav2Vec FA finished loading media.") |
| 51 | + |
| 52 | + # collect utterances 30 secondish segments to be aligned for whisper |
| 53 | + # we have to do this because whisper does poorly with very short segments |
| 54 | + groups = [] |
| 55 | + group = [] |
| 56 | + seg_start = 0 |
| 57 | + |
| 58 | + L.debug(f"Wav2Vec FA finished loading media.") |
| 59 | + |
| 60 | + for i in doc.content: |
| 61 | + if not isinstance(i, Utterance): |
| 62 | + continue |
| 63 | + if i.alignment is None: |
| 64 | + warnings.warn("We found at least one utterance without utterance-level alignment; this is usually not an issue, but if the entire transcript is unaligned, it means that utterance level timing recovery (which is fuzzy using ASR) failed due to the audio clarity. On this transcript, before running forced-alignment, please supply utterance-level links.") |
| 65 | + continue |
| 66 | + |
| 67 | + # pop the previous group onto the stack |
| 68 | + if (i.alignment[-1] - seg_start) > 15*1000: |
| 69 | + groups.append(group) |
| 70 | + group = [] |
| 71 | + seg_start = i.alignment[0] |
| 72 | + |
| 73 | + # append the contents to the running group |
| 74 | + for word in i.content: |
| 75 | + group.append((word, i.alignment)) |
| 76 | + |
| 77 | + groups.append(group) |
| 78 | + |
| 79 | + L.debug(f"Begin Wav2Vec Inference...") |
| 80 | + |
| 81 | + for indx, grp in enumerate(groups): |
| 82 | + L.info(f"Wave2Vec FA processing segment {indx+1}/{len(groups)}...") |
| 83 | + if self.status_hook != None: |
| 84 | + self.status_hook(indx+1, len(groups)) |
| 85 | + |
| 86 | + # perform alignment |
| 87 | + # we take a 2 second buffer in each direction |
| 88 | + try: |
| 89 | + transcript = [word[0].text for word in grp] |
| 90 | + # replace ANY punctuation |
| 91 | + for p in MOR_PUNCT + ENDING_PUNCT: |
| 92 | + transcript = [i.replace("_", " ") for i in transcript if i.strip() != p] |
| 93 | + |
| 94 | + transcript_mms = {r:self.cantonese_to_mms_chars(r) for r in transcript} |
| 95 | + transcript_mms_rev = {v:k for k,v in transcript_mms.items()} |
| 96 | + # if "noone's" in detokenized: |
| 97 | + # breakpoint() |
| 98 | + res = self.__wav2vec( |
| 99 | + audio=f.chunk(grp[0][1][0], grp[-1][1][1]), |
| 100 | + text=[transcript_mms[r] for r in transcript] |
| 101 | + ) |
| 102 | + except: |
| 103 | + # utterance contains nothing |
| 104 | + continue |
| 105 | + |
| 106 | + # create reference backplates, which are the word ids to set the timing for |
| 107 | + ref_targets = [] |
| 108 | + for indx, (word, _) in enumerate(grp): |
| 109 | + for char in word.text: |
| 110 | + ref_targets.append(ReferenceTarget(char, payload=indx)) |
| 111 | + # create target backplates for the timings |
| 112 | + payload_targets = [] |
| 113 | + timings = [] |
| 114 | + try: |
| 115 | + for indx, (word, time) in enumerate(res): |
| 116 | + timings.append(time) |
| 117 | + for char in transcript_mms_rev[word]: |
| 118 | + payload_targets.append(PayloadTarget(char, payload=indx)) |
| 119 | + except: |
| 120 | + continue |
| 121 | + # alignment! |
| 122 | + alignments = align(payload_targets, ref_targets, tqdm=False) |
| 123 | + |
| 124 | + # set the ids back to the text ids |
| 125 | + # we do this BACKWARDS because we went to have the first timestamp |
| 126 | + # we get about a word first |
| 127 | + alignments.reverse() |
| 128 | + for indx,elem in enumerate(alignments): |
| 129 | + if isinstance(elem, Match): |
| 130 | + grp[elem.reference_payload][0].time = (int(round((timings[elem.payload][0] + |
| 131 | + grp[0][1][0]))), |
| 132 | + int(round((timings[elem.payload][1] + |
| 133 | + grp[0][1][0])))) |
| 134 | + |
| 135 | + L.debug(f"Correcting text...") |
| 136 | + |
| 137 | + # we now set the end alignment of each word to the start of the next |
| 138 | + for doc_ut, ut in enumerate(doc.content): |
| 139 | + if not isinstance(ut, Utterance): |
| 140 | + continue |
| 141 | + |
| 142 | + # correct each word by bumping it forward |
| 143 | + # and if its not a word we remove the timing |
| 144 | + for indx, w in enumerate(ut.content): |
| 145 | + if w.type in [TokenType.PUNCT, TokenType.FEAT, TokenType.ANNOT]: |
| 146 | + w.time = None |
| 147 | + elif indx == len(ut.content)-1 and w.text in ENDING_PUNCT: |
| 148 | + w.time = None |
| 149 | + elif indx != len(ut.content)-1: |
| 150 | + # search forward for the next compatible time |
| 151 | + tmp = indx+1 |
| 152 | + while tmp < len(ut.content)-1 and ut.content[tmp].time == None: |
| 153 | + tmp += 1 |
| 154 | + if w.time == None: |
| 155 | + continue |
| 156 | + if ut.content[tmp].time == None: |
| 157 | + # seek forward one utterance to find their start time |
| 158 | + next_ut = doc_ut + 1 |
| 159 | + while next_ut < len(doc.content)-1 and (not isinstance(doc.content, Utterance) or doc.content[next_ut].alignment == None): |
| 160 | + next_ut += 1 |
| 161 | + if next_ut < len(doc.content) and isinstance(doc.content, Utterance) and doc.content[next_ut].alignment: |
| 162 | + w.time = (w.time[0], doc.content[next_ut].alignment[0]) |
| 163 | + else: |
| 164 | + w.time = (w.time[0], w.time[0]+500) # give half a second because we don't know |
| 165 | + |
| 166 | + # just in case, bound the time by the utterance derived timings |
| 167 | + if ut.alignment and ut.alignment[0] != None: |
| 168 | + w.time = (max(w.time[0], ut.alignment[0]), min(w.time[1], ut.alignment[1])) |
| 169 | + # if we ended up with timings that don't make sense, drop it |
| 170 | + if w.time and w.time[0] >= w.time[1]: |
| 171 | + w.time = None |
| 172 | + |
| 173 | + # clear any built-in timing (i.e. we should use utterance-derived timing) |
| 174 | + ut.time = None |
| 175 | + # correct the text |
| 176 | + if ut.alignment and ut.text != None: |
| 177 | + if '\x15' not in ut.text: |
| 178 | + ut.text = (ut.text+f" \x15{ut.alignment[0]}_{ut.alignment[1]}\x15").strip() |
| 179 | + else: |
| 180 | + ut.text = re.sub(r"\x15\d+_\d+\x15", |
| 181 | + f"\x15{ut.alignment[0]}_{ut.alignment[1]}\x15", ut.text).strip() |
| 182 | + elif ut.text != None: |
| 183 | + ut.text = re.sub(r"\x15\d+_\d+\x15", f"", ut.text).strip() |
| 184 | + |
| 185 | + return doc |
0 commit comments