keras_nlp Gemma error
Closed this issue ยท 2 comments
Hello,
I had error then exec simple test with Gemma on keras_nlp: Input dims must be <= 4 and >=1
tensorflow lib 2.17.0 (try 2.16.0)
python 3.11
script:
import os
os.environ["KAGGLE_USERNAME"] = ""
os.environ["KAGGLE_KEY"] = ""
os.environ["KERAS_BACKEND"] = "tensorflow"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"
os.environ["TF_ENABLE_ZENDNN_OPTS"] = "1"
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
import keras #3.5.0
import keras_nlp #0.14.4
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")
gemma_lm.generate("what is keras in 3 bullet points?", max_length=64)
Hi @pwipo
Thanks for reporting this issue.
We are looking into it and it will be fixed in upcoming release.
Till we fix and update this repo. Please find the below workaround -
Disable the Softmax rewrite as below
- Comment line #72
Set below env variable
- os.environ["ZENDNN_MATMUL_ALGO"] = "FP32:3,BF16:3"
Build the zentf from source as mentioned in README.