PyTorch_xd(名称待定) 是一个高等级的PyTorch训练辅助库。
...
- 原生PyTorch
- 省事。
- 其他PyTorch训练辅助库
- 代码短,人类可以轻松看懂。
这个项目使用 Python 开发,请确保你本地安装了它。
建议使用pip安装本库。(我还没写setup,这里用不了)
$ pip install .
使用时,你需要定义一个训练器类,其必须继承自TrainerBase同时完成所有抽象方法的编写。
from pytorch_xd import TrainerBase
class TrainerDemo(TrainerBase):
def __init__(self, opt):
super().__init__(opt.device)
self.loss_fn = nn.CrossEntropyLoss()
self.opt = opt
def configure_optimizers(self, model):
optimizer = optim.Adam(model.parameters(), lr=self.opt.lr, betas=(0.9, 0.999))
# scheduler = StepLR(optimizer, step_size=1) # 还没实现呢
return optimizer
def train_dataloader(self):
dataset = TrainDataset(self.opt.data_path, self.opt.crop_size)
loader = DataLoader(dataset, batch_size=self.opt.batchsize, shuffle=True)
return loader
def val_dataloader(self):
dataset = TrainDataset(self.opt.data_path, self.opt.crop_size)
loader = DataLoader(dataset, batch_size=self.opt.batchsize, shuffle=True)
return loader
def test_dataloader(self):
dataset = TrainDataset(self.opt.data_path, self.opt.crop_size)
loader = DataLoader(dataset, batch_size=self.opt.batchsize, shuffle=True)
return loader
def training_step(self, batch, batch_idx):
inp, label = batch
inp = inp.to(self.device)
label = label.to(self.device)
output_label = self.model(inp)
loss = self.loss_fn(output_label, label)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return {"loss": loss.detach()}
def validation_step(self, batch, batch_idx):
inp, label = batch
inp = inp.to(self.device)
label = label.to(self.device)
output_label = self.model(inp)
loss = self.loss_fn(output_label, label).detach()
accuracy = torch.equal(output_label.argmax(1), label) * 1.0
return {"loss": loss, "accuracy": accuracy}
def test_step(self, batch, batch_idx):
pass
其实,你也可以直接导入内置的训练器类
from pytorch_xd import TrainerDemo
下面是各种训练器类的介绍
- TrainerBase 这是训练器的基础,所有其他训练器都需要继承于它,它本身是个抽象类无法直接使用。
- LossTrainer 这是一个抽象训练器类,它与损失函数绑定,可以训练使用特定损失函数的一类模型。
- DemoTrainer 这是一个继承自 LossTrainer 的训练器类,它会截取特定训练集的前十个样本训练和验证,仅用作示例,没有实际用途。
- DatasetTrainer 这是一个继承自 LossTrainer 的训练器类,它与训练集绑定,它在实例化时需要指定一个具备自分割能力的数据集类,配合本库内置的各种数据集类可以实现交叉验证等很多功能。
- 。。。
我还没写
还没写。
- 内置tensorboard。
- (2020.05.15) v0.1.0 更新
- 第一版。
非常欢迎你的加入!提一个 Issue (假的)或者提交一个 Pull Request。
感谢以下参与项目的人:
禁止将本辅助库及使用本辅助库制作的文档上传到百度网盘。 GNU © Rongrui Zhan