99import mido
1010
1111from mido .midifiles .units import tick2second
12- from collections import defaultdict
12+ from collections import defaultdict , deque
1313from pathlib import Path
1414from typing import (
1515 Any ,
1616 Final ,
17+ Deque ,
1718 Concatenate ,
1819 Callable ,
1920 TypeAlias ,
2223 cast ,
2324)
2425
25- from ariautils .utils import load_config , load_maestro_metadata_json
26+ from ariautils .utils import load_config , load_maestro_metadata_json , get_logger
2627
28+ logger = get_logger (__package__ )
2729
2830# TODO:
2931# - Remove unneeded comments
30- # - Add asserts
32+ # - Add asserts (e.g., for test and metadata functions)
33+ # - Add docstrings to test_ functions
3134
3235
3336class MetaMessage (TypedDict ):
@@ -1136,6 +1139,136 @@ def test_min_length(
11361139 return True , total_duration_ms / 1e3
11371140
11381141
1142+ def test_silent_interval (
1143+ midi_dict : MidiDict , max_silence_s : float
1144+ ) -> tuple [bool , float ]:
1145+ if not midi_dict .note_msgs or not midi_dict .tempo_msgs :
1146+ return False , 0.0
1147+
1148+ longest_silence_s : float = 0.0
1149+ last_note_end_tick = midi_dict .note_msgs [0 ]["data" ]["end" ]
1150+
1151+ for note_msg in midi_dict .note_msgs [1 :]:
1152+ note_start_tick = note_msg ["data" ]["start" ]
1153+
1154+ if note_start_tick > last_note_end_tick :
1155+ longest_silence_s = max (
1156+ longest_silence_s ,
1157+ get_duration_ms (
1158+ start_tick = last_note_end_tick ,
1159+ end_tick = note_start_tick ,
1160+ tempo_msgs = midi_dict .tempo_msgs ,
1161+ ticks_per_beat = midi_dict .ticks_per_beat ,
1162+ )
1163+ / 1000.0 ,
1164+ )
1165+
1166+ if longest_silence_s >= max_silence_s :
1167+ return False , longest_silence_s
1168+
1169+ last_note_end_tick = max (last_note_end_tick , note_msg ["data" ]["end" ])
1170+
1171+ return True , longest_silence_s
1172+
1173+
1174+ def test_unique_pitch_count (
1175+ midi_dict : MidiDict , min_num_unique_pitches : int
1176+ ) -> tuple [bool , int ]:
1177+ if not midi_dict .note_msgs :
1178+ return False , 0
1179+
1180+ present_pitches = {
1181+ note_msg ["data" ]["pitch" ]
1182+ for note_msg in midi_dict .note_msgs
1183+ if note_msg ["channel" ] != 9
1184+ }
1185+
1186+ unique_pitches = len (present_pitches )
1187+ if unique_pitches < min_num_unique_pitches :
1188+ return False , unique_pitches
1189+ else :
1190+ return True , unique_pitches
1191+
1192+
1193+ def _test_unique_pitch_count_in_interval (
1194+ midi_dict : MidiDict ,
1195+ min_unique_pitch_cnt : int ,
1196+ interval_len_s : int ,
1197+ ) -> tuple [bool , int ]:
1198+ if not midi_dict .note_msgs :
1199+ return False , 0
1200+
1201+ note_events = [
1202+ (
1203+ note_msg ["data" ]["pitch" ],
1204+ midi_dict .tick_to_ms (note_msg ["data" ]["start" ]),
1205+ )
1206+ for note_msg in midi_dict .note_msgs
1207+ if note_msg ["channel" ] != 9
1208+ ]
1209+ note_events = sorted (note_events , key = lambda x : x [1 ])
1210+
1211+ WINDOW_STEP_S : Final [int ] = 1
1212+ interval_start_s = (
1213+ midi_dict .tick_to_ms (midi_dict .note_msgs [0 ]["tick" ]) / 1000.0
1214+ )
1215+ min_window_pitch_count_seen = 128
1216+ end_idx = 0
1217+ notes_in_window : Deque [tuple [int , int ]] = deque ()
1218+ while end_idx < len (note_events ):
1219+ interval_end_s = interval_start_s + interval_len_s
1220+
1221+ for note_msg_tuple in note_events [end_idx :]:
1222+ _ , _start_ms = note_msg_tuple
1223+ _start_s = _start_ms / 1000.0
1224+ if _start_s <= interval_end_s :
1225+ notes_in_window .append (note_msg_tuple )
1226+ end_idx += 1
1227+ else :
1228+ break
1229+
1230+ if len (notes_in_window ) > 0 :
1231+ while notes_in_window :
1232+ _ , _start_ms = notes_in_window [0 ]
1233+ _start_s = _start_ms / 1000.0
1234+ if _start_s < interval_start_s :
1235+ notes_in_window .popleft ()
1236+ else :
1237+ break
1238+
1239+ unique_pitches_in_window = {
1240+ note_tuple [0 ] for note_tuple in notes_in_window
1241+ }
1242+
1243+ min_window_pitch_count_seen = min (
1244+ min_window_pitch_count_seen ,
1245+ len (unique_pitches_in_window ),
1246+ )
1247+
1248+ interval_start_s += WINDOW_STEP_S
1249+
1250+ if min_window_pitch_count_seen < min_unique_pitch_cnt :
1251+ return False , min_window_pitch_count_seen
1252+ else :
1253+ return True , min_window_pitch_count_seen
1254+
1255+
1256+ def test_unique_pitch_count_in_interval (
1257+ midi_dict : MidiDict , test_params_list : list [dict ]
1258+ ) -> tuple [bool , int ]:
1259+
1260+ for test_params in test_params_list :
1261+ success , val = _test_unique_pitch_count_in_interval (
1262+ midi_dict = midi_dict ,
1263+ min_unique_pitch_cnt = test_params ["min_unique_pitch_cnt" ],
1264+ interval_len_s = test_params ["interval_len_s" ],
1265+ )
1266+ if success is False :
1267+ return False , val
1268+
1269+ return True , val
1270+
1271+
11391272def get_test_fn (
11401273 test_name : str ,
11411274) -> Callable [Concatenate [MidiDict , ...], tuple [bool , Any ]]:
@@ -1147,6 +1280,9 @@ def get_test_fn(
11471280 "total_note_frequency" : test_note_frequency ,
11481281 "note_frequency_per_instrument" : test_note_frequency_per_instrument ,
11491282 "min_length" : test_min_length ,
1283+ "silent_interval" : test_silent_interval ,
1284+ "unique_pitch_count" : test_unique_pitch_count ,
1285+ "unique_pitch_count_in_interval" : test_unique_pitch_count_in_interval ,
11501286 }
11511287
11521288 fn = name_to_fn .get (test_name , None )
0 commit comments