Unknown activation function: silu Error
Closed this issue · 1 comments
AnouarITI commented
I am trying to run the simple mnist classification demo. First, I discovered that this code works only with Numpy version <1.24.
Now, I am getting this error when I try to build the first KAN model:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/tmp/ipykernel_1832/4142239080.py in <module>
8 tf.keras.layers.Softmax()
9 ])
---> 10 kan.build(input_shape=(None, 28, 28, 1))
11 kan.summary()
/usr/local/lib/python3.8/dist-packages/tensorflow/python/keras/engine/sequential.py in build(self, input_shape)
354 input_shape = tuple(input_shape)
355 self._build_input_shape = input_shape
--> 356 super(Sequential, self).build(input_shape)
357 self.built = True
358
/usr/local/lib/python3.8/dist-packages/tensorflow/python/keras/engine/training.py in build(self, input_shape)
441 'method accepts an `inputs` argument.')
442 try:
--> 443 self.call(x, **kwargs)
444 except (errors.InvalidArgumentError, TypeError):
445 raise ValueError('You cannot build your model by calling `build` '
/usr/local/lib/python3.8/dist-packages/tensorflow/python/keras/engine/sequential.py in call(self, inputs, training, mask)
392 kwargs['training'] = training
393
--> 394 outputs = layer(inputs, **kwargs)
395
396 if len(nest.flatten(outputs)) != 1:
/usr/local/lib/python3.8/dist-packages/tensorflow/python/keras/engine/base_layer.py in __call__(self, *args, **kwargs)
1028 with autocast_variable.enable_auto_cast_variables(
1029 self._compute_dtype_object):
-> 1030 outputs = call_fn(inputs, *args, **kwargs)
1031
1032 if self._activity_regularizer:
/usr/local/lib/python3.8/dist-packages/tensorflow/python/keras/engine/functional.py in call(self, *args, **kwargs)
1445 if 'mask' in kwargs and not self._expects_mask_arg:
1446 kwargs.pop('mask')
-> 1447 return getattr(self._module, self._method_name)(*args, **kwargs)
/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py in __call__(self, *args, **kwargs)
997 with tf.name_scope(name_scope):
998 if not self.built:
--> 999 self._maybe_build(inputs)
1000
1001 if self._autocast:
/usr/local/lib/python3.8/dist-packages/keras/engine/base_layer.py in _maybe_build(self, inputs)
2599 # operations.
2600 with tf_utils.maybe_init_scope(self):
-> 2601 self.build(input_shapes) # pylint:disable=not-callable
2602 # We must set also ensure that the layer is marked as built, and the build
2603 # shape is stored since user defined build functions may not be calling
/usr/local/lib/python3.8/dist-packages/tfkan/layers/convolution.py in build(self, input_shape)
51 )
52 self._in_size = self.kernel_size[0] * self.kernel_size[1] * in_channels
---> 53 self.kernel.build(self._in_size)
54
55 # create bias if needed
/usr/local/lib/python3.8/dist-packages/tfkan/layers/dense.py in build(self, input_shape)
77 # build basis activation
78 if isinstance(self.basis_activation, str):
---> 79 self.basis_activation = tf.keras.activations.get(self.basis_activation)
80 elif not callable(self.basis_activation):
81 raise ValueError(f"expected basis_activation to be str or callable, found {type(self.basis_activation)}")
/usr/local/lib/python3.8/dist-packages/tensorflow/python/util/dispatch.py in wrapper(*args, **kwargs)
204 """Call target, and fall back on dispatchers if there is a TypeError."""
205 try:
--> 206 return target(*args, **kwargs)
207 except (TypeError, ValueError):
208 # Note: convert_to_eager_tensor currently raises a ValueError, not a
/usr/local/lib/python3.8/dist-packages/tensorflow/python/keras/activations.py in get(identifier)
585 if isinstance(identifier, str):
586 identifier = str(identifier)
--> 587 return deserialize(identifier)
588 elif isinstance(identifier, dict):
589 return deserialize(identifier)
/usr/local/lib/python3.8/dist-packages/tensorflow/python/util/dispatch.py in wrapper(*args, **kwargs)
204 """Call target, and fall back on dispatchers if there is a TypeError."""
205 try:
--> 206 return target(*args, **kwargs)
207 except (TypeError, ValueError):
208 # Note: convert_to_eager_tensor currently raises a ValueError, not a
/usr/local/lib/python3.8/dist-packages/tensorflow/python/keras/activations.py in deserialize(name, custom_objects)
544 globs[key] = val
545
--> 546 return deserialize_keras_object(
547 name,
548 module_objects=globs,
/usr/local/lib/python3.8/dist-packages/tensorflow/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
696 obj = module_objects.get(object_name)
697 if obj is None:
--> 698 raise ValueError(
699 'Unknown {}: {}. Please ensure this object is '
700 'passed to the `custom_objects` argument. See '
ValueError: Unknown activation function: silu. Please ensure this object is passed to the `custom_objects` argument. See https://www.tensorflow.org/guide/keras/save_and_serialize#registering_the_custom_object for details.
I assume it is a tensorflow compatibility problem. Which one are you using?
ZPZhou-lab commented
Thank you for feedbacks! I'm using tensorflow >= 2.10
, if you do not want to reinstall a new version or create a new envirionment for this, you can define the silu
activation and pass it into DenseKAN
or Conv2DKAN
kwargs:
import tensorflow as tf
from tfkan.layers import DenseKAN, Conv2DKAN
def silu(x: tf.Tensor):
return x * tf.nn.sigmoid(x)
# build layer with custom activation
dense_layer = DenseKAN(units=10, basis_activation=silu)
conv_layer = Conv2DKAN(filters=10, kernel_size=3, strides=1, kan_kwargs={'basis_activation': silu})
So the CNN build logic can be:
# KAN
kan = tf.keras.models.Sequential([
Conv2DKAN(filters=8, kernel_size=5, strides=2, padding='valid', kan_kwargs={'grid_size': 3, 'basis_activation': silu}),
tf.keras.layers.LayerNormalization(),
Conv2DKAN(filters=16, kernel_size=5, strides=2, padding='valid', kan_kwargs={'grid_size': 3, 'basis_activation': silu}),
GlobalAveragePooling2D(),
DenseKAN(10, grid_size=3),
tf.keras.layers.Softmax()
])
kan.build(input_shape=(None, 28, 28, 1))
kan.summary()
Hope this can help you, I will add requirements.txt
for tfkan
later.🫡