Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 238 additions & 13 deletions ariautils/midi.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
cast,
)

from ariautils.utils import load_config, load_maestro_metadata_json, get_logger
from ariautils.utils import load_maestro_metadata_json, get_logger

logger = get_logger(__package__)

Expand Down Expand Up @@ -1199,15 +1199,15 @@ def test_mean_note_velocity(

def test_mean_note_len(
midi_dict: MidiDict,
min_mean_len: float,
max_mean_len: float,
min_mean_len_ms: float,
max_mean_len_ms: float,
) -> tuple[bool, float]:
"""Tests the average note length of MIDI messages in milliseconds.

Args:
midi_dict (MidiDict): MidiDict to test.
min_mean_len (float): Minimum average note length.
max_mean_len (float): Maximum average note length.
min_mean_len_ms (float): Minimum average note length.
max_mean_len_ms (float): Maximum average note length.

Returns:
bool: True if passed test, else False.
Expand All @@ -1225,7 +1225,7 @@ def test_mean_note_len(
]

mean_length = sum(note_lengths) / len(note_lengths)
is_valid = min_mean_len <= mean_length <= max_mean_len
is_valid = min_mean_len_ms <= mean_length <= max_mean_len_ms

return is_valid, mean_length

Expand Down Expand Up @@ -1308,9 +1308,6 @@ def _test_unique_pitch_count_in_interval(
min_unique_pitch_cnt: int,
interval_len_s: float,
) -> tuple[bool, tuple[int, float]]:
if not midi_dict.note_msgs:
return False, (0, 0)

note_events = [
(
note_msg["data"]["pitch"],
Expand All @@ -1321,7 +1318,10 @@ def _test_unique_pitch_count_in_interval(
]
note_events = sorted(note_events, key=lambda x: x[1])

WINDOW_STEP_S: Final[int] = 1
if not note_events:
return False, (0, 0)

WINDOW_STEP_S: Final[float] = 1
interval_start_s = (
midi_dict.tick_to_ms(midi_dict.note_msgs[0]["tick"]) / 1000.0
)
Expand Down Expand Up @@ -1382,7 +1382,7 @@ def test_unique_pitch_count_in_interval(

Returns:
bool: True if all tests passed, else False.
tuple[int, int, float]: Results from first failure or final test.
tuple[int, int, float]: Result from first failure or final test:
- interval_len_s: Window length in seconds
- pitch_cnt: Number of unique pitches found
- window_start_s: Start time of the window in seconds
Expand All @@ -1407,7 +1407,132 @@ def test_unique_pitch_count_in_interval(
return True, (test_params["interval_len_s"], pitch_cnt, window_start_s)


# 2.5-3.0 seems like a good pretraining threshold
def _test_note_density_in_interval(
midi_dict: "MidiDict",
max_notes_per_second: int,
max_notes_per_second_per_pitch: int,
interval_len_s: float,
) -> tuple[bool, tuple[float, float, int]]:
note_events = [
(
midi_dict.tick_to_ms(note_msg["data"]["start"]) / 1000.0,
note_msg["data"]["pitch"],
)
for note_msg in midi_dict.note_msgs
if note_msg["channel"] != 9
]
note_events.sort()

if not note_events:
return False, (0.0, 0.0, 0)

WINDOW_STEP_S: Final[float] = 1
interval_start_s = note_events[0][0]
max_window_note_cnt_seen = 0
max_window_start_s: int = 0
max_pitch_cnt_seen = 0
end_idx = 0
notes_in_window: Deque[tuple[float, int]] = deque()
pitch_cnts: dict[int, int] = {}

while end_idx < len(note_events):
interval_end_s = interval_start_s + interval_len_s

for time_s, pitch in note_events[end_idx:]:
if time_s <= interval_end_s:
notes_in_window.append((time_s, pitch))
pitch_cnts[pitch] = pitch_cnts.get(pitch, 0) + 1
end_idx += 1
else:
break

if notes_in_window:
while notes_in_window and notes_in_window[0][0] < interval_start_s:
_, old_pitch = notes_in_window.popleft()
pitch_cnts[old_pitch] -= 1
if pitch_cnts[old_pitch] == 0:
del pitch_cnts[old_pitch]

notes_in_window_cnt = len(notes_in_window)
max_pitch_cnt = max(pitch_cnts.values()) if pitch_cnts else 0

if notes_in_window_cnt > max_window_note_cnt_seen:
max_window_note_cnt_seen = notes_in_window_cnt
max_window_start_s = int(interval_start_s)

if max_pitch_cnt > max_pitch_cnt_seen:
max_pitch_cnt_seen = max_pitch_cnt

interval_start_s += WINDOW_STEP_S

max_allowed_notes = max_notes_per_second * interval_len_s
max_allowed_pitch_notes = max_notes_per_second_per_pitch * interval_len_s

is_valid = (
max_window_note_cnt_seen <= max_allowed_notes
and max_pitch_cnt_seen <= max_allowed_pitch_notes
)

return is_valid, (
max_window_note_cnt_seen / interval_len_s,
max_pitch_cnt_seen / interval_len_s,
max_window_start_s,
)


def test_note_density_in_interval(
midi_dict: "MidiDict", test_params_list: list[dict]
) -> tuple[bool, tuple[float, float, float, int]]:
"""Tests if note density exceeds thresholds within sliding time windows.

Args:
midi_dict: MidiDict to test.
test_params_list: List of parameter dicts, each containing:
- max_notes_per_second (int): Maximum allowed notes per second
- max_notes_per_second_per_pitch (int): Maximum allowed notes per
second for each pitch
- interval_len_s (float): Length of sliding window in seconds

Returns:
bool: True if all tests passed, else False
tuple[float, float, float, int]: Result from first failure or final test:
- interval_len_s: Window length in seconds
- max_notes_per_second: Maximum notes per second seen
- max_notes_per_second_per_pitch: Maximum notes per second for each
pitch
- window_start_s: Start time of the window in seconds
"""

for test_params in test_params_list:
success, (
notes_per_second,
notes_per_second_per_pitch,
interval_start_s,
) = _test_note_density_in_interval(
midi_dict=midi_dict,
max_notes_per_second=test_params["max_notes_per_second"],
max_notes_per_second_per_pitch=test_params[
"max_notes_per_second_per_pitch"
],
interval_len_s=test_params["interval_len_s"],
)

if success is False:
return False, (
test_params["interval_len_s"],
notes_per_second,
notes_per_second_per_pitch,
interval_start_s,
)

return True, (
test_params["interval_len_s"],
notes_per_second,
notes_per_second_per_pitch,
interval_start_s,
)


def test_note_timing_entropy(
midi_dict: MidiDict,
min_length_entropy: float,
Expand Down Expand Up @@ -1477,7 +1602,6 @@ def test_note_timing_entropy(
return is_valid, (len_entropy, onset_delta_entropy)


# 3.0 feels about right for pretraining
def test_note_pitch_entropy(
midi_dict: MidiDict, min_entropy: float
) -> tuple[bool, float]:
Expand Down Expand Up @@ -1516,6 +1640,105 @@ def test_note_pitch_entropy(
return is_valid, entropy


def test_repetitive_content(
midi_dict: MidiDict,
min_length_m: float,
num_chunks: int,
kl_tolerance: float,
) -> tuple[bool, float]:
"""Tests if a MIDI file is repetitive by comparing pitch distributions.

Calculates KL-Divergence between pitch distributions for evenly spaced
chunks:

https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence

Args:
midi_dict (MidiDict): MidiDict to test.
min_length_m: Minimum length in minutes required for the test.
num_chunks: Number of chunks to divide the MIDI file into.
kl_tolerance: Maximum allowed KL-divergence between distributions
Lower values mean distributions must be more similar.

Returns:
bool: False if KL-divergence within tolerance.
float: Maximum KL-divergence found between any two chunks.
"""
note_msgs_nd = [msg for msg in midi_dict.note_msgs if msg["channel"] != 9]
if not note_msgs_nd:
return False, 0.0

start_time_s = (
midi_dict.tick_to_ms(note_msgs_nd[0]["data"]["start"]) / 1000.0
)
end_time_s = (
midi_dict.tick_to_ms(note_msgs_nd[-1]["data"]["start"]) / 1000.0
)
duration_s = end_time_s - start_time_s

if duration_s / 60.0 < min_length_m:
return True, 0.0

chunk_size_s = duration_s / num_chunks
chunk_distributions: list[dict[int, float]] = []
chunk_boundaries_s = [
start_time_s + (i * chunk_size_s) for i in range(num_chunks + 1)
]

curr_chunk = 0
msg_idx = 0
while curr_chunk < num_chunks and msg_idx < len(note_msgs_nd):
chunk_start_ms = chunk_boundaries_s[curr_chunk] * 1000.0
chunk_end_ms = chunk_boundaries_s[curr_chunk + 1] * 1000.0

curr_chunk_pitches: dict[int, int] = {p: 0 for p in range(0, 128)}
while msg_idx < len(note_msgs_nd):
note_start_ms = midi_dict.tick_to_ms(
note_msgs_nd[msg_idx]["data"]["start"]
)

if note_start_ms >= chunk_end_ms:
break

if note_start_ms >= chunk_start_ms:
pitch = note_msgs_nd[msg_idx]["data"]["pitch"]
curr_chunk_pitches[pitch] += 1

msg_idx += 1

total = sum(curr_chunk_pitches.values())
if total > 0:
distribution: dict[int, float] = {
k: v / total for k, v in curr_chunk_pitches.items() if v > 0
}
chunk_distributions.append(distribution)

curr_chunk += 1

if len(chunk_distributions) < 2:
return True, 0.0

# Calculate KL divergence between all pairs of distributions
max_kl = 0.0
for i in range(len(chunk_distributions)):
for j in range(i + 1, len(chunk_distributions)):
dist1 = chunk_distributions[i]
dist2 = chunk_distributions[j]

kl_div = 0.0
for pitch in range(0, 128):
p = dist1.get(pitch, 1e-5)
q = dist2.get(pitch, 1e-5)
kl_div += p * log2(p / q)

max_kl = max(max_kl, abs(kl_div))

is_valid = max_kl > kl_tolerance

return is_valid, max_kl


# TODO: Refactor tests into a new module
def get_test_fn(
test_name: str,
) -> Callable[Concatenate[MidiDict, ...], tuple[bool, Any]]:
Expand All @@ -1532,8 +1755,10 @@ def get_test_fn(
"silent_interval": test_silent_interval,
"unique_pitch_count": test_unique_pitch_count,
"unique_pitch_count_in_interval": test_unique_pitch_count_in_interval,
"note_density_in_interval": test_note_density_in_interval,
"note_timing_entropy": test_note_timing_entropy,
"note_pitch_entropy": test_note_pitch_entropy,
"repetitive_content": test_repetitive_content,
}

fn = name_to_fn.get(test_name, None)
Expand Down
Loading