bananaml/serverless-template

0 GPU with Tensorflow

Opened this issue · 1 comments

s1223 commented

Problem:

  1. GPU not detected if using the latest TensorFlow library.

How to solve this problem?

My code:
app.py:

import tensorflow as tf

# Init is ran on server startup
# Load your model to GPU as a global variable here using the variable name "model"
def init():
    pass

# Inference is ran for every server call
# Reference your preloaded global model variable here.
def inference(model_inputs:dict) -> dict:
    # Parse out your arguments
    prompt = model_inputs.get('url', None)
    if prompt == None:
        return {'message': "No url provided"}
    
    # Run the model
    physical_devices = tf.config.list_physical_devices('GPU')
    result = f"Number of GPU : {len(physical_devices)}"

    # Return the results as a dictionary
    return result

download.py :

# In this file, we define download_model
# It runs during container build time to get model weights built into the container

# In this example: A Huggingface BERT model

import tensorflow as tf

def download_model():
    # do a dry run of loading the huggingface model, which will download weights
    pass

if __name__ == "__main__":
    download_model()

requirements.txt

sanic
transformers
accelerate
torch
tensorflow

test2.py :

import banana_dev as banana
import time

api_key = "api_key" # "YOUR_API_KEY"
model_key = "model_key" # "YOUR_MODEL_KEY"
model_inputs = {'url': 'Hello I am a [MASK] model.'}

startTime = time.time()
out = banana.run(api_key, model_key, model_inputs)
print(out)
endTime =  time.time()
print("Time: ", endTime - startTime)

Result:

{'id': '7fb0bec8-8875-408e-99ea-12c867de3e19', 'message': '', 'created': 1658303998, 'apiVersion': '12 May 2022', 'modelOutputs': ['Number of GPU : 0']}
Time:  8.601353645324707

Hey! We do support tensorflow and haven't seen this error before, will debug and update you