From 1886628ea904324f16b4418fb86c472f73572f6c Mon Sep 17 00:00:00 2001 From: Louis Date: Wed, 4 Dec 2024 16:55:45 +0000 Subject: [PATCH 1/5] change metadat fn format --- ariautils/config/config.json | 25 ------------------------- ariautils/midi.py | 22 ++++++++++------------ 2 files changed, 10 insertions(+), 37 deletions(-) diff --git a/ariautils/config/config.json b/ariautils/config/config.json index 90593f1..dbd6d32 100644 --- a/ariautils/config/config.json +++ b/ariautils/config/config.json @@ -1,29 +1,4 @@ { - "data": { - "pre_processing": { - "remove_instruments": { - "run": true, - "args": { - "piano": false, - "chromatic": true, - "organ": false, - "guitar": false, - "bass": false, - "strings": false, - "ensemble": false, - "brass": false, - "reed": false, - "pipe": false, - "synth_lead": false, - "synth_pad": true, - "synth_effect": true, - "ethnic": true, - "percussive": true, - "sfx": true - } - } - } - }, "tokenizer": { "rel": { "ignore_instruments": { diff --git a/ariautils/midi.py b/ariautils/midi.py index 6f17285..ecc4589 100644 --- a/ariautils/midi.py +++ b/ariautils/midi.py @@ -895,9 +895,9 @@ def to_ascii(s: str) -> str: def meta_composer_filename( - midi_dict_data: MidiDictData, composer_names: list + midi_dict: MidiDict, composer_names: list ) -> dict[str, str]: - abs_load_path = midi_dict_data["metadata"].get("abs_load_path") + abs_load_path = midi_dict.metadata.get("abs_load_path") if abs_load_path is None: return {} @@ -915,10 +915,8 @@ def meta_composer_filename( return {} -def meta_form_filename( - midi_dict_data: MidiDictData, form_names: list -) -> dict[str, str]: - abs_load_path = midi_dict_data["metadata"].get("abs_load_path") +def meta_form_filename(midi_dict: MidiDict, form_names: list) -> dict[str, str]: + abs_load_path = midi_dict.metadata.get("abs_load_path") if abs_load_path is None: return {} @@ -937,10 +935,10 @@ def meta_form_filename( def meta_composer_metamsg( - midi_dict_data: MidiDictData, composer_names: list + midi_dict: MidiDict, composer_names: list ) -> dict[str, str]: matched_names_unique = set() - for msg in midi_dict_data["meta_msgs"]: + for msg in midi_dict.meta_msgs: for name in composer_names: if _match_word(msg["data"], name): matched_names_unique.add(name) @@ -955,7 +953,7 @@ def meta_composer_metamsg( # TODO: Needs testing def meta_maestro_json( - midi_dict_data: MidiDictData, + midi_dict: MidiDict, composer_names: list, form_names: list, ) -> dict[str, str]: @@ -967,7 +965,7 @@ def meta_maestro_json( the form file_name: {"composer": str, "title": str}. """ - abs_load_path = midi_dict_data["metadata"].get("abs_load_path") + abs_load_path = midi_dict.metadata.get("abs_load_path") if abs_load_path is None: return {} @@ -1000,10 +998,10 @@ def meta_maestro_json( # TODO: Add metadata function compatible with aria-midi def get_metadata_fn( metadata_process_name: str, -) -> Callable[Concatenate[MidiDictData, ...], dict[str, str]]: +) -> Callable[Concatenate[MidiDict, ...], dict[str, str]]: name_to_fn: dict[ str, - Callable[Concatenate[MidiDictData, ...], dict[str, str]], + Callable[Concatenate[MidiDict, ...], dict[str, str]], ] = { "composer_filename": meta_composer_filename, "composer_metamsg": meta_composer_metamsg, From 2e028b5fbedda0358c92957b4c231410acf3e737 Mon Sep 17 00:00:00 2001 From: Louis Date: Wed, 4 Dec 2024 17:42:44 +0000 Subject: [PATCH 2/5] add info to pitch_count test --- ariautils/midi.py | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/ariautils/midi.py b/ariautils/midi.py index ecc4589..183f8ad 100644 --- a/ariautils/midi.py +++ b/ariautils/midi.py @@ -1184,9 +1184,9 @@ def _test_unique_pitch_count_in_interval( midi_dict: MidiDict, min_unique_pitch_cnt: int, interval_len_s: int, -) -> tuple[bool, int]: +) -> tuple[bool, tuple[int, float]]: if not midi_dict.note_msgs: - return False, 0 + return False, (0, 0) note_events = [ ( @@ -1203,6 +1203,7 @@ def _test_unique_pitch_count_in_interval( midi_dict.tick_to_ms(midi_dict.note_msgs[0]["tick"]) / 1000.0 ) min_window_pitch_count_seen = 128 + min_window_start_s = 0 end_idx = 0 notes_in_window: Deque[tuple[int, int]] = deque() while end_idx < len(note_events): @@ -1234,29 +1235,39 @@ def _test_unique_pitch_count_in_interval( min_window_pitch_count_seen, len(unique_pitches_in_window), ) + if len(unique_pitches_in_window) < min_window_pitch_count_seen: + min_window_pitch_count_seen = len(unique_pitches_in_window) + min_window_start_s = interval_start_s interval_start_s += WINDOW_STEP_S if min_window_pitch_count_seen < min_unique_pitch_cnt: - return False, min_window_pitch_count_seen + return False, (min_window_pitch_count_seen, min_window_start_s) else: - return True, min_window_pitch_count_seen + return True, (min_window_pitch_count_seen, min_window_start_s) def test_unique_pitch_count_in_interval( midi_dict: MidiDict, test_params_list: list[dict] -) -> tuple[bool, int]: +) -> tuple[bool, tuple[int, int, float]]: for test_params in test_params_list: - success, val = _test_unique_pitch_count_in_interval( - midi_dict=midi_dict, - min_unique_pitch_cnt=test_params["min_unique_pitch_cnt"], - interval_len_s=test_params["interval_len_s"], + success, (pitch_cnt, window_start_s) = ( + _test_unique_pitch_count_in_interval( + midi_dict=midi_dict, + min_unique_pitch_cnt=test_params["min_unique_pitch_cnt"], + interval_len_s=test_params["interval_len_s"], + ) ) + if success is False: - return False, val + return False, ( + test_params["interval_len_s"], + pitch_cnt, + window_start_s, + ) - return True, val + return True, (test_params["interval_len_s"], pitch_cnt, window_start_s) def get_test_fn( From 20c6ecfccdd335fd1d1408f19a8bbefba5786473 Mon Sep 17 00:00:00 2001 From: Louis Date: Wed, 4 Dec 2024 17:59:59 +0000 Subject: [PATCH 3/5] fix --- ariautils/midi.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/ariautils/midi.py b/ariautils/midi.py index 183f8ad..967e1f7 100644 --- a/ariautils/midi.py +++ b/ariautils/midi.py @@ -1203,7 +1203,7 @@ def _test_unique_pitch_count_in_interval( midi_dict.tick_to_ms(midi_dict.note_msgs[0]["tick"]) / 1000.0 ) min_window_pitch_count_seen = 128 - min_window_start_s = 0 + min_window_start_s = 0.0 end_idx = 0 notes_in_window: Deque[tuple[int, int]] = deque() while end_idx < len(note_events): @@ -1231,10 +1231,6 @@ def _test_unique_pitch_count_in_interval( note_tuple[0] for note_tuple in notes_in_window } - min_window_pitch_count_seen = min( - min_window_pitch_count_seen, - len(unique_pitches_in_window), - ) if len(unique_pitches_in_window) < min_window_pitch_count_seen: min_window_pitch_count_seen = len(unique_pitches_in_window) min_window_start_s = interval_start_s From 669e66fe2557e8b835e9f530ffe6582c98356140 Mon Sep 17 00:00:00 2001 From: Louis Date: Thu, 5 Dec 2024 19:15:26 +0000 Subject: [PATCH 4/5] add test functions --- ariautils/midi.py | 310 +++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 282 insertions(+), 28 deletions(-) diff --git a/ariautils/midi.py b/ariautils/midi.py index 967e1f7..435cab7 100644 --- a/ariautils/midi.py +++ b/ariautils/midi.py @@ -1,7 +1,6 @@ """Utils for MIDI processing.""" import re -import os import json import hashlib import copy @@ -10,6 +9,7 @@ from mido.midifiles.units import tick2second from collections import defaultdict, deque +from math import log2 from pathlib import Path from typing import ( Any, @@ -1019,37 +1019,63 @@ def get_metadata_fn( def test_max_programs(midi_dict: MidiDict, max: int) -> tuple[bool, int]: - """Returns false if midi_dict uses more than {max} programs.""" + """Tests the number of programs present. + + Args: + midi_dict (MidiDict): MidiDict to test. + max (int): Maximum allowed number of unique programs. + + Returns: + bool: True if number of programs <= max, False otherwise + int: Actual number of unique programs found + """ present_programs = set( map( lambda msg: msg["data"], midi_dict.instrument_msgs, ) ) + is_valid = len(present_programs) <= max - if len(present_programs) <= max: - return True, len(present_programs) - else: - return False, len(present_programs) + return is_valid, len(present_programs) def test_max_instruments(midi_dict: MidiDict, max: int) -> tuple[bool, int]: + """Tests the number of instruments present. + + Args: + midi_dict (MidiDict): MidiDict to test. + max (int): Maximum allowed number of unique instruments. + + Returns: + bool: True if number of instruments <= max, False otherwise. + int: Number of unique instruments found. + """ present_instruments = set( map( lambda msg: midi_dict.program_to_instrument[msg["data"]], midi_dict.instrument_msgs, ) ) + is_valid = len(present_instruments) <= max - if len(present_instruments) <= max: - return True, len(present_instruments) - else: - return False, len(present_instruments) + return is_valid, len(present_instruments) def test_note_frequency( midi_dict: MidiDict, max_per_second: float, min_per_second: float ) -> tuple[bool, float]: + """Tests if overall note frequency falls within specified bounds. + + Args: + midi_dict (MidiDict): MidiDict to test. + max_per_second (float): Maximum allowed notes per second. + min_per_second (float): Minimum required notes per second. + + Returns: + bool: True if frequency is within bounds, False otherwise. + float: Actual notes per second found. + """ if not midi_dict.note_msgs: return False, 0.0 @@ -1065,16 +1091,25 @@ def test_note_frequency( return False, 0.0 notes_per_second = (num_notes * 1e3) / total_duration_ms + is_valid = min_per_second <= notes_per_second <= max_per_second - if notes_per_second < min_per_second or notes_per_second > max_per_second: - return False, notes_per_second - else: - return True, notes_per_second + return is_valid, notes_per_second def test_note_frequency_per_instrument( midi_dict: MidiDict, max_per_second: float, min_per_second: float ) -> tuple[bool, float]: + """Tests if the note frequency per instrument falls within specified bounds. + + Args: + midi_dict (MidiDict): MidiDict to test. + max_per_second (float): Maximum notes per second per instrument. + min_per_second (float): Minimum notes per second per instrument. + + Returns: + bool: True if frequency is within bounds, False otherwise. + float: Actual notes per second per instrument found. + """ num_instruments = len( set( map( @@ -1084,7 +1119,7 @@ def test_note_frequency_per_instrument( ) ) - if not midi_dict.note_msgs: + if not midi_dict.note_msgs or not midi_dict.instrument_msgs: return False, 0.0 num_notes = len(midi_dict.note_msgs) @@ -1099,20 +1134,26 @@ def test_note_frequency_per_instrument( return False, 0.0 notes_per_second = (num_notes * 1e3) / total_duration_ms - note_freq_per_instrument = notes_per_second / num_instruments - if ( - note_freq_per_instrument < min_per_second - or note_freq_per_instrument > max_per_second - ): - return False, note_freq_per_instrument - else: - return True, note_freq_per_instrument + is_valid = min_per_second <= note_freq_per_instrument <= max_per_second + + return is_valid, note_freq_per_instrument def test_min_length( midi_dict: MidiDict, min_seconds: int ) -> tuple[bool, float]: + """Tests the min length is above a threshold. + + Args: + midi_dict (MidiDict): MidiDict to test. + min_seconds (int): Minimum length threshold in seconds. + + Returns: + bool: True if longer than minimum length, else False. + float: Length in seconds. + """ + if not midi_dict.note_msgs: return False, 0.0 @@ -1122,16 +1163,87 @@ def test_min_length( tempo_msgs=midi_dict.tempo_msgs, ticks_per_beat=midi_dict.ticks_per_beat, ) + is_valid = total_duration_ms / 1e3 >= min_seconds + + return is_valid, total_duration_ms / 1e3 + + +def test_mean_note_velocity( + midi_dict: MidiDict, + min_mean_velocity: float, + max_mean_velocity: float, +) -> tuple[bool, float]: + """Tests the average velocity of non-drum note messages. + + Args: + midi_dict (MidiDict): MidiDict to test. + min_mean_velocity (float): Min average velocity. + max_mean_velocity (float): Max average velocity. + + Returns: + bool: True if passed test, else False. + float: Average velocity value. + """ + + note_msgs_nd = [msg for msg in midi_dict.note_msgs if msg["channel"] != 9] + if not note_msgs_nd: + return False, 0.0 + + velocity_values = [msg["data"]["velocity"] for msg in note_msgs_nd] + + avg_velocity = sum(velocity_values) / len(velocity_values) + is_valid = min_mean_velocity <= avg_velocity <= max_mean_velocity + + return is_valid, avg_velocity - if total_duration_ms / 1e3 < min_seconds: - return False, total_duration_ms / 1e3 - else: - return True, total_duration_ms / 1e3 + +def test_mean_note_len( + midi_dict: MidiDict, + min_mean_len: float, + max_mean_len: float, +) -> tuple[bool, float]: + """Tests the average note length of MIDI messages in milliseconds. + + Args: + midi_dict (MidiDict): MidiDict to test. + min_mean_len (float): Minimum average note length. + max_mean_len (float): Maximum average note length. + + Returns: + bool: True if passed test, else False. + float: Average note length in milliseconds. + """ + + note_msgs_nd = [msg for msg in midi_dict.note_msgs if msg["channel"] != 9] + if not note_msgs_nd: + return False, 0.0 + + note_lengths = [ + midi_dict.tick_to_ms(msg["data"]["end"]) + - midi_dict.tick_to_ms(msg["data"]["start"]) + for msg in note_msgs_nd + ] + + mean_length = sum(note_lengths) / len(note_lengths) + is_valid = min_mean_len <= mean_length <= max_mean_len + + return is_valid, mean_length def test_silent_interval( midi_dict: MidiDict, max_silence_s: float ) -> tuple[bool, float]: + """Tests the length of silent gaps between notes. + + Args: + midi_dict (MidiDict): MidiDict to test. + max_silence_s (float): Maximum allowed silence in seconds. + + Returns: + bool: True if no silences exceed maximum, else False. + float: Duration of longest silence found, in seconds. + """ + if not midi_dict.note_msgs or not midi_dict.tempo_msgs: return False, 0.0 @@ -1164,6 +1276,17 @@ def test_silent_interval( def test_unique_pitch_count( midi_dict: MidiDict, min_num_unique_pitches: int ) -> tuple[bool, int]: + """Tests the number of unique pitches present. + + Args: + midi_dict (MidiDict): MidiDict to test. + min_num_unique_pitches (int): Minimum unique pitches. + + Returns: + bool: True if minimum pitch count met, else False. + int: Total number of unique pitches found. + """ + if not midi_dict.note_msgs: return False, 0 @@ -1183,7 +1306,7 @@ def test_unique_pitch_count( def _test_unique_pitch_count_in_interval( midi_dict: MidiDict, min_unique_pitch_cnt: int, - interval_len_s: int, + interval_len_s: float, ) -> tuple[bool, tuple[int, float]]: if not midi_dict.note_msgs: return False, (0, 0) @@ -1246,6 +1369,24 @@ def _test_unique_pitch_count_in_interval( def test_unique_pitch_count_in_interval( midi_dict: MidiDict, test_params_list: list[dict] ) -> tuple[bool, tuple[int, int, float]]: + """Tests if sufficient unique pitches occur within sliding time windows. + + For each set of test parameters, checks if there are enough unique MIDI + pitches within a sliding time window of specified length. + + Args: + midi_dict (MidiDict): MidiDict to test. + test_params_list (list[dict]): List of parameter dicts, each containing: + - min_unique_pitch_cnt (int): Minimum required unique pitches + - interval_len_s (float): Length of sliding window in seconds + + Returns: + bool: True if all tests passed, else False. + tuple[int, int, float]: Results from first failure or final test. + - interval_len_s: Window length in seconds + - pitch_cnt: Number of unique pitches found + - window_start_s: Start time of the window in seconds + """ for test_params in test_params_list: success, (pitch_cnt, window_start_s) = ( @@ -1266,6 +1407,115 @@ def test_unique_pitch_count_in_interval( return True, (test_params["interval_len_s"], pitch_cnt, window_start_s) +# 2.5-3.0 seems like a good pretraining threshold +def test_note_timing_entropy( + midi_dict: MidiDict, + min_length_entropy: float, + min_onset_delta_entropy: float, +) -> tuple[bool, tuple[float, float]]: + """Tests the entropy of the distribution of note lengths and onsets. + + Targets files with very un-random note duration and onset distributions. + Typically these consist of quantized files with moderate entropy, and + degenerate files very low entropy. Note lengths values are rounded to 10ms + and truncated at 5000ms. + + https://en.wikipedia.org/wiki/Entropy_(information_theory) + + Args: + midi_dict (MidiDict): MidiDict to test. + min_length_entropy (float): Minimum entropy of the distribution of + note-lengths. + min_onset_delta_entropy (float): Minimum entropy of the distribution of + onset-differences between subsequent notes. + + Returns: + bool: True if passed test, else False. + tuple[float, float]: Tuple containing: + - length_entropy: Entropy of the note length distribution + - onset_delta_entropy: Entropy of the note-onset delta distribution + """ + + note_msgs_nd = [msg for msg in midi_dict.note_msgs if msg["channel"] != 9] + if not note_msgs_nd: + return False, (0.0, 0.0) + + note_lens = [] + note_onset_deltas = [] + for prev_msg, msg in zip(note_msgs_nd, note_msgs_nd[1:]): + prev_msg_start_ms = midi_dict.tick_to_ms(prev_msg["data"]["start"]) + start_ms = midi_dict.tick_to_ms(msg["data"]["start"]) + end_ms = midi_dict.tick_to_ms(msg["data"]["end"]) + + note_onset_delta_ms = start_ms - prev_msg_start_ms + duration_ms = end_ms - start_ms + duration_ms = min(round(duration_ms, -1), 5000) + + note_lens.append(duration_ms) + note_onset_deltas.append(note_onset_delta_ms) + + total = len(note_lens) + len_counts = {} + onset_deltas_counts = {} + for interval in note_lens: + len_counts[interval] = len_counts.get(interval, 0) + 1 + for interval in note_onset_deltas: + onset_deltas_counts[interval] = onset_deltas_counts.get(interval, 0) + 1 + + len_entropy = -sum( + (cnt / total) * log2(cnt / total) for cnt in len_counts.values() + ) + onset_delta_entropy = -sum( + (cnt / total) * log2(cnt / total) + for cnt in onset_deltas_counts.values() + ) + is_valid = ( + len_entropy >= min_length_entropy + and onset_delta_entropy >= min_onset_delta_entropy + ) + + return is_valid, (len_entropy, onset_delta_entropy) + + +# 3.0 feels about right for pretraining +def test_note_pitch_entropy( + midi_dict: MidiDict, min_entropy: float +) -> tuple[bool, float]: + """Tests the entropy of the pitch distribution in a MIDI file. + + Targets degenerate files, as well as files with very simple harmonic and + melodic structure. + + Args: + midi_dict (MidiDict): MidiDict to test. + min_entropy (float): Minimum required entropy value. + + Returns: + bool: True if minimum entropy met, else False. + float: Calculated entropy value. + """ + + note_msgs_nd = [msg for msg in midi_dict.note_msgs if msg["channel"] != 9] + if not note_msgs_nd: + return False, 0.0 + + pitches = [] + for msg in note_msgs_nd: + pitches.append(msg["data"]["pitch"]) + + total = len(pitches) + counts = {} + for interval in pitches: + counts[interval] = counts.get(interval, 0) + 1 + + entropy = -sum( + (count / total) * log2(count / total) for count in counts.values() + ) + is_valid = entropy >= min_entropy + + return is_valid, entropy + + def get_test_fn( test_name: str, ) -> Callable[Concatenate[MidiDict, ...], tuple[bool, Any]]: @@ -1277,9 +1527,13 @@ def get_test_fn( "total_note_frequency": test_note_frequency, "note_frequency_per_instrument": test_note_frequency_per_instrument, "min_length": test_min_length, + "mean_note_len": test_mean_note_len, + "mean_note_velocity": test_mean_note_velocity, "silent_interval": test_silent_interval, "unique_pitch_count": test_unique_pitch_count, "unique_pitch_count_in_interval": test_unique_pitch_count_in_interval, + "note_timing_entropy": test_note_timing_entropy, + "note_pitch_entropy": test_note_pitch_entropy, } fn = name_to_fn.get(test_name, None) From 96c35a436eca682838151d8d4f58c04fab8a2d31 Mon Sep 17 00:00:00 2001 From: Louis Date: Thu, 5 Dec 2024 19:18:06 +0000 Subject: [PATCH 5/5] fix mypy --- ariautils/midi.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ariautils/midi.py b/ariautils/midi.py index 435cab7..7ea6c29 100644 --- a/ariautils/midi.py +++ b/ariautils/midi.py @@ -1455,8 +1455,8 @@ def test_note_timing_entropy( note_onset_deltas.append(note_onset_delta_ms) total = len(note_lens) - len_counts = {} - onset_deltas_counts = {} + len_counts: dict[float, int] = {} + onset_deltas_counts: dict[float, int] = {} for interval in note_lens: len_counts[interval] = len_counts.get(interval, 0) + 1 for interval in note_onset_deltas: @@ -1504,7 +1504,7 @@ def test_note_pitch_entropy( pitches.append(msg["data"]["pitch"]) total = len(pitches) - counts = {} + counts: dict[int, int] = {} for interval in pitches: counts[interval] = counts.get(interval, 0) + 1