premAI-io/benchmarks

JAX

Opened this issue · 2 comments

Since, JAX is getting very much popular. So, it would be awesome, if we can also benchmark the performance of LLama 2 written in JAX.

Here is the implementation

@Anindyadeep can you check if they support llama2/mistral? Otherwise let's close the issue

Seems like HuggingFace does have an implementation for both

Llama implementation: huggingface/transformers#24587
Mistral implementation: huggingface/transformers#26943

we can come back to this, once done with the initial ones