Skip to content

Commit a19784c

Browse files
authored
Add pitch frequency MidiDict test functions (#9)
* pyproject.toml fix * fix * fix config module * add note freq test fns * adjust config
1 parent 7a8b553 commit a19784c

File tree

3 files changed

+162
-3
lines changed

3 files changed

+162
-3
lines changed

ariautils/config/__init__.py

Whitespace-only changes.

ariautils/config/config.json

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,29 @@
3232
"args": {
3333
"min_seconds": 30
3434
}
35+
},
36+
"silent_interval":{
37+
"run": false,
38+
"args": {
39+
"max_silence_s": 30
40+
}
41+
},
42+
"unique_pitch_count":{
43+
"run": false,
44+
"args": {
45+
"min_num_unique_pitches": 8
46+
}
47+
},
48+
"unique_pitch_count_in_interval":{
49+
"run": false,
50+
"args": {
51+
"test_params":
52+
[
53+
{"min_unique_pitch_cnt": 2, "interval_len_s": 20},
54+
{"min_unique_pitch_cnt": 3, "interval_len_s": 30},
55+
{"min_unique_pitch_cnt": 5, "interval_len_s": 90}
56+
]
57+
}
3558
}
3659
},
3760
"pre_processing": {

ariautils/midi.py

Lines changed: 139 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@
99
import mido
1010

1111
from mido.midifiles.units import tick2second
12-
from collections import defaultdict
12+
from collections import defaultdict, deque
1313
from pathlib import Path
1414
from typing import (
1515
Any,
1616
Final,
17+
Deque,
1718
Concatenate,
1819
Callable,
1920
TypeAlias,
@@ -22,12 +23,14 @@
2223
cast,
2324
)
2425

25-
from ariautils.utils import load_config, load_maestro_metadata_json
26+
from ariautils.utils import load_config, load_maestro_metadata_json, get_logger
2627

28+
logger = get_logger(__package__)
2729

2830
# TODO:
2931
# - Remove unneeded comments
30-
# - Add asserts
32+
# - Add asserts (e.g., for test and metadata functions)
33+
# - Add docstrings to test_ functions
3134

3235

3336
class MetaMessage(TypedDict):
@@ -1136,6 +1139,136 @@ def test_min_length(
11361139
return True, total_duration_ms / 1e3
11371140

11381141

1142+
def test_silent_interval(
1143+
midi_dict: MidiDict, max_silence_s: float
1144+
) -> tuple[bool, float]:
1145+
if not midi_dict.note_msgs or not midi_dict.tempo_msgs:
1146+
return False, 0.0
1147+
1148+
longest_silence_s: float = 0.0
1149+
last_note_end_tick = midi_dict.note_msgs[0]["data"]["end"]
1150+
1151+
for note_msg in midi_dict.note_msgs[1:]:
1152+
note_start_tick = note_msg["data"]["start"]
1153+
1154+
if note_start_tick > last_note_end_tick:
1155+
longest_silence_s = max(
1156+
longest_silence_s,
1157+
get_duration_ms(
1158+
start_tick=last_note_end_tick,
1159+
end_tick=note_start_tick,
1160+
tempo_msgs=midi_dict.tempo_msgs,
1161+
ticks_per_beat=midi_dict.ticks_per_beat,
1162+
)
1163+
/ 1000.0,
1164+
)
1165+
1166+
if longest_silence_s >= max_silence_s:
1167+
return False, longest_silence_s
1168+
1169+
last_note_end_tick = max(last_note_end_tick, note_msg["data"]["end"])
1170+
1171+
return True, longest_silence_s
1172+
1173+
1174+
def test_unique_pitch_count(
1175+
midi_dict: MidiDict, min_num_unique_pitches: int
1176+
) -> tuple[bool, int]:
1177+
if not midi_dict.note_msgs:
1178+
return False, 0
1179+
1180+
present_pitches = {
1181+
note_msg["data"]["pitch"]
1182+
for note_msg in midi_dict.note_msgs
1183+
if note_msg["channel"] != 9
1184+
}
1185+
1186+
unique_pitches = len(present_pitches)
1187+
if unique_pitches < min_num_unique_pitches:
1188+
return False, unique_pitches
1189+
else:
1190+
return True, unique_pitches
1191+
1192+
1193+
def _test_unique_pitch_count_in_interval(
1194+
midi_dict: MidiDict,
1195+
min_unique_pitch_cnt: int,
1196+
interval_len_s: int,
1197+
) -> tuple[bool, int]:
1198+
if not midi_dict.note_msgs:
1199+
return False, 0
1200+
1201+
note_events = [
1202+
(
1203+
note_msg["data"]["pitch"],
1204+
midi_dict.tick_to_ms(note_msg["data"]["start"]),
1205+
)
1206+
for note_msg in midi_dict.note_msgs
1207+
if note_msg["channel"] != 9
1208+
]
1209+
note_events = sorted(note_events, key=lambda x: x[1])
1210+
1211+
WINDOW_STEP_S: Final[int] = 1
1212+
interval_start_s = (
1213+
midi_dict.tick_to_ms(midi_dict.note_msgs[0]["tick"]) / 1000.0
1214+
)
1215+
min_window_pitch_count_seen = 128
1216+
end_idx = 0
1217+
notes_in_window: Deque[tuple[int, int]] = deque()
1218+
while end_idx < len(note_events):
1219+
interval_end_s = interval_start_s + interval_len_s
1220+
1221+
for note_msg_tuple in note_events[end_idx:]:
1222+
_, _start_ms = note_msg_tuple
1223+
_start_s = _start_ms / 1000.0
1224+
if _start_s <= interval_end_s:
1225+
notes_in_window.append(note_msg_tuple)
1226+
end_idx += 1
1227+
else:
1228+
break
1229+
1230+
if len(notes_in_window) > 0:
1231+
while notes_in_window:
1232+
_, _start_ms = notes_in_window[0]
1233+
_start_s = _start_ms / 1000.0
1234+
if _start_s < interval_start_s:
1235+
notes_in_window.popleft()
1236+
else:
1237+
break
1238+
1239+
unique_pitches_in_window = {
1240+
note_tuple[0] for note_tuple in notes_in_window
1241+
}
1242+
1243+
min_window_pitch_count_seen = min(
1244+
min_window_pitch_count_seen,
1245+
len(unique_pitches_in_window),
1246+
)
1247+
1248+
interval_start_s += WINDOW_STEP_S
1249+
1250+
if min_window_pitch_count_seen < min_unique_pitch_cnt:
1251+
return False, min_window_pitch_count_seen
1252+
else:
1253+
return True, min_window_pitch_count_seen
1254+
1255+
1256+
def test_unique_pitch_count_in_interval(
1257+
midi_dict: MidiDict, test_params_list: list[dict]
1258+
) -> tuple[bool, int]:
1259+
1260+
for test_params in test_params_list:
1261+
success, val = _test_unique_pitch_count_in_interval(
1262+
midi_dict=midi_dict,
1263+
min_unique_pitch_cnt=test_params["min_unique_pitch_cnt"],
1264+
interval_len_s=test_params["interval_len_s"],
1265+
)
1266+
if success is False:
1267+
return False, val
1268+
1269+
return True, val
1270+
1271+
11391272
def get_test_fn(
11401273
test_name: str,
11411274
) -> Callable[Concatenate[MidiDict, ...], tuple[bool, Any]]:
@@ -1147,6 +1280,9 @@ def get_test_fn(
11471280
"total_note_frequency": test_note_frequency,
11481281
"note_frequency_per_instrument": test_note_frequency_per_instrument,
11491282
"min_length": test_min_length,
1283+
"silent_interval": test_silent_interval,
1284+
"unique_pitch_count": test_unique_pitch_count,
1285+
"unique_pitch_count_in_interval": test_unique_pitch_count_in_interval,
11501286
}
11511287

11521288
fn = name_to_fn.get(test_name, None)

0 commit comments

Comments
 (0)