lyhue1991/torchkeras

AttributeError: partially initialized module 'torchkeras' has no attribute 'KerasModel' (most likely due to a circular import)

MaybeRichard opened this issue · 0 comments

`from model import LeNet
import torchkeras
import torchmetrics
from torchvision import datasets
import torch.nn as nn
import torch
import torchvision.transforms as transforms
import torchvision

transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 4
net = LeNet()
model = torchkeras.KerasModel(net,
loss_fn = nn.BCEWithLogitsLoss(),
optimizer= torch.optim.Adam(net.parameters(),lr = 1e-4),
metrics_dict = {"acc":torchmetrics.Accuracy(task='binary')}
)

trainset = torchvision.datasets.CIFAR10(root='../../../../Dataset', train=True,
download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='../../../../Dataset', train=False,
download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
shuffle=False, num_workers=2)

dfhistory=model.fit(train_data=trainloader,
val_data=testloader,
epochs=20,
patience=3,
ckpt_path='checkpoint.pt',
monitor="val_acc",
mode="max",
plot=True,
)`