这个项目实现了一个基于混合专家(Mixture of Experts, MoE)的多模态神经网络模型,用于图像分类任务,支持CIFAR10、Fashion-MNIST和Flickr8k数据集。
diff-MoE/
├── main.py # 主程序入口
├── model.py # 模型定义
├── train.py # 训练函数
├── test.py # 测试和评估函数
├── utils.py # 工具函数和可视化
├── data_loader.py # 数据加载器
├── datasets.py # 数据集定义和标签文本描述
├── data_utils.py # 数据集下载和预处理工具
├── config.py # 配置参数
├── configs/ # 配置文件目录
│ ├── model.yaml # 模型配置
│ ├── training.yaml # 训练配置
│ ├── cifar10.yaml # CIFAR10数据集配置
│ ├── fashion_mnist.yaml # Fashion-MNIST数据集配置
│ └── visualization.yaml # 可视化配置
├── requirements.txt # 依赖库
├── README.md # 项目说明
└── README_FLICKR8K.md # Flickr8k数据集使用说明
-
数据处理:
data_loader.py
: 提供数据加载函数和批处理datasets.py
: 定义自定义数据集类及标签到文本描述的转换data_utils.py
: 数据集下载和预处理工具
-
模型:
model.py
: 实现MultiModalMoE模型和各个组件config.py
: 模型和训练基础配置参数configs/
: 包含各数据集的具体配置YAML文件
-
训练与测试:
train.py
: 提供训练循环和模型优化test.py
: 模型评估和测试功能
-
工具与可视化:
utils.py
: 工具函数和可视化功能main.py
: 命令行接口和程序入口
-
缓存目录:
clip_cache/
: CLIP模型和tokenizer缓存bert_cache/
: BERT模型和tokenizer缓存(旧版本使用)
pip install -r requirements.txt
# 训练CIFAR10模型
python main.py --dataset cifar10 --mode train --batch_size 64 --epochs 10
# 训练Fashion-MNIST模型
python main.py --dataset fashion_mnist --mode train --batch_size 64 --epochs 10
# 训练Flickr8k模型
python main.py --dataset flickr8k --mode train --batch_size 32 --epochs 20
# 测试CIFAR10模型
python main.py --dataset cifar10 --mode test --checkpoint checkpoints/cifar10_model.pth
# 测试Fashion-MNIST模型
python main.py --dataset fashion_mnist --mode test --checkpoint checkpoints/fashion_mnist_model.pth
# 测试Flickr8k模型
python main.py --dataset flickr8k --mode test --checkpoint checkpoints/flickr8k_model.pth
# 训练并测试CIFAR10模型
python main.py --dataset cifar10 --mode train_test --batch_size 64 --epochs 10
--dataset
: 数据集名称,选项:cifar10, fashion_mnist, flickr8k--mode
: 运行模式,选项:train, test, train_test--batch_size
: 批次大小--epochs
: 训练轮次数--lr
: 学习率--weight_decay
: 权重衰减--checkpoint
: 模型检查点路径--save_dir
: 保存目录
Diff-MoE是一个基于Transformer架构的多模态混合专家模型,主要特点:
- 混合专家层: 包含共享和模态特定的专家网络,通过路由器分配不同的输入到不同的专家
- 多模态融合: 结合图像和文本特征进行预测
- Attention机制: 使用自注意力和交叉注意力进行特征交互
+------------------------------------------+
| MultiModalMoE |
+------------------------------------------+
| |
| +------------+ +---------------+ |
| | 图像 | | 文本 | |
| | 编码器 | | 编码器 | |
| +-----+------+ +-------+-------+ |
| | | |
| v v |
| +------------+ +---------------+ |
| | 图像 | | CLIP文本 | |
| | 嵌入层 | | Tokenizer | |
| +-----+------+ +-------+-------+ |
| | | |
| v v |
| +----------------------------------------+
| | 统一模态编码器 (UnifiedModalEncoder) |
| +----------------------------------------+
| | |
| | +-------------+ +---------------+ |
| | | 自注意力机制 | | 交叉注意力机制 | |
| | +------+------+ +-------+-------+ |
| | | | |
| | v v |
| | +------------------+----------------+ |
| | | 混合专家层 (MoE) | |
| | +----------------------------------+ |
| | | | |
| | | +------------+ +---------------+ | |
| | | | 共享专家 | | 模态特定专家 | | |
| | | +------------+ +---------------+ | |
| | | | |
| | | +------------+ | |
| | | | 注意力路由器| | |
| | | +------------+ | |
| | +----------------------------------+ |
| | | |
| +------------------+--------------------+
| |
| v
| +------------------+--------------------+
| | 分类头 |
| +------------------+--------------------+
| |
+---------------------|---------------------+
v
输出
-
专家模块 (Expert):
- 每个专家由两个全连接层组成,带有激活函数和Layer Normalization
- 支持的激活函数: GELU, ReLU, SiLU
- 包含Dropout机制以防止过拟合
-
注意力路由器 (AttentiveRouter):
- 基于注意力机制的路由算法
- 为每个输入token分配top-k个专家
- 使用正交初始化提高稳定性
- 支持噪声辅助路由以增强训练过程中的探索
-
统一模态编码器 (UnifiedModalEncoder):
- 多头自注意力机制
- 跨模态交叉注意力融合
- 混合专家前馈网络
- 支持梯度检查点以节省显存
-
图像处理:
- Patch嵌入层将图像分割为固定大小的块
- 位置编码提供空间信息
-
文本处理:
- 使用嵌入层将文本token转化为向量
- 支持自注意力处理序列信息
- 使用CLIP tokenizer进行文本分词,最大序列长度为77
参数 | 默认值 | 描述 |
---|---|---|
img_size | 28 | 输入图像大小 |
patch_size | 4 | 图像patch大小 |
in_channels | 1 | 输入图像通道数 |
embed_dim | 512 | 嵌入维度 |
num_shared_experts | 4 | 共享专家数量 |
num_modality_specific_experts | 2 | 模态特定专家数量 |
top_k | 2 | 每个输入选择的专家数量 |
num_heads | 8 | 注意力头数量 |
num_layers | 6 | Transformer层数 |
num_classes | 10 | 分类类别数量 |
dropout | 0.1 | Dropout比率 |
activation | 'gelu' | 激活函数类型 |
Diff-MoE模型的总损失函数由两部分组成:
-
分类损失:使用标准交叉熵损失(CrossEntropyLoss)计算预测与真实标签之间的差异
分类损失 = CrossEntropyLoss(logits, labels)
-
路由损失:基于KL散度计算的专家负载均衡损失,确保专家被均匀使用
负载目标分布 = 均匀分布(1/num_experts) 专家利用率 = 每个专家处理的token数 / 总容量 路由损失 = KL散度(专家利用率, 负载目标分布)
-
总损失:
总损失 = 分类损失 + 路由损失
- 默认优化器: AdamW
- 学习率: 0.0005
- 权重衰减: 0.01
- 梯度累积: 支持梯度累积减少显存占用
- 梯度累积:每个batch的梯度会累积几个步骤后再更新参数,有效增大batch size
- 梯度检查点:可选使用梯度检查点技术以节省显存
- 学习率调度:支持学习率调度器动态调整学习率
推理时,模型接收图像和可选的文本输入:
- 图像通过Patch嵌入层处理后得到图像特征
- 文本(如果有)通过嵌入层处理后得到文本特征
- 两种特征通过统一模态编码器处理,包括:
- 自注意力机制捕获同一模态内的关系
- 交叉注意力机制融合不同模态的信息
- 通过注意力路由器将特征分配给不同专家
- 聚合专家输出得到最终表示
- 最后通过分类头输出类别概率
模型支持以下可视化功能:
- 路由决策可视化:展示每个专家处理的token数量和分布
- 注意力图:展示模型在不同输入部分的关注程度
- 专家利用率:分析不同专家的使用情况和负载均衡
- 多模态交互:展示图像和文本特征如何相互影响
- CIFAR10: 10类通用对象图像分类
- Fashion-MNIST: 10类服装图像分类
- Flickr8k: 图像-文本对数据集,用于多模态学习
训练过程中会生成以下可视化:
- 训练曲线(损失和准确率)
- 预测结果示例
- 混淆矩阵
- 专家分配可视化
可视化结果保存在visualizations/
目录下。
本项目结合了多种先进技术:
- 混合专家模型:通过路由机制实现高效计算资源分配
- 多模态融合:同时处理图像和文本信息
- CLIP文本编码:使用OpenAI的CLIP-BPE tokenizer处理文本,相比BERT在多模态任务上表现更好
- 注意力机制:使用自注意力和交叉注意力进行深度特征提取
- 可视化分析:支持模型决策过程的可视化
本项目需要以下依赖:
torch>=1.10.0 # PyTorch深度学习框架
torchvision>=0.11.0 # PyTorch计算机视觉工具
transformers>=4.18.0 # Hugging Face Transformers库,用于CLIP和BERT
matplotlib>=3.5.0 # 数据可视化
numpy>=1.21.0 # 数值计算
tqdm>=4.62.0 # 进度条显示
scikit-learn>=1.0.0 # 机器学习工具
seaborn>=0.11.0 # 统计数据可视化
pyyaml>=6.0.0 # YAML配置文件处理
pillow>=9.0.0 # 图像处理
scipy>=1.7.0 # 科学计算
对于Flickr8k数据集,还需要额外安装:
requests # HTTP请求库,用于下载数据集
硬件要求:
- 建议使用至少8GB显存的GPU进行训练
- 对于Flickr8k数据集,建议使用至少16GB显存的GPU