Aleph-Alpha/magma

Improved inference interface

CoEich opened this issue · 1 comments

Implement an interface like the one Mayukh suggested:

from magma import Magma 
from magma.image import Image, ImageFromURL  ## to easily load/use images

model, tokenizer = Magma(checkpoint = 'model.pt', config = 'config.yml', device = 'cuda:0')

inputs = [
    Image('path/to/image.jpg'),
    'Where is this ? A: Egypt',
    ImageFromURL('url/to/image.jpg'),
    'Where is this ? A:'
]

embeddings = tokenizer.tokenize(inputs).to(model.device)

output = model.forward(embeddings, output_attentions = True)

logits = output.logits ## tensor of shape [1, len_seq, len_vocab]
attentions = output.attentions ## list of tensors

## this already exists https://gitlab.aleph-alpha.de/research/multimodal_fewshot/-/blob/master/multimodal_fewshot/model.py#L442
generated_text = model.generate(embeddings, n_steps = 10, *args)```

We can also use a unified image wrapper which handles both local images and urls as seen here

import requests
from io import BytesIO
import  PIL.Image as PilImage

class ImageInput():
    """Wrapper to handle image inputs both from local paths and urls

    Args:
        path_or_url (str): path or link to image.
    """
    def __init__(self, path_or_url):
        
        self.path_or_url = path_or_url
        if self.path_or_url.startswith("http://") or self.path_or_url.startswith("https://"):
            try:
                response = requests.get(path_or_url)
                self.pil_image = PilImage.open(BytesIO(response.content))
            except:
                raise Exception(f'Could not retrieve image from url:\n{self.path_or_url}')
        else:
            self.pil_image = PilImage.open(path_or_url)

    def get_image(self):  ## to be called internally
        return self.pil_image

If you want then I'll make a PR with this 🙂