|
1 | | -# gpt-aria |
| 1 | +# Aria |
2 | 2 |
|
3 | | -[Discord](https://discord.com/invite/zBGx3azzUn) |
| 3 | +This repository contains training, inference, and evaluation code for the paper [*Scaling Self-Supervised Representation Learning for Symbolic Piano Performance (ISMIR 2025)*](https://example.com/), as well as implementations of our real-time piano continuation demo. *Aria* is a pretrained autoregressive generative model for symbolic music, based on the LLaMA 3.2 (1B) architecture, which was trained on ~60k hours of MIDI transcriptions of expressive solo-piano recordings. Alongside the base model, we are releasing a checkpoint finetuned to improve generative quality, as well as a checkpoint finetuned to produce general-purpose piano MIDI embeddings using a SimCSE-style contrastive training objective. |
4 | 4 |
|
5 | | -A repository containing resources for pre-training, fine-tuning, and evaluating musical (MIDI) transformer models. |
| 5 | +📖 Read our [release blog post](https://example.com/) and [paper](https://example.com/) |
| 6 | +🤗 Access our models via the [HuggingFace page](https://huggingface.co/loubb/aria-medium-base) |
| 7 | +📊 Get access to our training dataset [Aria-MIDI](https://huggingface.co/datasets/loubb/aria-midi) and train your own models |
6 | 8 |
|
7 | | -***Note that this project is under active development*** |
| 9 | +## Installation |
8 | 10 |
|
9 | | -## Description |
| 11 | +Installation requires Python 3.11+. To install the package and all dependencies with pip: |
10 | 12 |
|
11 | | -The main goal of the gpt-aria project is to create a suite of powerful pre-trained generative (symbolic) music models. We want to investigate how modern training (pre-training & fine-tuning) techniques can be used to improve the quality/usefulness of such models. Alongside this we are building various data (MIDI) preprocessing tools, allowing **you** to easily fine-tune our models on your own data. |
| 13 | +```bash |
| 14 | +git clone https://github.com/EleutherAI/aria |
| 15 | +cd aria |
| 16 | +pip install -e ".[all]" |
| 17 | +``` |
12 | 18 |
|
13 | | -If you are new to symbolic music models, a good place to start are the following projects/blogposts by Google Magenta and OpenAI: |
| 19 | +## Quickstart |
14 | 20 |
|
15 | | -- [Music Transformer](https://magenta.tensorflow.org/music-transformer) |
16 | | -- [MuseNet](https://openai.com/research/musenet) |
| 21 | +Download model weights from the official HuggingFace page for our pretrained model, as well as checkpoints finetuned for piano-continuation and generating MIDI-embeddings: |
17 | 22 |
|
18 | | - Long story short: Transformer + MIDI + GPUs = 🎵 x ∞ |
| 23 | +- `aria-medium-base` ([huggingface](https://huggingface.co/loubb/aria-medium-base), [direct-download](https://huggingface.co/loubb/aria-medium-base/resolve/main/model.safetensors?download=true)) |
| 24 | +- `aria-medium-gen`([huggingface](https://huggingface.co/loubb/aria-medium-gen), [direct-download](https://huggingface.co/loubb/aria-medium-gen/resolve/main/model.safetensors?download=true)) |
| 25 | +- `aria-medium-embedding`([huggingface](https://huggingface.co/loubb/aria-medium-embedding), [direct-download](https://huggingface.co/loubb/aria-medium-embedding/resolve/main/model.safetensors?download=true)) |
19 | 26 |
|
20 | | -## Installation |
| 27 | +### Inference (Prompt Continuation) |
21 | 28 |
|
22 | | -Make sure you are using Python 3.10+. Note that I haven't explicitly developed this project for anything other than Linux. If you are using Windows, things might not work properly. In this case I suggest installing using WSL. |
| 29 | +We provide optimized model implementations for PyTorch (CUDA) and MLX (Apple Silicon). You can generate continuations of a MIDI file using the CLI, e.g., using CUDA (Linux): |
23 | 30 |
|
24 | | -``` |
25 | | -git clone https://github.com/eleutherai/aria |
26 | | -cd aria |
27 | | -pip install -e . |
| 31 | +```bash |
| 32 | +aria generate \ |
| 33 | + --backend torch_cuda \ |
| 34 | + --checkpoint_path <path-to-model-weights> \ |
| 35 | + --prompt_midi_path <path-to-midi-file-to-continue> \ |
| 36 | + --prompt_duration <length-in-seconds-for-prompt> \ |
| 37 | + --variations <number-of-variations-to-generate> \ |
| 38 | + --temp 0.98 \ |
| 39 | + --min_p 0.035 \ |
| 40 | + --length 2048 \ |
| 41 | + --save_dir <dir-to-save-results> |
28 | 42 | ``` |
29 | 43 |
|
30 | | -## Inference |
| 44 | +Since the model has not been post-trained with instruction tuning or RLHF (similar to pre-instruct GPT models), it is very sensitive to input quality and performs best when prompted with well-played music. To get prompt MIDI files, see the `example-prompts/` directory, explore the [Aria-MIDI](https://huggingface.co/datasets/loubb/aria-midi) dataset, or transcribe your own files using our [piano-transcription model](https://github.com/EleutherAI/aria-amt). For a full list of sampling options: `aria generate -h`. If you wish to do inference on the CPU, please see the platform-agnostic implementation on our HuggingFace page [link]. |
31 | 45 |
|
32 | | -You can find preliminary checkpoints at the following locations |
| 46 | +### Intended Use and Limitations |
33 | 47 |
|
34 | | -Finetuned piano-only checkpoints (improved robustness): |
| 48 | +Aria performs best when **continuing existing piano MIDI files** rather than generating music from scratch. While multi-track tokenization and generation are supported, the model was trained primarily on **single-track expressive piano performances**, and we recommend using single-track inputs for optimal results. |
35 | 49 |
|
36 | | -``` |
37 | | -large - https://storage.googleapis.com/aria-checkpoints/large-abs-inst.safetensors |
38 | | -``` |
| 50 | +Due to the high representation of popular classical works (e.g., Chopin) in the training data and the difficulty of complete deduplication, the model may **memorize or closely reproduce** such pieces. For more original outputs, we suggest prompting Aria with **lesser-known works or your own compositions**. |
39 | 51 |
|
40 | | -Pretrained checkpoints: |
| 52 | +### Inference (MIDI embeddings) |
41 | 53 |
|
| 54 | +You can generate embeddings from MIDI files using the `aria.embeddings` module. This is primarily exposed with the `get_global_embedding_from_midi` function, for example: |
| 55 | + |
| 56 | +```python |
| 57 | +from aria.embeddings import get_global_embedding_from_midi |
| 58 | +from aria.model import TransformerEMB, ModelConfig |
| 59 | +from aria.config import load_model_config |
| 60 | +from ariautils.tokenizer import AbsTokenizer |
| 61 | + |
| 62 | +# Load model |
| 63 | +model_config = ModelConfig(**load_model_config(name="medium-emb")) |
| 64 | +model_config.set_vocab_size(AbsTokenizer().vocab_size) |
| 65 | +model = TransformerEMB(model_config) |
| 66 | +state_dict = load_file(filename=CHECKPOINT_PATH) |
| 67 | +model.load_state_dict(state_dict=state_dict, strict=True) |
| 68 | + |
| 69 | +# Generate embedding |
| 70 | +embedding = get_global_embedding_from_midi( |
| 71 | + model=model, |
| 72 | + midi_path=MIDI_PATH, |
| 73 | + device="cpu", |
| 74 | +) |
42 | 75 | ``` |
43 | | -large - https://storage.googleapis.com/aria-checkpoints/large-abs-pt.bin |
44 | | -medium - https://storage.googleapis.com/aria-checkpoints/medium-abs-pt.bin |
45 | | -small - https://storage.googleapis.com/aria-checkpoints/small-abs-pt.bin |
46 | | -``` |
47 | 76 |
|
48 | | -You can then sample using the cli: |
| 77 | +Our embedding model was trained to capture composition-level and performance-level attributes, and therefore might not be appropriate for every use case. |
| 78 | + |
| 79 | +## Real-time demo |
| 80 | + |
| 81 | +In `demo/` we provide CUDA (Linux/PyTorch) and MLX (Apple Silicon) implementations of the real-time interactive piano-continuation demo showcased in our release blog post. For the demo we used an acoustic Yamaha Disklavier piano with simultaneous MIDI input and output ports connected via a standard MIDI interface. |
| 82 | + |
| 83 | +❗**NOTE**: Responsiveness of the real-time demo is dependent on your system configuration, e.g., GPU FLOPS and memory bandwidth. |
49 | 84 |
|
| 85 | +A MIDI input device is not strictly required to play around with the demo: By using the `--midi_path` and `--midi_through` arguments you can mock real-time input by playing from a MIDI file. All that is required are MIDI drivers (e.g., CoreMIDI, ALSA) and a virtual software instrument (e.g., Fluidsynth, Pianoteq) to render the output. |
| 86 | + |
| 87 | +Example usage (MLX): |
| 88 | + |
| 89 | +```bash |
| 90 | +MIDI_PATH="example-prompts/pokey_jazz.mid" |
| 91 | + |
| 92 | +python demo/demo_mlx.py \ |
| 93 | + --checkpoint <checkpoint-path> \ |
| 94 | + --midi_path ${MIDI_PATH} \ |
| 95 | + --midi_through <port-to-stream-midi-file-through> \ |
| 96 | + --midi_out <port-to-stream-generation-over> \ |
| 97 | + --save_path <path-to-save-result> \ |
| 98 | + --temp 0.98 \ |
| 99 | + --min_p 0.035 |
50 | 100 | ``` |
51 | | -aria sample \ |
52 | | - -m large \ |
53 | | - -c <path-to-checkpoint> \ |
54 | | - -p <path-to-midifile> \ |
55 | | - -var <num-variations-to-generate> \ |
56 | | - -trunc <seconds-in-to-truncate-prompt> \ |
57 | | - -l <number-of-tokens-to-generate> \ |
58 | | - -temp 0.95 \ |
59 | | - -e |
| 101 | + |
| 102 | +## Evaluation |
| 103 | + |
| 104 | +We provide the specific files/splits we used for Aria-MIDI derived linear-probe and classification evaluations. These can be downloaded from HuggingFace ([direct-download](https://huggingface.co/loubb/aria-medium-base/resolve/main/eval-splits.tar.gz?download=true)). Class labels are provided in `metadata.json` with the schema: |
| 105 | + |
| 106 | +```json |
| 107 | +{ |
| 108 | + "<category>": { |
| 109 | + "<split-name>": { |
| 110 | + "<relative/path/to/file.mid>": "<metadata_value_for_that_category>", |
| 111 | + ... |
| 112 | + }, |
| 113 | + ... |
| 114 | + }, |
| 115 | + ... |
| 116 | +} |
60 | 117 | ``` |
61 | 118 |
|
62 | | -You can use `aria sample -h` to see a full list of options. If you wish to sample from a pretrained checkpoint, please use the `-pt` flag. |
| 119 | +## License and Attribution |
63 | 120 |
|
| 121 | +The Aria project has been kindly supported by EleutherAI, Stability AI, as well as by a compute grant from the Ministry of Science and ICT of Korea. Our models and MIDI tooling are released under the Apache-2.0 license. If you use the models or tooling for follow-up work, please cite the paper in which they were introduced: |
64 | 122 |
|
| 123 | +```bibtex |
| 124 | +@inproceedings{bradshawscaling, |
| 125 | + title={Scaling Self-Supervised Representation Learning for Symbolic Piano Performance}, |
| 126 | + author={Bradshaw, Louis and Fan, Honglu and Spangher, Alex and Biderman, Stella and Colton, Simon}, |
| 127 | + booktitle={arXiv preprint}, |
| 128 | + year={2025}, |
| 129 | + url={https://arxiv.org/abs/2504.15071} |
| 130 | +} |
| 131 | +``` |
0 commit comments