databricks/lilac

Add `dataset.map(embeddings=True)`

Opened this issue · 0 comments

Reading lilac's embeddings in a map is hard today (and private API):

ds = ll.get_dataset('local', 'glave-coder-sample')
import numpy as np


def get_similarity(x):
  rowid = x[ll.ROWID]
  path_key = (rowid,)

  question_index = ds._get_vector_db_index('jina-v2-small', ('question',))
  question_vec = next(question_index.get(keys=[path_key]))[0]['vector']

  answer_index = ds._get_vector_db_index('jina-v2-small', ('answer',))
  answer_vec = next(answer_index.get(keys=[path_key]))[0]['vector']
  return float(np.dot(question_vec, answer_vec))

ds.map(get_similarity, output_column='similarity', overwrite=True)