Skip to content

[BIBM'25] Official Repository for the Paper "Augmenting Continual Learning of Diseases with LLM-Generated Visual Concepts"

License

Notifications You must be signed in to change notification settings

MPX0222/VisualConcepts4CL

Repository files navigation

Augmenting Continual Learning of Diseases with LLM-Generated Visual Concepts

Jiantao Tan1   Peixian Ma2   Kanghao Chen2   Zhiming Dai1   Ruixuan Wang1,3,4  

1Sun Yat-sen University 2The Hong Kong University of Science and Technology (Guangzhou) 3Peng Cheng Laboratory 4Key Laboratory of Machine Intelligence and Advanced Computing


📖 Abstract

Continual learning is essential for medical image classification systems to adapt to dynamically evolving clinical environments. The integration of multimodal information can significantly enhance continual learning of image classes. However, while existing approaches do utilize textual modality information, they solely rely on simplistic templates with a class name, thereby neglecting richer semantic information. To address these limitations, we propose a novel framework that harnesses visual concepts generated by large language models (LLMs) as discriminative semantic guidance. Our method dynamically constructs a visual concept pool with a similarity-based filtering mechanism to prevent redundancy. Then, to integrate the concepts into the continual learning process, we employ a cross-modal image-concept attention module, coupled with an attention loss. Through attention, the module can leverage the semantic knowledge from relevant visual concepts and produce class-representative fused features for classification. Experiments on medical and natural image datasets show our method achieves state-of-the-art performance, demonstrating the effectiveness and superiority of our method. We will release the code publicly.

📝 Citation

If you find this work useful, please cite:

@article{tan2024augmenting,
  title={Augmenting Continual Learning of Diseases with LLM-Generated Visual Concepts},
  author={Tan, Jiantao and Ma, Peixian and Chen, Kanghao and Dai, Zhiming and Wang, Ruixuan},
  journal={arXiv preprint arXiv:2508.03094},
  year={2024}
}

🚀 Installation

Requirements

  • Python 3.7+
  • PyTorch 2.3.1
  • CUDA (for GPU acceleration)

Setup

  1. Clone the repository:
git clone https://github.com/MPX0222/VisualConcepts4CL.git
cd VisualConcepts4CL
  1. Install dependencies:
pip install -r requirements.txt

For concept generation (optional):

pip install openai scikit-learn nltk
  1. Download pretrained CLIP models:

    • Create a pretrained_model/ directory in the project root
    • Download CLIP models and place them in pretrained_model/
    • Default path: pretrained_model/CLIP_ViT-B-16.pt
    • You can download CLIP models from OpenAI CLIP or OpenCLIP
    • Note: Pretrained models are not included in the repository due to their size
  2. Download and prepare datasets:

    • Follow the instructions in the Dataset Preparation section
    • Ensure datasets are placed in the expected locations
    • Note: Raw dataset files are not included in the repository (see .gitignore)

📁 Dataset Preparation

Dataset Structure

Raw Dataset Files

Place your downloaded datasets according to the expected paths in each dataset's implementation file. For example:

  • Skin40: $HOME/Data/skin40/ with train_1.txt, val_1.txt, and images/ directory
  • CIFAR100: Automatically handled by torchvision
  • Other datasets: Check the respective dataset class in datasets/ for expected paths

Class Descriptions (Included in Repository)

Each dataset should have its class descriptions stored in datasets/class_descs/{DATASET_NAME}/:

  • description_pool.json: Dictionary mapping class names to lists of LLM-generated descriptions
  • unique_descriptions.txt: List of unique descriptions used for training

Example structure:

datasets/
  class_descs/
    CIFAR100/
      description_pool.json
      unique_descriptions.txt
    Skin40/
      description_pool.json
      unique_descriptions.txt

Note: The class description files are already included in this repository and do not need to be downloaded separately.

Generating Class Descriptions (Optional)

If you want to generate your own class descriptions using LLMs, you can use the concept generation tool located in methods/concept_generation/. This is an independent module that:

  • Generates visual concept descriptions using LLM APIs (OpenAI GPT models)
  • Implements a similarity-based filtering mechanism to prevent redundant descriptions
  • Supports batch processing and different prompt types (medical, default, CT, etc.)
  • Converts description pools to index-based format for efficient storage

Note: The pre-generated descriptions in datasets/class_descs/ are ready to use. You only need to run this tool if you want to generate new descriptions or modify existing ones.

🎯 Usage

Training

  1. Configure your experiment in a YAML file under config_yaml/CLIP_Concept/:
python main.py --yaml_path config_yaml/CLIP_Concept/cifar100.yaml
  1. Or specify parameters directly via command line:
python main.py \
    --yaml_path config_yaml/CLIP_Concept/cifar100.yaml \
    --batch_size 32 \
    --epochs 10 \
    --lr 0.002

⚙️ Configuration

Configuration files are organized in YAML format under config_yaml/CLIP_Concept/. Key parameters include:

Basic Settings

  • method: Training method (e.g., "CLIP_Concept")
  • increment_type: Type of incremental learning ("CIL" for class-incremental)
  • increment_steps: List defining number of classes per task (e.g., [10, 10, 10, ...])

Model Settings

  • backbone: Backbone architecture ("CLIP", "OpenCLIP", or "MedCLIP")
  • pretrained_path: Path to pretrained model weights
  • alpha: Weight for combining direct and attention-based logits

Training Settings

  • batch_size: Batch size for training
  • epochs: Number of training epochs per task
  • lr: Learning rate
  • optimizer: Optimizer type (e.g., "AdamW")
  • scheduler: Learning rate scheduler (e.g., "Cosine")

Concept Settings

  • desc_path: Path to class descriptions file
  • desc_num: Number of descriptions per class
  • prompt_template: Template for text prompts (e.g., "a photo of a {}.")
  • lambd: Weight for attention loss

Memory Settings (Optional)

  • memory_size: Total size of replay memory
  • memory_per_class: Number of exemplars per class
  • sampling_method: Method for selecting exemplars (e.g., "herding")

Stage 2 Training (Class-Aware Regularization)

  • ca_epoch: Number of epochs for stage 2 training
  • ca_lr: Learning rate for stage 2 training
  • num_sampled_pcls: Number of samples per class for stage 2
  • ca_logit_norm: Logit normalization factor (0 to disable)

Example configuration file:

basic:
  random_seed: 1993
  version_name: "cifar100_b0i10"
  method: "CLIP_Concept"
  increment_type: "CIL"
  increment_steps: [10, 10, 10, 10, 10, 10, 10, 10, 10, 10]

usual:
  dataset_name: "cifar100"
  backbone: "CLIP"
  pretrained_path: "pretrained_model/CLIP_ViT-B-16.pt"
  batch_size: 32
  epochs: 10
  lr: 0.002
  optimizer: AdamW
  scheduler: Cosine

special:
  desc_path: "./datasets/class_descs/CIFAR100/unique_descriptions.txt"
  prompt_template: "a photo of a {}."
  desc_num: 3
  alpha: 0.5
  ca_epoch: 5
  ca_lr: 0.002

About

[BIBM'25] Official Repository for the Paper "Augmenting Continual Learning of Diseases with LLM-Generated Visual Concepts"

Resources

License

Stars

Watchers

Forks

Languages