diff --git a/docs/tutorials/posttraining/lora.md b/docs/tutorials/posttraining/lora.md new file mode 100644 index 0000000000..3306e8748d --- /dev/null +++ b/docs/tutorials/posttraining/lora.md @@ -0,0 +1,164 @@ + + +# LoRA Fine-tuning on single-host TPUs + +**Low-Rank Adaptation (LoRA)** is a Parameter-Efficient Fine-Tuning (PEFT) technique designed to optimize large language models while minimizing resource consumption. + +Unlike traditional full-parameter fine-tuning, LoRA: +* **Freezes the pre-trained model weights**, preserving the original knowledge. +* **Injects trainable rank decomposition matrices** into the Transformer layers. + +This approach **greatly reduces the number of trainable parameters** required for downstream tasks, making the process faster and more memory-efficient. + +This tutorial provides step-by-step instructions for setting up the environment and performing LoRA fine-tuning on a Hugging Face dataset using MaxText. + +We use [Tunix](https://github.com/google/tunix), a JAX-based library, to power these post-training tasks. + +In this tutorial we use a single host TPU VM such as `v6e-8/v5p-8`. Let's get started! + +## Install MaxText and Post-Training dependencies + +For instructions on installing MaxText with post-training dependencies on your VM, please refer to the [official documentation](https://maxtext.readthedocs.io/en/latest/install_maxtext.html) and use the `maxtext[tpu-post-train]` installation path to include all necessary post-training dependencies. + +## Setup environment variables + +Set the following environment variables before running LoRA Fine-tuning. + +```sh +# -- Model configuration -- +export PRE_TRAINED_MODEL= # e.g., 'llama3.1-8b-Instruct' + +# -- MaxText configuration -- +export BASE_OUTPUT_DIRECTORY= # e.g., gs://my-bucket/my-output-directory +export RUN_NAME= # e.g., $(date +%Y-%m-%d-%H-%M-%S) +export STEPS= # e.g., 1000 +export PER_DEVICE_BATCH_SIZE= # e.g., 1 +export HF_TOKEN= +export LORA_RANK= # e.g., 16 +export LORA_ALPHA= # e.g., 32.0 +export LEARNING_RATE= # e.g., 3e-6 +export MAX_TARGET_LENGTH= # e.g., 1024 +export WEIGHT_DTYPE= # e.g., bfloat16 +export DTYPE= # e.g., bfloat16 + +# -- Dataset configuration -- +export DATASET_NAME= # e.g., openai/gsm8k +export TRAIN_SPLIT= # e.g., train +export HF_DATA_DIR= # e.g., main +export TRAIN_DATA_COLUMNS= # e.g., ['question','answer'] + +# -- LoRA Conversion configuration (Optional) -- +export HF_LORA_ADAPTER_PATH= # e.g., 'username/adapter-name' +``` + +## Get your model checkpoint + +This section explains how to prepare your model checkpoint for use with MaxText. You have two options: using an existing MaxText checkpoint or converting a Hugging Face checkpoint. + +### Option 1: Using an existing MaxText checkpoint + +If you already have a MaxText-compatible model checkpoint, simply set the following environment variable and move on to the next section. + +```sh +export PRE_TRAINED_MODEL_CKPT_PATH= # e.g., gs://my-bucket/my-model-checkpoint/0/items +``` + +### Option 2: Converting a Hugging Face checkpoint + +Refer to the steps in [Hugging Face to MaxText](https://maxtext.readthedocs.io/en/maxtext-v0.2.1/guides/checkpointing_solutions/convert_checkpoint.html#hugging-face-to-maxtext) to convert a hugging face checkpoint to MaxText. Make sure you have correct checkpoint files converted and saved. Similar as Option 1, you can set the following environment and move on. + +```sh +export PRE_TRAINED_MODEL_CKPT_PATH= # e.g., gs://my-bucket/my-model-checkpoint/0/items +``` +## (Optional) Resume from a previous LoRA checkpoint + +If you want to resume training from a previous run or further fine-tune an existing LoRA adapter, you can specify the LoRA checkpoint path. + +### Step 1: Convert HF LoRA adapter to MaxText format + +If your LoRA adapter is currently in Hugging Face format, you must convert it to MaxText format before it can be loaded. Use the provided conversion script: + +```sh +python3 maxtext/checkpoint_conversion/hf_lora_to_maxtext.py \ + maxtext/configs/post_train/sft.yml \ + model_name="${PRE_TRAINED_MODEL?}" \ + hf_lora_adapter_path="${HF_LORA_ADAPTER_PATH?}" \ + base_output_directory="${BASE_OUTPUT_DIRECTORY?}" \ + scan_layers=True +``` +### Step 2: Set the restore path + +Point `LORA_RESTORE_PATH` to the converted MaxText adapter directory (the directory containing the `0/items` or Orbax files). +- **load_parameters_path**: Points to the frozen base model weights (the original model). +- **lora_restore_path**: Points to the previous LoRA adapter weights you wish to load. + +```sh +# If starting fresh, you can leave this empty or skip this variable +export LORA_RESTORE_PATH= # e.g., gs://my-bucket/run-1/checkpoints/0/items +``` +## Run LoRA Fine-Tuning on Hugging Face Dataset + +Once your environment variables and checkpoints are ready, you can start the LoRA fine-tuning process. + +Execute the following command to begin training: + +```sh +python3 -m maxtext.trainers.post_train.sft.train_sft \ + run_name="${RUN_NAME?}" \ + base_output_directory="${BASE_OUTPUT_DIRECTORY?}" \ + model_name="${PRE_TRAINED_MODEL?}" \ + load_parameters_path="${PRE_TRAINED_MODEL_CKPT_PATH?}" \ + lora_restore_path="${LORA_RESTORE_PATH}" \ + hf_access_token="${HF_TOKEN?}" \ + hf_path="${DATASET_NAME?}" \ + train_split="${TRAIN_SPLIT?}" \ + hf_data_dir="${HF_DATA_DIR?}" \ + train_data_columns="${TRAIN_DATA_COLUMNS?}" \ + steps="${STEPS?}" \ + per_device_batch_size="${PER_DEVICE_BATCH_SIZE?}" \ + max_target_length="${MAX_TARGET_LENGTH?}" \ + learning_rate="${LEARNING_RATE?}" \ + weight_dtype="${WEIGHT_DTYPE?}" \ + dtype="${DTYPE?}" \ + profiler=xplane \ + enable_nnx=True \ + pure_nnx_decoder=True \ + enable_lora=True \ + lora_rank="${LORA_RANK?}" \ + lora_alpha="${LORA_ALPHA?}" \ + scan_layers=True +``` + +Your fine-tuned model checkpoints will be saved here: `$BASE_OUTPUT_DIRECTORY/$RUN_NAME/checkpoints`. + +## (Optional) Export Fine-tuned LoRA to Hugging Face Format + +After completing the fine-tuning process, your LoRA weights are stored in MaxText/Orbax format. To use these weights with the Hugging Face ecosystem (e.g., for inference or sharing), convert them back using the `maxtext_lora_to_hf.py` script. + +```sh +python3 maxtext/checkpoint_conversion/maxtext_to_hf_lora.py \ + maxtext/configs/post_train/sft.yml \ + model_name="${PRE_TRAINED_MODEL?}" \ + load_parameters_path="${BASE_OUTPUT_DIRECTORY?}/${RUN_NAME?}/checkpoints//items" \ + base_output_directory="${BASE_OUTPUT_DIRECTORY?}/hf_lora_adaptor" \ + lora_rank="${LORA_RANK?}" \ + lora_alpha="${LORA_ALPHA?}" +``` + +- ```load_parameters_path```: Point this to the specific checkpoint directory (e.g., ```.../checkpoints/1000/items```) that you want to export. +- ```base_output_directory```: The local or GCS directory where the Hugging Face ```adapter_model.safetensors``` and ```adapter_config.json``` will be saved. +- ```lora_rank``` / ```lora_alpha```: Must match the values used during the training phase to ensure the ```adapter_config.json``` is generated correctly. \ No newline at end of file