jina-ai/GSoC

JAX support in DocArray v2

Nick17t opened this issue ยท 7 comments

Project idea 6: JAX support in DocArray v2

Info details
Skills needed Python, deep learning , JAX
Project size 175 hours
Difficulty level Hard
Mentors @Sami Jaghouar

Project Description

  • DocArray is a library for representing, sending, and storing multi-modal data, with a focus on applications in ML and Neural Search. It currently supports several deep learning frameworks, including PyTorch and TensorFlow. Jax is becoming increasingly popular for deep learning, so we want to integrate it into DocArray.

  • The project we propose is to add Jax as a backend for DocArray, alongside PyTorch and TensorFlow. The first part would involve rewriting and translating all of the computational backend functions of DocArray with the Jax framework. Then, we would battle-test the implementation against a real Jax use case, such as integrating DocArray with Jax support for model training and serving.

Expected outcomes

  • We aim to provide JAX with the same level of support in DocArray as we do for PyTorch, Numpy, and TensorFlow. The integration should be thoroughly tested and documented.

Desired skills

  • Python proficiency is expected since the DocArray codebase is quite complete. Additionally, experience with the JAX framework and familiarity with the scientific Python ecosystem (e.g. NumPy, Torch, scikit-learn, etc.) is required.

More detailed :

This Project target DocArray, especially the current rewrite: DocArray v2 which is a new codebase.

We currently support three computational frameworks in DocArray v2 : Pytorch, Numpy, and TensorFlow, we would like to add JAX support.

More info about JAX can be found here but in short, it is a deep learning framework supported by Google that is getting a lot of traction, especially among researchers.

Concretely what is expected in this project:

Hello @Nick17t @samsja! I am Pranjal.

While surfing GSoC projects, I came across this today. Having multi-modal data structures compatible with JAX modules sounds really cool to me. I had a small go through the DocArray codebase and found ArrayType and AnyDNN as the framework agnostic types. I believe their uses in codebase such as in .embed() and docarray.math.distance will need to be looked into for the JAX port. We will also need to decide on the use of either Flax or Haiku Modules for DNNs. Overall, the project seems very exciting to work on!

I would love to know more and contribute to the project.

samsja commented

@DevPranjal I added more info in the description of the issue. Be aware that this project is on DocArray v2

@Nick17t @samsja
Based on the given information, here is a what I understood:

Project Description:

DocArray is a library for representing, sending, and storing multi-modal data, with a focus on applications in ML and Neural Search. It currently supports PyTorch, Numpy, and TensorFlow as computational backends. We want to extend the backend support to include JAX.
The project goal is to add JAX support as a computational backend in DocArray v2.

Here are the specific tasks involved:

  1. Add a new backend to the computational backend while relying as much as possible on JAX Numpy (jnp) as a numpy-like interface for JAX.

  2. Create a new Tensor object with the JAX backend, including variants for ImageTensor and other tensor types.

  3. Ensure compatibility of DocumentArrayStack with JAX, with unit testing for each function in the computational backend.

  4. Thoroughly test the implementation through the following:

Unit tests for each function in the computational backend, using predefined tensors and DocumentArrayStack.
Integration tests to check the coherence of the entire implementation, with emphasis on training a small neural network using DocArray + JAX.

Expected outcomes:

Upon successful completion of this project, DocArray v2 will support JAX as a computational backend alongside PyTorch, Numpy, and TensorFlow. The implementation will be thoroughly tested and documented.

I would like to work on this project.

Hey @samsja , I would like to contribute to the project.Please guide me how to get started with the stuff. I am proficient in python as well as machine learning using tensorflow

Hi @samsja @Nick17t , as much as I understood , i tried doing it.Please state if I am on the correct path
Screenshot 2023-03-13 at 5 08 16 PM

Hi @DevPranjal @Arnav131003 @tehami02

I am delighted to hear that you are interested in contributing to the Jina AI community! ๐ŸŽ‰

To get started, please take a moment to fill out our survey so that we can learn more about you and your skills.

Also, don't forget to mark your calendars for the GSoC x Jina AI webinar on March 23rd at 2 pm (CET). This is an excellent opportunity to learn more about the projects and ask any questions you have about the requirements and expectations.

Our mentors will provide an in-depth overview of the projects and answer any questions you may have. So please don't hesitate to ask any questions or seek clarification on any aspect of the project.

Is there anything specific you would like to learn from the webinar? Do you have any questions about the JAX support in DocArray v2 project that you would like to see clarified during the Q&A session? Let me know, and I'll be happy to help!

Looking forward to seeing you at the webinar, and thank you for your interest in the Jina AI community! ๐Ÿ˜Š

Hi @Nick17t this is very interesting project and I have worked on similar kind of project where we have to create the new backend module for JAX. And to make DocumentArrayStack compatible with JAX we need to ensure that DocumentArrayStack works seamlessly with the Jax backend. This will involve testing the existing DocumentArrayStack code with the new Jax backend and resolving any compatibility issues that arise. And I love to work on this project ๐Ÿ˜.