Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 31 additions & 8 deletions ariautils/midi.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,13 +318,36 @@ def resolve_overlaps(self) -> "MidiDict":

return self

# TODO: Update docstring
def enforce_gaps(
self,
min_gap_ms: int,
min_gap_ms: int = 0,
min_gap_by_vel: dict[int, int] | None = None,
min_length_ms: int = 0,
) -> "MidiDict":
"""Enforce at least min_gap_ms between same-pitch notes on the same
MIDI channel, and drop any notes shorter than min_length_ms."""
"""Enforce a minimum gap between consecutive same-pitch notes.

Shortens the end time of a note if it's too close to the next note
of the same pitch and channel. The required gap can be a single value or
vary based on the velocity of the following note.

After gaps are enforced, any note shorter than `min_length_ms` is removed.

Args:
min_gap_ms (int): The default minimum gap in milliseconds. This
value is used as a fallback if `min_gap_by_vel` is provided but
a specific velocity is not found in the map.
min_gap_by_vel (dict[int, int]): An optional dictionary mapping a
MIDI velocity (int) to a desired minimum preceding gap (ms).
min_length_ms (int): The minimum duration for a note in milliseconds.
Notes shorter than this will be removed after processing.
"""

def _get_min_gap_ms(vel: int) -> int:
if min_gap_by_vel is not None:
return min_gap_by_vel[vel]
else:
return min_gap_ms

def _tempo_at_tick(tick: int) -> int:
# find tempo in effect at given tick
Expand All @@ -336,7 +359,6 @@ def _tempo_at_tick(tick: int) -> int:
break
return tempo

# Phase 1: shorten preceding notes to enforce min_gap_ms
note_groups: dict[tuple[int, int], list] = defaultdict(list)
for msg in self.note_msgs:
key = (msg["channel"], msg["data"]["pitch"])
Expand All @@ -347,10 +369,11 @@ def _tempo_at_tick(tick: int) -> int:
for prev, curr in zip(msgs, msgs[1:]):
prev_end_ms = self.tick_to_ms(prev["data"]["end"])
curr_start_ms = self.tick_to_ms(curr["data"]["start"])
gap = curr_start_ms - prev_end_ms
if gap < min_gap_ms:
# compute new end so that curr_start_ms - new_end_ms == min_gap_ms
new_end_ms = curr_start_ms - min_gap_ms
curr_gap_ms = curr_start_ms - prev_end_ms
curr_min_gap_ms = _get_min_gap_ms(curr["data"]["velocity"])
if curr_gap_ms < curr_min_gap_ms:
# Compute new end so that curr_start_ms - new_end_ms == curr_min_gap_ms
new_end_ms = curr_start_ms - curr_min_gap_ms
tempo = _tempo_at_tick(prev["data"]["end"])
new_end_tick = round(
second2tick(
Expand Down
Loading