/tinyKAN

Primary LanguagePython

# Define the KAN model
class KANNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, knot_count=5, order=3, extra_knots=2):
        super().__init__()
        self.input_layer = SplineActivation(input_size, hidden_size, knot_count, order, extra_knots)
        self.hidden_layer = SplineActivation(hidden_size, output_size, knot_count, order, extra_knots)

    def forward(self, x):
        x = self.input_layer(x)
        x = self.hidden_layer(x)
        return x