Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 28 additions & 21 deletions ariautils/midi.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,10 @@ def from_midi(cls, mid_path: str | Path) -> "MidiDict":
"""Loads a MIDI file from path and returns MidiDict."""

mid = mido.MidiFile(mid_path)
return cls(**midi_to_dict(mid))
midi_dict = cls(**midi_to_dict(mid))
midi_dict.metadata["abs_load_path"] = str(Path(mid_path).absolute())

return midi_dict

def calculate_hash(self) -> str:
msg_dict_to_hash = cast(dict, self.get_msg_dict())
Expand Down Expand Up @@ -892,9 +895,13 @@ def to_ascii(s: str) -> str:


def meta_composer_filename(
mid: mido.MidiFile, msg_data: MidiDictData, composer_names: list
midi_dict_data: MidiDictData, composer_names: list
) -> dict[str, str]:
file_name = Path(str(mid.filename)).stem
abs_load_path = midi_dict_data["metadata"].get("abs_load_path")
if abs_load_path is None:
return {}

file_name = Path(abs_load_path).stem
matched_names_unique = set()
for name in composer_names:
if _match_word(file_name, name):
Expand All @@ -909,9 +916,13 @@ def meta_composer_filename(


def meta_form_filename(
mid: mido.MidiFile, msg_data: MidiDictData, form_names: list
midi_dict_data: MidiDictData, form_names: list
) -> dict[str, str]:
file_name = Path(str(mid.filename)).stem
abs_load_path = midi_dict_data["metadata"].get("abs_load_path")
if abs_load_path is None:
return {}

file_name = Path(abs_load_path).stem
matched_names_unique = set()
for name in form_names:
if _match_word(file_name, name):
Expand All @@ -926,10 +937,10 @@ def meta_form_filename(


def meta_composer_metamsg(
mid: mido.MidiFile, msg_data: MidiDictData, composer_names: list
midi_dict_data: MidiDictData, composer_names: list
) -> dict[str, str]:
matched_names_unique = set()
for msg in msg_data["meta_msgs"]:
for msg in midi_dict_data["meta_msgs"]:
for name in composer_names:
if _match_word(msg["data"], name):
matched_names_unique.add(name)
Expand All @@ -944,8 +955,7 @@ def meta_composer_metamsg(

# TODO: Needs testing
def meta_maestro_json(
mid: mido.MidiFile,
msg_data: MidiDictData,
midi_dict_data: MidiDictData,
composer_names: list,
form_names: list,
) -> dict[str, str]:
Expand All @@ -957,11 +967,12 @@ def meta_maestro_json(
the form file_name: {"composer": str, "title": str}.
"""

_file_name = Path(str(mid.filename)).name
_file_name_without_ext = os.path.splitext(_file_name)[0]
metadata = load_maestro_metadata_json().get(
_file_name_without_ext + ".midi", None
)
abs_load_path = midi_dict_data["metadata"].get("abs_load_path")
if abs_load_path is None:
return {}

file_name = Path(abs_load_path).stem
metadata = load_maestro_metadata_json().get(file_name + ".midi", None)
if metadata == None:
return {}

Expand All @@ -986,22 +997,18 @@ def meta_maestro_json(
return res


def meta_abs_path(mid: mido.MidiFile, msg_data: MidiDictData) -> dict[str, str]:
return {"abs_path": str(Path(str(mid.filename)).absolute())}


# TODO: Add metadata function compatible with aria-midi
def get_metadata_fn(
metadata_process_name: str,
) -> Callable[Concatenate[mido.MidiFile, MidiDictData, ...], dict[str, str]]:
) -> Callable[Concatenate[MidiDictData, ...], dict[str, str]]:
name_to_fn: dict[
str,
Callable[Concatenate[mido.MidiFile, MidiDictData, ...], dict[str, str]],
Callable[Concatenate[MidiDictData, ...], dict[str, str]],
] = {
"composer_filename": meta_composer_filename,
"composer_metamsg": meta_composer_metamsg,
"form_filename": meta_form_filename,
"maestro_json": meta_maestro_json,
"abs_path": meta_abs_path,
}

fn = name_to_fn.get(metadata_process_name, None)
Expand Down
Loading