Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ariautils/config/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
"abs_time_step_ms": 5000,
"max_dur_ms": 5000,
"time_step_ms": 10,
"include_pedal": true,
"composer_names": ["bach", "beethoven", "mozart", "chopin", "rachmaninoff", "liszt", "debussy", "schubert", "brahms", "ravel", "satie", "scarlatti"],
"form_names": ["sonata", "prelude", "nocturne", "étude", "waltz", "mazurka", "impromptu", "fugue"],
"genre_names": ["jazz", "classical"]
Expand Down
84 changes: 77 additions & 7 deletions ariautils/tokenizer/absolute.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,17 @@ def __init__(self, config_path: Path | str | None = None) -> None:
)
self.pad_id = self.tok_to_id[self.pad_tok]

# Pedal tokens appended to end of vocab
self.include_pedal = self.config["include_pedal"]
self.ped_on_tok = "<PED_ON>"
self.ped_off_tok = "<PED_OFF>"
if self.config["include_pedal"] is True:
self.add_tokens_to_vocab([self.ped_on_tok, self.ped_off_tok])

def export_data_aug(self) -> list[Callable[[list[Token]], list[Token]]]:
assert (
self.include_pedal is False
), f"Data augmentation doesn't support pedal"
return [
self.export_tempo_aug(max_tempo_aug=0.2, mixup=True),
self.export_pitch_aug(5),
Expand Down Expand Up @@ -291,6 +301,10 @@ def _tokenize_midi_dict(
if channel_to_instrument.get(c) is None and c != 9:
channel_to_instrument[c] = "piano"

if self.include_pedal:
assert len(channel_to_instrument.values()) == 1
assert set(channel_to_instrument.values()) == {"piano"}

# Calculate prefix
prefix: list[Token] = [
("prefix", "instrument", x)
Expand All @@ -316,14 +330,37 @@ def _tokenize_midi_dict(
else:
initial_onset_tick = midi_dict.note_msgs[0]["data"]["start"]

if self.include_pedal is True:
_msgs = midi_dict.note_msgs + midi_dict.pedal_msgs
_msgs.sort(
key=lambda msg: (
msg["data"]["start"]
if msg["type"] == "note"
else msg["tick"]
)
)
else:
_msgs = midi_dict.note_msgs

curr_time_since_onset = 0
for _, msg in enumerate(midi_dict.note_msgs):
for _, msg in enumerate(_msgs):
# Extract msg data
_channel = msg["channel"]
_pitch = msg["data"]["pitch"]
_velocity = msg["data"]["velocity"]
_start_tick = msg["data"]["start"]
_end_tick = msg["data"]["end"]
if msg["type"] == "note":
_type = "note"
_channel = msg["channel"]
_pitch = msg["data"]["pitch"]
_velocity = msg["data"]["velocity"]
_pedal_data = None
_start_tick = msg["data"]["start"]
_end_tick = msg["data"]["end"]
elif msg["type"] == "pedal":
_type = "pedal"
_channel = msg["channel"]
_pitch = None
_velocity = None
_pedal_data = msg["data"]
_start_tick = msg["tick"]
_end_tick = None

# Calculate time data
prev_time_since_onset = curr_time_since_onset
Expand Down Expand Up @@ -351,8 +388,23 @@ def _tokenize_midi_dict(
tokenized_seq.append(("drum", _pitch))
tokenized_seq.append(("onset", _note_onset))

elif _type == "pedal":
_pedal_onset = self._quantize_onset(
curr_time_since_onset % self.abs_time_step_ms
)
if _pedal_data == 1:
tokenized_seq.append(self.ped_on_tok)
tokenized_seq.append(("onset", _pedal_onset))
elif _pedal_data == 0:
tokenized_seq.append(self.ped_off_tok)
tokenized_seq.append(("onset", _pedal_onset))
else:
raise ValueError("Invalid pedal message")

else: # Non drum case (i.e. an instrument note)
_instrument = channel_to_instrument[_channel]
assert _velocity is not None
assert _end_tick is not None

# Update _end_tick if affected by pedal
for pedal_interval in channel_to_pedal_intervals[_channel]:
Expand Down Expand Up @@ -481,6 +533,10 @@ def _detokenize_midi_dict(self, tokenized_seq: list[Token]) -> MidiDict:
start = idx
break

if self.include_pedal:
assert len(instrument_msgs) == 1
assert instrument_msgs[0]["data"] == 0 # Piano

# Note messages
note_msgs: list[NoteMessage] = []
for tok_1, tok_2, tok_3 in zip(
Expand All @@ -503,14 +559,28 @@ def _detokenize_midi_dict(self, tokenized_seq: list[Token]) -> MidiDict:

if tok_1 == self.time_tok:
curr_tick += self.abs_time_step_ms

elif (
_tok_type_1 == "special"
or _tok_type_1 == "prefix"
or _tok_type_1 == "onset"
or _tok_type_1 == "dur"
):
continue
elif tok_1 in {self.ped_on_tok, self.ped_off_tok}:
assert isinstance(
tok_2[1], int
), f"Expected int for onset, got {tok_2[1]}"

_data = 1 if tok_1 == self.ped_on_tok else 0
_tick: int = curr_tick + tok_2[1]
pedal_msgs.append(
{
"type": "pedal",
"data": _data,
"tick": _tick,
"channel": 0,
}
)
elif _tok_type_1 == "drum" and _tok_type_2 == "onset":
assert isinstance(
tok_2[1], int
Expand Down
Binary file added tests/assets/data/transcription.mid
Binary file not shown.
10 changes: 10 additions & 0 deletions tests/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,16 @@ def _test_tokenize_detokenize(_load_path: Path) -> None:
load_path = TEST_DATA_DIRECTORY.joinpath("basic.mid")
_test_tokenize_detokenize(_load_path=load_path)

def test_tokenize_detokenize_pedal(self) -> None:
load_path = TEST_DATA_DIRECTORY.joinpath("transcription.mid")
save_path = RESULTS_DATA_DIRECTORY.joinpath("transcription_pedal.mid")

midi_dict = MidiDict.from_midi(load_path)
tokenizer = AbsTokenizer()

seq = tokenizer.tokenize(midi_dict=midi_dict)
tokenizer.detokenize(tokenized_seq=seq).to_midi().save(save_path)

def test_pitch_aug(self) -> None:
def _test_out_of_bounds(
tokenizer: AbsTokenizer,
Expand Down
Loading