Skip to content

Commit f017028

Browse files
authored
bump to v1 (#119)
* demo * demo fix * mess it all up agian * demo finished * undo mistake * update demo * add prefill compile * add class finetuning * add seq sep option to PretrainingDataset * change from genre to composer * update emb eval scripts * add explore script * add contrastive ft * add missing changes * add loop * fix arg bug * update eval * fix eval hang * add data aug * fix data aug * formalize eval * eval scripts * fix range bug * add m3 only embeddings * update script for m3 embeddings * update for pianist eval * add pianist8 dataset script * adjust per file emb logic and update scripts * update datasets/training/model scripts to support embedding conditioning * add ft-dataset script * change use embeddings train logic * fix model ft loading * fix arg * fix ddp model error * add pca * keshav * keshav add args * fix keshav * update sampling and demo * add looping and ending to demo * push mlx imp for test * fix sample script * add continuous prefill and speculative duration calculation * add off-msg streaming and fix timing alignment * fix early-off logic with dumb hack * fix stream_midi logic * port demo to mlx * add script * update mlx demo * partial tree refactor for release * add resid dropout to model * import fix * inference tree skeleton * fix tree * rm scripts * refactor entrypoint for generate * cfg conditioned generation refactored for torch_cuda * add mlx backend for conditioned generation * fix mlx backend for conditioned gen * update cli flags to standard unix format * migrate to pyproject.toml * add toml * remove old plan * add README draft * update README * rmv test_dataset * update README * demo adjustments * add input delay correction * update README
1 parent fedf763 commit f017028

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

68 files changed

+9383
-2418
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,6 @@ fluidsynth/
167167
tests/test_results
168168
lightning_logs/
169169
.vscode/
170+
paper
171+
hf
172+
_scripts

Makefile

Lines changed: 0 additions & 9 deletions
This file was deleted.

README.md

Lines changed: 105 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,64 +1,131 @@
1-
# gpt-aria
1+
# Aria
22

3-
[Discord](https://discord.com/invite/zBGx3azzUn)
3+
This repository contains training, inference, and evaluation code for the paper [*Scaling Self-Supervised Representation Learning for Symbolic Piano Performance (ISMIR 2025)*](https://example.com/), as well as implementations of our real-time piano continuation demo. *Aria* is a pretrained autoregressive generative model for symbolic music, based on the LLaMA 3.2 (1B) architecture, which was trained on ~60k hours of MIDI transcriptions of expressive solo-piano recordings. Alongside the base model, we are releasing a checkpoint finetuned to improve generative quality, as well as a checkpoint finetuned to produce general-purpose piano MIDI embeddings using a SimCSE-style contrastive training objective.
44

5-
A repository containing resources for pre-training, fine-tuning, and evaluating musical (MIDI) transformer models.
5+
📖 Read our [release blog post](https://example.com/) and [paper](https://example.com/)
6+
🤗 Access our models via the [HuggingFace page](https://huggingface.co/loubb/aria-medium-base)
7+
📊 Get access to our training dataset [Aria-MIDI](https://huggingface.co/datasets/loubb/aria-midi) and train your own models
68

7-
***Note that this project is under active development***
9+
## Installation
810

9-
## Description
11+
Installation requires Python 3.11+. To install the package and all dependencies with pip:
1012

11-
The main goal of the gpt-aria project is to create a suite of powerful pre-trained generative (symbolic) music models. We want to investigate how modern training (pre-training & fine-tuning) techniques can be used to improve the quality/usefulness of such models. Alongside this we are building various data (MIDI) preprocessing tools, allowing **you** to easily fine-tune our models on your own data.
13+
```bash
14+
git clone https://github.com/EleutherAI/aria
15+
cd aria
16+
pip install -e ".[all]"
17+
```
1218

13-
If you are new to symbolic music models, a good place to start are the following projects/blogposts by Google Magenta and OpenAI:
19+
## Quickstart
1420

15-
- [Music Transformer](https://magenta.tensorflow.org/music-transformer)
16-
- [MuseNet](https://openai.com/research/musenet)
21+
Download model weights from the official HuggingFace page for our pretrained model, as well as checkpoints finetuned for piano-continuation and generating MIDI-embeddings:
1722

18-
Long story short: Transformer + MIDI + GPUs = 🎵 x ∞
23+
- `aria-medium-base` ([huggingface](https://huggingface.co/loubb/aria-medium-base), [direct-download](https://huggingface.co/loubb/aria-medium-base/resolve/main/model.safetensors?download=true))
24+
- `aria-medium-gen`([huggingface](https://huggingface.co/loubb/aria-medium-gen), [direct-download](https://huggingface.co/loubb/aria-medium-gen/resolve/main/model.safetensors?download=true))
25+
- `aria-medium-embedding`([huggingface](https://huggingface.co/loubb/aria-medium-embedding), [direct-download](https://huggingface.co/loubb/aria-medium-embedding/resolve/main/model.safetensors?download=true))
1926

20-
## Installation
27+
### Inference (Prompt Continuation)
2128

22-
Make sure you are using Python 3.10+. Note that I haven't explicitly developed this project for anything other than Linux. If you are using Windows, things might not work properly. In this case I suggest installing using WSL.
29+
We provide optimized model implementations for PyTorch (CUDA) and MLX (Apple Silicon). You can generate continuations of a MIDI file using the CLI, e.g., using CUDA (Linux):
2330

24-
```
25-
git clone https://github.com/eleutherai/aria
26-
cd aria
27-
pip install -e .
31+
```bash
32+
aria generate \
33+
--backend torch_cuda \
34+
--checkpoint_path <path-to-model-weights> \
35+
--prompt_midi_path <path-to-midi-file-to-continue> \
36+
--prompt_duration <length-in-seconds-for-prompt> \
37+
--variations <number-of-variations-to-generate> \
38+
--temp 0.98 \
39+
--min_p 0.035 \
40+
--length 2048 \
41+
--save_dir <dir-to-save-results>
2842
```
2943

30-
## Inference
44+
Since the model has not been post-trained with instruction tuning or RLHF (similar to pre-instruct GPT models), it is very sensitive to input quality and performs best when prompted with well-played music. To get prompt MIDI files, see the `example-prompts/` directory, explore the [Aria-MIDI](https://huggingface.co/datasets/loubb/aria-midi) dataset, or transcribe your own files using our [piano-transcription model](https://github.com/EleutherAI/aria-amt). For a full list of sampling options: `aria generate -h`. If you wish to do inference on the CPU, please see the platform-agnostic implementation on our HuggingFace page [link].
3145

32-
You can find preliminary checkpoints at the following locations
46+
### Intended Use and Limitations
3347

34-
Finetuned piano-only checkpoints (improved robustness):
48+
Aria performs best when **continuing existing piano MIDI files** rather than generating music from scratch. While multi-track tokenization and generation are supported, the model was trained primarily on **single-track expressive piano performances**, and we recommend using single-track inputs for optimal results.
3549

36-
```
37-
large - https://storage.googleapis.com/aria-checkpoints/large-abs-inst.safetensors
38-
```
50+
Due to the high representation of popular classical works (e.g., Chopin) in the training data and the difficulty of complete deduplication, the model may **memorize or closely reproduce** such pieces. For more original outputs, we suggest prompting Aria with **lesser-known works or your own compositions**.
3951

40-
Pretrained checkpoints:
52+
### Inference (MIDI embeddings)
4153

54+
You can generate embeddings from MIDI files using the `aria.embeddings` module. This is primarily exposed with the `get_global_embedding_from_midi` function, for example:
55+
56+
```python
57+
from aria.embeddings import get_global_embedding_from_midi
58+
from aria.model import TransformerEMB, ModelConfig
59+
from aria.config import load_model_config
60+
from ariautils.tokenizer import AbsTokenizer
61+
62+
# Load model
63+
model_config = ModelConfig(**load_model_config(name="medium-emb"))
64+
model_config.set_vocab_size(AbsTokenizer().vocab_size)
65+
model = TransformerEMB(model_config)
66+
state_dict = load_file(filename=CHECKPOINT_PATH)
67+
model.load_state_dict(state_dict=state_dict, strict=True)
68+
69+
# Generate embedding
70+
embedding = get_global_embedding_from_midi(
71+
model=model,
72+
midi_path=MIDI_PATH,
73+
device="cpu",
74+
)
4275
```
43-
large - https://storage.googleapis.com/aria-checkpoints/large-abs-pt.bin
44-
medium - https://storage.googleapis.com/aria-checkpoints/medium-abs-pt.bin
45-
small - https://storage.googleapis.com/aria-checkpoints/small-abs-pt.bin
46-
```
4776

48-
You can then sample using the cli:
77+
Our embedding model was trained to capture composition-level and performance-level attributes, and therefore might not be appropriate for every use case.
78+
79+
## Real-time demo
80+
81+
In `demo/` we provide CUDA (Linux/PyTorch) and MLX (Apple Silicon) implementations of the real-time interactive piano-continuation demo showcased in our release blog post. For the demo we used an acoustic Yamaha Disklavier piano with simultaneous MIDI input and output ports connected via a standard MIDI interface.
82+
83+
**NOTE**: Responsiveness of the real-time demo is dependent on your system configuration, e.g., GPU FLOPS and memory bandwidth.
4984

85+
A MIDI input device is not strictly required to play around with the demo: By using the `--midi_path` and `--midi_through` arguments you can mock real-time input by playing from a MIDI file. All that is required are MIDI drivers (e.g., CoreMIDI, ALSA) and a virtual software instrument (e.g., Fluidsynth, Pianoteq) to render the output.
86+
87+
Example usage (MLX):
88+
89+
```bash
90+
MIDI_PATH="example-prompts/pokey_jazz.mid"
91+
92+
python demo/demo_mlx.py \
93+
--checkpoint <checkpoint-path> \
94+
--midi_path ${MIDI_PATH} \
95+
--midi_through <port-to-stream-midi-file-through> \
96+
--midi_out <port-to-stream-generation-over> \
97+
--save_path <path-to-save-result> \
98+
--temp 0.98 \
99+
--min_p 0.035
50100
```
51-
aria sample \
52-
-m large \
53-
-c <path-to-checkpoint> \
54-
-p <path-to-midifile> \
55-
-var <num-variations-to-generate> \
56-
-trunc <seconds-in-to-truncate-prompt> \
57-
-l <number-of-tokens-to-generate> \
58-
-temp 0.95 \
59-
-e
101+
102+
## Evaluation
103+
104+
We provide the specific files/splits we used for Aria-MIDI derived linear-probe and classification evaluations. These can be downloaded from HuggingFace ([direct-download](https://huggingface.co/loubb/aria-medium-base/resolve/main/eval-splits.tar.gz?download=true)). Class labels are provided in `metadata.json` with the schema:
105+
106+
```json
107+
{
108+
"<category>": {
109+
"<split-name>": {
110+
"<relative/path/to/file.mid>": "<metadata_value_for_that_category>",
111+
...
112+
},
113+
...
114+
},
115+
...
116+
}
60117
```
61118

62-
You can use `aria sample -h` to see a full list of options. If you wish to sample from a pretrained checkpoint, please use the `-pt` flag.
119+
## License and Attribution
63120

121+
The Aria project has been kindly supported by EleutherAI, Stability AI, as well as by a compute grant from the Ministry of Science and ICT of Korea. Our models and MIDI tooling are released under the Apache-2.0 license. If you use the models or tooling for follow-up work, please cite the paper in which they were introduced:
64122

123+
```bibtex
124+
@inproceedings{bradshawscaling,
125+
title={Scaling Self-Supervised Representation Learning for Symbolic Piano Performance},
126+
author={Bradshaw, Louis and Fan, Honglu and Spangher, Alex and Biderman, Stella and Colton, Simon},
127+
booktitle={arXiv preprint},
128+
year={2025},
129+
url={https://arxiv.org/abs/2504.15071}
130+
}
131+
```

0 commit comments

Comments
 (0)