SAM2UNet 是一个结合了 SAM2 Transformer 编码器和 U-Net 解码结构的高效图像分割模型,目前仅适用于二值语义分割任务。该项目基于 PyTorch Lightning 框架,支持训练、评估、ONNX 导出及推理。
src/:核心代码(模型、数据、训练、评估、工具等)configs/:配置文件(训练、评估、模型、数据等)ckpts/:预训练权重存放目录data/:数据集目录logs/:训练日志与模型保存visual_test/:推理结果可视化输出
- Python >= 3.10
- 主要依赖包(见
pyproject.toml) - 第三方依赖 submodules/sam2
推荐使用 uv 工具安装依赖
如果你尚未安装 uv,可以先运行:
pip install uv然后使用如下命令同步依赖环境:
uv sync- 数据集目录结构示例(以光伏热点分割为例):
data/$data_name ├── train/ │ ├── images/ │ └── masks/ ├── val/ │ ├── images/ │ └── masks/ └── test/ ├── images/ └── masks/ - 数据集路径可在
configs/train.yaml和configs/eval.yaml中通过data_dir字段指定。
运行训练脚本:
python src/train.py- 配置文件默认使用
configs/train.yaml,可通过命令行参数覆盖。 - 训练参数(如 batch_size、学习率等)可在
configs/data/sam2.yaml、configs/model/sam2.yaml中调整。
运行评估脚本:
python src/eval.py- 需在
configs/eval.yaml中指定待评估的模型权重路径(ckpt_path)。 - 评估结果将输出到日志中。
导出训练好的模型为ONNX格式:
python src/utils/export_onnx.py --ckpt_path $ckpt_path --onnx_path $onnx_path- 可在脚本内修改权重路径、导出路径、输入尺寸等参数。
使用ONNX模型进行批量推理与可视化:
python src/utils/onnxruntime_infer.py --onnx_model_path sam2unet.onnx --image_path $data_path --save_path visual_test/- 支持单张或文件夹批量推理,自动保存可视化结果。
- SAM2UNet:主干为冻结的SAM2 Transformer编码器,解码器为U-Net结构,支持多尺度特征融合与辅助分割输出。
- 数据处理:支持自定义数据集,包含多种数据增强(Resize、Flip、Normalize等)。
- 训练与评估:基于PyTorch Lightning,支持断点续训、自动日志、回调等。
configs/train.yaml:训练主配置configs/eval.yaml:评估主配置configs/data/sam2.yaml:数据加载参数configs/model/sam2.yaml:模型结构与优化器参数
代码参考自 WZH0120/SAM2-UNet 并做了详细注释与工程化改造。
模版使用 ashleve/lightning-hydra-template 对代码进行重构,增加工程性和可读性
环境配置使用 uv, 强大的包管理工具
如需进一步补充(如实验结果、可视化示例、FAQ等),请告知!