Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 28 additions & 6 deletions config/data_module/emg_pretrain_data_module.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,36 @@ data_module:
train_val_split_ratio: 0.8
datasets:
demo_dataset: null
emg2pose:
emg2pose_train:
_target_: 'datasets.emg_pretrain_dataset.EMGPretrainDataset'
data_dir: "${env:DATA_PATH}/emg2pose/h5/"
db6:
hdf5_file: "${env:DATA_PATH}/emg2pose_data/h5/train.h5"
emg2pose_val:
_target_: 'datasets.emg_pretrain_dataset.EMGPretrainDataset'
data_dir: "${env:DATA_PATH}/ninapro/DB6/h5/"
hdf5_file: "${env:DATA_PATH}/emg2pose_data/h5/val.h5"
emg2pose_test:
_target_: 'datasets.emg_pretrain_dataset.EMGPretrainDataset'
hdf5_file: "${env:DATA_PATH}/emg2pose_data/h5/test.h5"
db6_train:
_target_: 'datasets.emg_pretrain_dataset.EMGPretrainDataset'
hdf5_file: "${env:DATA_PATH}/ninapro/DB6/h5/train.h5"
pad_up_to_max_chans: 16
db6_val:
_target_: 'datasets.emg_pretrain_dataset.EMGPretrainDataset'
hdf5_file: "${env:DATA_PATH}/ninapro/DB6/h5/val.h5"
pad_up_to_max_chans: 16
db6_test:
_target_: 'datasets.emg_pretrain_dataset.EMGPretrainDataset'
hdf5_file: "${env:DATA_PATH}/ninapro/DB6/h5/test.h5"
pad_up_to_max_chans: 16
db7_train:
_target_: 'datasets.emg_pretrain_dataset.EMGPretrainDataset'
hdf5_file: "${env:DATA_PATH}/ninapro/DB7/h5/train.h5"
pad_up_to_max_chans: 16
db7_val:
_target_: 'datasets.emg_pretrain_dataset.EMGPretrainDataset'
hdf5_file: "${env:DATA_PATH}/ninapro/DB7/h5/val.h5"
pad_up_to_max_chans: 16
db7:
db7_test:
_target_: 'datasets.emg_pretrain_dataset.EMGPretrainDataset'
data_dir: "${env:DATA_PATH}/ninapro/DB7/h5/"
hdf5_file: "${env:DATA_PATH}/ninapro/DB7/h5/test.h5"
pad_up_to_max_chans: 16
27 changes: 13 additions & 14 deletions config/experiment/TinyMyo_finetune.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ tag: EMG_finetune

gpus: -1
num_nodes: 1
num_workers: 8
num_workers: 4
batch_size: 32
max_epochs: 50

Expand All @@ -32,7 +32,6 @@ finetune_pretrained: True
resume: False

layerwise_lr_decay: 0.90
scheduler_type: cosine

pretrained_checkpoint_path: null
pretrained_safetensors_path: null
Expand All @@ -41,7 +40,7 @@ finetuning:
freeze_layers: False

io:
base_output_path: ${env:DATA_PATH}
base_output_path: ${env:LOG_DIR}
checkpoint_dirpath: ${env:CHECKPOINT_DIR}/checkpoints
version: 0

Expand All @@ -52,23 +51,16 @@ defaults:
- override /task: finetune_task_TinyMyo
- override /criterion: finetune_criterion

masking:
patch_size: [1, 20]
masking_ratio: 0.50
unmasked_loss_coeff: 0.1

input_normalization:
normalize: False

model:
n_layer: 8
num_classes: 6
classification_type: "ml"
task: "classification"

trainer:
accelerator: gpu
num_nodes: ${num_nodes}
devices: ${gpus}
strategy: auto
strategy: ddp_find_unused_parameters_true
max_epochs: ${max_epochs}

model_checkpoint:
Expand All @@ -89,7 +81,7 @@ optimizer:
optim: 'AdamW'
lr: 5e-4
betas: [0.9, 0.98]
weight_decay: 0.01
weight_decay: 1e-2

scheduler:
trainer: ${trainer}
Expand All @@ -98,3 +90,10 @@ scheduler:
warmup_epochs: 5
total_training_opt_steps: ${max_epochs}
t_in_epochs: True

wandb:
entity: "TinyMyo"
project: "TinyMyo"
save_dir: ${env:LOG_DIR}
run_name: "TinyMyo-Finetuning"
offline: True
45 changes: 30 additions & 15 deletions config/experiment/TinyMyo_pretrain.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ tag: EMG_pretrain
gpus: -1
num_nodes: 1
num_workers: 8
batch_size: 128
max_epochs: 50
batch_size: 512
max_epochs: 30

final_validate: True
final_test: False

pretrained_checkpoint_path: null
io:
base_output_path: ${env:DATA_PATH}
base_output_path: ${env:LOG_DIR}
checkpoint_dirpath: ${env:CHECKPOINT_DIR}/checkpoints
version: 0

Expand All @@ -49,22 +49,23 @@ masking:
input_normalization:
normalize: True

scheduler:
trainer: ${trainer}
min_lr: 1e-6
warmup_lr_init: 1e-6
warmup_epochs: 10
total_training_opt_steps: ${max_epochs}
t_in_epochs: True
model:
n_layer: 8
drop_path: 0.0 # Stochastic depth disabled for pretraining
num_classes: 0 # No classification head for pretraining
task: pretraining

criterion:
loss_type: 'smooth_l1'

trainer:
accelerator: gpu
num_nodes: ${num_nodes}
devices: ${gpus}
strategy: auto
precision: "bf16-mixed"
max_epochs: ${max_epochs}
gradient_clip_val: 3
accumulate_grad_batches: 8
gradient_clip_val: 1

model_checkpoint:
save_last: True
Expand All @@ -73,7 +74,21 @@ model_checkpoint:
save_top_k: 1

optimizer:
optim: 'AdamW'
lr: 1e-4
lr: 5e-4
betas: [0.9, 0.98]
weight_decay: 0.01
weight_decay: 1e-2

scheduler:
trainer: ${trainer}
min_lr: 1e-6
warmup_lr_init: 1e-6
warmup_epochs: 3
total_training_opt_steps: ${max_epochs}
t_in_epochs: True

wandb:
entity: "TinyMyo"
project: "TinyMyo"
save_dir: ${env:LOG_DIR}
run_name: "TinyMyo-Pretraining"
offline: True
1 change: 1 addition & 0 deletions datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Local dataset implementations for BioFoundation."""
122 changes: 27 additions & 95 deletions datasets/emg_finetune_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,130 +16,62 @@
# * *
# * Author: Matteo Fasulo *
# *----------------------------------------------------------------------------*
from collections import deque
from typing import Tuple, Union

import h5py
import numpy as np
import torch


class EMGDataset(torch.utils.data.Dataset):
"""
A PyTorch Dataset class for loading EMG (Electromyography) data from HDF5 files.
This dataset supports lazy loading of data from HDF5 files, with optional caching
to improve performance during training. It can be used for both fine-tuning (with labels)
and inference (without labels) modes. The class handles data preprocessing, such as
converting to tensors and optional unsqueezing.
"""PyTorch Dataset for loading EMG data from HDF5 files.

Attributes:
hdf5_file (str): Path to the HDF5 file containing the dataset.
unsqueeze (bool): Whether to add an extra dimension to the input data (default: False).
finetune (bool): If True, loads both data and labels; if False, loads only data (default: True).
cache_size (int): Maximum number of samples to cache in memory (default: 1500).
use_cache (bool): Whether to use caching for faster access (default: True).
regression (bool): If True, treats labels as regression targets (float); else, classification (long) (default: False).
num_samples (int): Total number of samples in the dataset, determined from HDF5 file.
data (h5py.File or None): Handle to the opened HDF5 file (lazy-loaded).
X_ds (h5py.Dataset or None): Dataset handle for input data.
Y_ds (h5py.Dataset or None): Dataset handle for labels (if finetune is True).
cache (dict): Dictionary for caching data items (if use_cache is True).
cache_queue (deque): Queue to track the order of cached items for LRU eviction.
Note:
- The HDF5 file is expected to have 'data' and 'label' datasets.
- Caching uses an LRU (Least Recently Used) eviction policy.
- Suitable for use with PyTorch DataLoader for batched loading.
hdf5_file (str): Path to the HDF5 source file.
finetune (bool): If True, returns (data, label). If False, returns data only.
regression (bool): If True, labels are treated as floats. Else, longs.
"""

def __init__(
self,
hdf5_file: str,
unsqueeze: bool = False,
finetune: bool = True,
cache_size: int = 1500,
use_cache: bool = True,
regression: bool = False,
verbose: bool = False,
):
self.hdf5_file = hdf5_file
self.unsqueeze = unsqueeze
self.cache_size = cache_size
self.finetune = finetune
self.use_cache = use_cache
self.regression = regression

self.data = None
self.X_ds = None
self.Y_ds = None

# Open once to get length, then close immediately
with h5py.File(self.hdf5_file, "r") as f:
self.num_samples = f["data"].shape[0]

if self.use_cache:
self.cache: dict[int, Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = {}
self.cache_queue = deque()

def _open_file(self) -> None:
# 'rdcc_nbytes' to increase the raw data chunk cache size
self.data = h5py.File(self.hdf5_file, "r", rdcc_nbytes=1024 * 1024 * 4)
if self.data is not None:
self.X_ds = self.data["data"]
self.Y_ds = self.data["label"]

def __len__(self) -> int:
return self.num_samples

def __getitem__(self, index):
# Check Cache
if self.use_cache and index in self.cache:
return self._process_data(self.cache[index])

# Open file (Lazy Loading for Multiprocessing)
if self.data is None:
self._open_file()

# Read Data, HDF5 slicing returns numpy array
X_np = self.X_ds[index]
X = torch.from_numpy(X_np).float()
X_np = f["data"][:]
Y_np = f["label"][:] if self.finetune else None

self.X_tensor = torch.from_numpy(X_np).float().contiguous()
if self.finetune:
Y_np = self.Y_ds[index]
if self.regression:
Y = torch.from_numpy(Y_np).float()
self.Y_tensor = torch.from_numpy(Y_np).float().contiguous()
else:
# Ensure scalar is converted properly
Y = torch.tensor(Y_np, dtype=torch.long)

data_item = (X, Y)
else:
data_item = X
self.Y_tensor = torch.from_numpy(Y_np).long().contiguous()
if verbose:
uniq, cnt = np.unique(Y_np, return_counts=True)
print(
f"[EMGDataset] {self.hdf5_file}: label min={Y_np.min()}, max={Y_np.max()}, classes={len(uniq)}"
)
print(f"[EMGDataset] {self.hdf5_file}: class hist={dict(zip(uniq.tolist(), cnt.tolist()))}")

# Update Cache
if self.use_cache:
# If cache is full, remove oldest item from dict AND queue
if len(self.cache) >= self.cache_size:
oldest_index = self.cache_queue.popleft()
del self.cache[oldest_index]
self.num_samples = self.X_tensor.shape[0] # [N, C, T]

self.cache[index] = data_item
self.cache_queue.append(index)

return self._process_data(data_item)
def __len__(self) -> int:
"""Returns the total number of samples in the dataset."""
return self.num_samples

def _process_data(self, data_item):
"""Helper to handle squeezing/returning uniformly."""
if self.finetune:
X, Y = data_item
else:
X = data_item
Y = None
def __getitem__(self, index: int) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""Retrieves the EMG data and optional label at the specified index."""

if self.unsqueeze:
X = X.unsqueeze(0)
X = self.X_tensor[index]

if self.finetune:
Y = self.Y_tensor[index]
return X, Y
else:
return X

def __del__(self):
if self.data is not None:
self.data.close()
return X
Loading