diff --git a/ariautils/config/config.json b/ariautils/config/config.json index dbd6d32..745e783 100644 --- a/ariautils/config/config.json +++ b/ariautils/config/config.json @@ -87,6 +87,7 @@ "abs_time_step_ms": 5000, "max_dur_ms": 5000, "time_step_ms": 10, + "include_pedal": true, "composer_names": ["bach", "beethoven", "mozart", "chopin", "rachmaninoff", "liszt", "debussy", "schubert", "brahms", "ravel", "satie", "scarlatti"], "form_names": ["sonata", "prelude", "nocturne", "étude", "waltz", "mazurka", "impromptu", "fugue"], "genre_names": ["jazz", "classical"] diff --git a/ariautils/tokenizer/absolute.py b/ariautils/tokenizer/absolute.py index b91ee6e..21df4fd 100644 --- a/ariautils/tokenizer/absolute.py +++ b/ariautils/tokenizer/absolute.py @@ -144,7 +144,17 @@ def __init__(self, config_path: Path | str | None = None) -> None: ) self.pad_id = self.tok_to_id[self.pad_tok] + # Pedal tokens appended to end of vocab + self.include_pedal = self.config["include_pedal"] + self.ped_on_tok = "" + self.ped_off_tok = "" + if self.config["include_pedal"] is True: + self.add_tokens_to_vocab([self.ped_on_tok, self.ped_off_tok]) + def export_data_aug(self) -> list[Callable[[list[Token]], list[Token]]]: + assert ( + self.include_pedal is False + ), f"Data augmentation doesn't support pedal" return [ self.export_tempo_aug(max_tempo_aug=0.2, mixup=True), self.export_pitch_aug(5), @@ -291,6 +301,10 @@ def _tokenize_midi_dict( if channel_to_instrument.get(c) is None and c != 9: channel_to_instrument[c] = "piano" + if self.include_pedal: + assert len(channel_to_instrument.values()) == 1 + assert set(channel_to_instrument.values()) == {"piano"} + # Calculate prefix prefix: list[Token] = [ ("prefix", "instrument", x) @@ -316,14 +330,37 @@ def _tokenize_midi_dict( else: initial_onset_tick = midi_dict.note_msgs[0]["data"]["start"] + if self.include_pedal is True: + _msgs = midi_dict.note_msgs + midi_dict.pedal_msgs + _msgs.sort( + key=lambda msg: ( + msg["data"]["start"] + if msg["type"] == "note" + else msg["tick"] + ) + ) + else: + _msgs = midi_dict.note_msgs + curr_time_since_onset = 0 - for _, msg in enumerate(midi_dict.note_msgs): + for _, msg in enumerate(_msgs): # Extract msg data - _channel = msg["channel"] - _pitch = msg["data"]["pitch"] - _velocity = msg["data"]["velocity"] - _start_tick = msg["data"]["start"] - _end_tick = msg["data"]["end"] + if msg["type"] == "note": + _type = "note" + _channel = msg["channel"] + _pitch = msg["data"]["pitch"] + _velocity = msg["data"]["velocity"] + _pedal_data = None + _start_tick = msg["data"]["start"] + _end_tick = msg["data"]["end"] + elif msg["type"] == "pedal": + _type = "pedal" + _channel = msg["channel"] + _pitch = None + _velocity = None + _pedal_data = msg["data"] + _start_tick = msg["tick"] + _end_tick = None # Calculate time data prev_time_since_onset = curr_time_since_onset @@ -351,8 +388,23 @@ def _tokenize_midi_dict( tokenized_seq.append(("drum", _pitch)) tokenized_seq.append(("onset", _note_onset)) + elif _type == "pedal": + _pedal_onset = self._quantize_onset( + curr_time_since_onset % self.abs_time_step_ms + ) + if _pedal_data == 1: + tokenized_seq.append(self.ped_on_tok) + tokenized_seq.append(("onset", _pedal_onset)) + elif _pedal_data == 0: + tokenized_seq.append(self.ped_off_tok) + tokenized_seq.append(("onset", _pedal_onset)) + else: + raise ValueError("Invalid pedal message") + else: # Non drum case (i.e. an instrument note) _instrument = channel_to_instrument[_channel] + assert _velocity is not None + assert _end_tick is not None # Update _end_tick if affected by pedal for pedal_interval in channel_to_pedal_intervals[_channel]: @@ -481,6 +533,10 @@ def _detokenize_midi_dict(self, tokenized_seq: list[Token]) -> MidiDict: start = idx break + if self.include_pedal: + assert len(instrument_msgs) == 1 + assert instrument_msgs[0]["data"] == 0 # Piano + # Note messages note_msgs: list[NoteMessage] = [] for tok_1, tok_2, tok_3 in zip( @@ -503,7 +559,6 @@ def _detokenize_midi_dict(self, tokenized_seq: list[Token]) -> MidiDict: if tok_1 == self.time_tok: curr_tick += self.abs_time_step_ms - elif ( _tok_type_1 == "special" or _tok_type_1 == "prefix" @@ -511,6 +566,21 @@ def _detokenize_midi_dict(self, tokenized_seq: list[Token]) -> MidiDict: or _tok_type_1 == "dur" ): continue + elif tok_1 in {self.ped_on_tok, self.ped_off_tok}: + assert isinstance( + tok_2[1], int + ), f"Expected int for onset, got {tok_2[1]}" + + _data = 1 if tok_1 == self.ped_on_tok else 0 + _tick: int = curr_tick + tok_2[1] + pedal_msgs.append( + { + "type": "pedal", + "data": _data, + "tick": _tick, + "channel": 0, + } + ) elif _tok_type_1 == "drum" and _tok_type_2 == "onset": assert isinstance( tok_2[1], int diff --git a/tests/assets/data/transcription.mid b/tests/assets/data/transcription.mid new file mode 100644 index 0000000..75652de Binary files /dev/null and b/tests/assets/data/transcription.mid differ diff --git a/tests/test_tokenizer.py b/tests/test_tokenizer.py index e13cf0b..2a5ea3c 100644 --- a/tests/test_tokenizer.py +++ b/tests/test_tokenizer.py @@ -111,6 +111,16 @@ def _test_tokenize_detokenize(_load_path: Path) -> None: load_path = TEST_DATA_DIRECTORY.joinpath("basic.mid") _test_tokenize_detokenize(_load_path=load_path) + def test_tokenize_detokenize_pedal(self) -> None: + load_path = TEST_DATA_DIRECTORY.joinpath("transcription.mid") + save_path = RESULTS_DATA_DIRECTORY.joinpath("transcription_pedal.mid") + + midi_dict = MidiDict.from_midi(load_path) + tokenizer = AbsTokenizer() + + seq = tokenizer.tokenize(midi_dict=midi_dict) + tokenizer.detokenize(tokenized_seq=seq).to_midi().save(save_path) + def test_pitch_aug(self) -> None: def _test_out_of_bounds( tokenizer: AbsTokenizer,