amd/ZenDNN-tensorflow-plugin

keras_nlp Gemma error

Closed this issue ยท 2 comments

Hello,
I had error then exec simple test with Gemma on keras_nlp: Input dims must be <= 4 and >=1
tensorflow lib 2.17.0 (try 2.16.0)
python 3.11

script:
import os
os.environ["KAGGLE_USERNAME"] = ""
os.environ["KAGGLE_KEY"] = ""
os.environ["KERAS_BACKEND"] = "tensorflow"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"
os.environ["TF_ENABLE_ZENDNN_OPTS"] = "1"
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
import keras #3.5.0
import keras_nlp #0.14.4
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")
gemma_lm.generate("what is keras in 3 bullet points?", max_length=64)

Hi @pwipo
Thanks for reporting this issue.

We are looking into it and it will be fixed in upcoming release.

Till we fix and update this repo. Please find the below workaround -

Disable the Softmax rewrite as below

  • Comment line #72

Set below env variable

  • os.environ["ZENDNN_MATMUL_ALGO"] = "FP32:3,BF16:3"

Build the zentf from source as mentioned in README.

@pwipo
Thank you for your patience. We have fixed the issue in 81bbf17. Please give a try with v5.0.