Jiantao Tan1 Peixian Ma2 Kanghao Chen2 Zhiming Dai1 Ruixuan Wang1,3,4
1Sun Yat-sen University 2The Hong Kong University of Science and Technology (Guangzhou) 3Peng Cheng Laboratory 4Key Laboratory of Machine Intelligence and Advanced Computing
Continual learning is essential for medical image classification systems to adapt to dynamically evolving clinical environments. The integration of multimodal information can significantly enhance continual learning of image classes. However, while existing approaches do utilize textual modality information, they solely rely on simplistic templates with a class name, thereby neglecting richer semantic information. To address these limitations, we propose a novel framework that harnesses visual concepts generated by large language models (LLMs) as discriminative semantic guidance. Our method dynamically constructs a visual concept pool with a similarity-based filtering mechanism to prevent redundancy. Then, to integrate the concepts into the continual learning process, we employ a cross-modal image-concept attention module, coupled with an attention loss. Through attention, the module can leverage the semantic knowledge from relevant visual concepts and produce class-representative fused features for classification. Experiments on medical and natural image datasets show our method achieves state-of-the-art performance, demonstrating the effectiveness and superiority of our method. We will release the code publicly.
If you find this work useful, please cite:
@article{tan2024augmenting,
title={Augmenting Continual Learning of Diseases with LLM-Generated Visual Concepts},
author={Tan, Jiantao and Ma, Peixian and Chen, Kanghao and Dai, Zhiming and Wang, Ruixuan},
journal={arXiv preprint arXiv:2508.03094},
year={2024}
}- Python 3.7+
- PyTorch 2.3.1
- CUDA (for GPU acceleration)
- Clone the repository:
git clone https://github.com/MPX0222/VisualConcepts4CL.git
cd VisualConcepts4CL- Install dependencies:
pip install -r requirements.txtFor concept generation (optional):
pip install openai scikit-learn nltk-
Download pretrained CLIP models:
- Create a
pretrained_model/directory in the project root - Download CLIP models and place them in
pretrained_model/ - Default path:
pretrained_model/CLIP_ViT-B-16.pt - You can download CLIP models from OpenAI CLIP or OpenCLIP
- Note: Pretrained models are not included in the repository due to their size
- Create a
-
Download and prepare datasets:
- Follow the instructions in the Dataset Preparation section
- Ensure datasets are placed in the expected locations
- Note: Raw dataset files are not included in the repository (see
.gitignore)
Place your downloaded datasets according to the expected paths in each dataset's implementation file. For example:
- Skin40:
$HOME/Data/skin40/withtrain_1.txt,val_1.txt, andimages/directory - CIFAR100: Automatically handled by torchvision
- Other datasets: Check the respective dataset class in
datasets/for expected paths
Each dataset should have its class descriptions stored in datasets/class_descs/{DATASET_NAME}/:
description_pool.json: Dictionary mapping class names to lists of LLM-generated descriptionsunique_descriptions.txt: List of unique descriptions used for training
Example structure:
datasets/
class_descs/
CIFAR100/
description_pool.json
unique_descriptions.txt
Skin40/
description_pool.json
unique_descriptions.txt
Note: The class description files are already included in this repository and do not need to be downloaded separately.
If you want to generate your own class descriptions using LLMs, you can use the concept generation tool located in methods/concept_generation/. This is an independent module that:
- Generates visual concept descriptions using LLM APIs (OpenAI GPT models)
- Implements a similarity-based filtering mechanism to prevent redundant descriptions
- Supports batch processing and different prompt types (medical, default, CT, etc.)
- Converts description pools to index-based format for efficient storage
Note: The pre-generated descriptions in datasets/class_descs/ are ready to use. You only need to run this tool if you want to generate new descriptions or modify existing ones.
- Configure your experiment in a YAML file under
config_yaml/CLIP_Concept/:
python main.py --yaml_path config_yaml/CLIP_Concept/cifar100.yaml- Or specify parameters directly via command line:
python main.py \
--yaml_path config_yaml/CLIP_Concept/cifar100.yaml \
--batch_size 32 \
--epochs 10 \
--lr 0.002Configuration files are organized in YAML format under config_yaml/CLIP_Concept/. Key parameters include:
method: Training method (e.g., "CLIP_Concept")increment_type: Type of incremental learning ("CIL" for class-incremental)increment_steps: List defining number of classes per task (e.g.,[10, 10, 10, ...])
backbone: Backbone architecture ("CLIP", "OpenCLIP", or "MedCLIP")pretrained_path: Path to pretrained model weightsalpha: Weight for combining direct and attention-based logits
batch_size: Batch size for trainingepochs: Number of training epochs per tasklr: Learning rateoptimizer: Optimizer type (e.g., "AdamW")scheduler: Learning rate scheduler (e.g., "Cosine")
desc_path: Path to class descriptions filedesc_num: Number of descriptions per classprompt_template: Template for text prompts (e.g., "a photo of a {}.")lambd: Weight for attention loss
memory_size: Total size of replay memorymemory_per_class: Number of exemplars per classsampling_method: Method for selecting exemplars (e.g., "herding")
ca_epoch: Number of epochs for stage 2 trainingca_lr: Learning rate for stage 2 trainingnum_sampled_pcls: Number of samples per class for stage 2ca_logit_norm: Logit normalization factor (0 to disable)
Example configuration file:
basic:
random_seed: 1993
version_name: "cifar100_b0i10"
method: "CLIP_Concept"
increment_type: "CIL"
increment_steps: [10, 10, 10, 10, 10, 10, 10, 10, 10, 10]
usual:
dataset_name: "cifar100"
backbone: "CLIP"
pretrained_path: "pretrained_model/CLIP_ViT-B-16.pt"
batch_size: 32
epochs: 10
lr: 0.002
optimizer: AdamW
scheduler: Cosine
special:
desc_path: "./datasets/class_descs/CIFAR100/unique_descriptions.txt"
prompt_template: "a photo of a {}."
desc_num: 3
alpha: 0.5
ca_epoch: 5
ca_lr: 0.002