本项目实现了经典的 ResNet-50 模型,并在 CIFAR-10 数据集上进行了训练和测试。代码支持 GPU 加速训练,并提供了模型保存和加载功能,方便后续使用。
- Python 3.8+
- PyTorch 1.10+
- torchvision
- PIL (Pillow)
- matplotlib (可选,用于可视化)
本项目使用 CIFAR-10 数据集,包含 10 个类别的 60,000 张 32x32 彩色图片。
- 训练集:50,000 张图片
- 测试集:10,000 张图片
数据集会自动下载并保存在 data/cifar10 目录下。
运行以下命令开始训练 ResNet-50 模型:
python Res50.py- 训练过程中会保存最佳模型到
models/best_model.pth。 - 训练日志会实时打印,包括每个 epoch 的损失和准确率。
运行以下命令测试模型在测试集上的性能:
python Res50.py- 测试结果会打印测试集的损失和准确率。
运行以下命令对单张图片进行分类:
python ModelTest.py --image_path test_image.jpg- 将
test_image.jpg替换为你要分类的图片路径。 - 脚本会输出预测的类别名称。
在 CIFAR-10 数据集上,ResNet-50 模型的性能如下:
- 训练集准确率:85%
- 测试集准确率:75%
- 定义了 ResNet-50 模型结构。
- 包含数据加载、模型训练和测试的逻辑。
- 支持 GPU 加速训练。
- 训练过程中会保存最佳模型。
- 加载训练好的模型权重。
- 对输入的图片进行预处理和分类。
- 输出预测的类别名称。
如果你想在其他数据集上训练模型,可以按照以下步骤操作:
- 将数据集放置在
data/目录下。 - 修改
Res50.py中的数据加载部分,适配你的数据集格式。 - 调整模型的输出类别数(
num_classes)。
欢迎提交 Issue 或 Pull Request 改进本项目!
本项目采用 MIT 许可证。
- Deep Residual Learning for Image Recognition - ResNet 论文
- CIFAR-10 Dataset - 数据集官网
- PyTorch Documentation - PyTorch 官方文档
如有问题或建议,请联系:[email protected]
感谢 PyTorch 团队和 CIFAR-10 数据集提供者!
希望这份 README.md 文档能帮助你更好地展示和分享你的 ResNet-50 项目!如果有其他需求,欢迎继续提问!