Pykan == 0.2.6 Cann't load ?? with my code but tutorial is work!!
THEFLASHFORD opened this issue · 1 comments
THEFLASHFORD commented
class KAN_Regressor(Model):
def init(self , grid=3, k=3, steps=10, **kwargs) -> None:
super().init(eliminate_duplicates=False, eliminate_duplicates_eps=1e-8, **kwargs)
self.dataset = {}
self.model = None
self.model_list = []
self.grid = grid
self.k = k
self.steps = steps
def fit(self,X,y):
if self.model is None:
model = KAN(width=[X.shape[1],2,2], grid=self.grid, k=self.k,seed=0, device=device)
self.model = model
model = copy.deepcopy(self.model)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.8)
self.dataset['train_input'] = torch.from_numpy(X_train)
self.dataset['test_input'] = torch.from_numpy(X_test)
self.dataset['train_label'] = torch.from_numpy(y_train[:,None])
self.dataset['test_label'] = torch.from_numpy(y_test[:,None])
try:
model.fit(self.dataset, opt="LBFGS", steps=self.steps)
except:
model = self.model_list[-1]
self.model_list.append(model)
def predict(self,X):
model = self.model_list[-1]
return model(torch.from_numpy(X)).detach().numpy()
THEFLASHFORD commented
Oh, i forgot it some line code lol
torch.set_default_dtype(torch.float64)