Azure/azureml-examples

Loading pre-built local Faiss Index causes ValueError (allow_dangerous_deserialization)

daviddwlee84 opened this issue · 0 comments

Operating System

Windows

Version Information

Python Version: 3.11.5

promptflow                               1.6.0
promptflow-tools                         1.3.0
promptflow_vectordb                      0.2.5

langchain                                0.1.12
langchain-community                      0.0.28
langchain-core                           0.1.32
langchain-experimental                   0.0.43
langchain-openai                         0.0.8
langchain-text-splitters                 0.0.1

faiss-cpu                                1.7.4

Steps to reproduce

This can be easily reproduced by following the tutorial notebook and then reload the saved Faiss index by rerun the cell

"MODEL_API_VERSION = \"2023-05-15\"\n",
"MODEL_DEPLOYMENT_NAME = \"text-embedding-ada-002\"\n",
"DIMENSION = 1536\n",
"\n",
"# Configure an embedding store to store index file.\n",
"store_path = os.path.join(os.getcwd(), \"faiss_index_store\")\n",
"config = StoreCoreConfig.create_config(\n",
" storage_type=StorageType.LOCAL,\n",
" store_identifier=store_path,\n",
" model_type=EmbeddingModelType.AOAI,\n",
" model_api_base=os.environ[\"Azure_OpenAI_MODEL_ENDPOINT\"],\n",
" model_api_key=os.environ[\"Azure_OpenAI_MODEL_API_KEY\"],\n",
" model_api_version=MODEL_API_VERSION,\n",
" model_name=MODEL_DEPLOYMENT_NAME,\n",
" dimension=DIMENSION,\n",
" create_if_not_exists=True,\n",
")\n",
"store = EmbeddingStoreCore(config)"

Expected behavior

Should successfully load pre-built Faiss index without error

Actual behavior

Got ValueError

https://github.com/langchain-ai/langchain/blob/40f846e65da37a1c00d72da9ea64ebb0f295b016/libs/community/langchain_community/vectorstores/faiss.py#L1054-L1089

Addition information

Should somehow pass allow_dangerous_deserialization=True to use local pickle vector db checkpoint.

I was able to bypass this error by changing promptflow_vectordb/core/engine/langchain_engine.py

# From
self.__langchain_faiss = FAISS.load_local(path, LangchainEmbedding(self.__embedding))
# To
self.__langchain_faiss = FAISS.load_local(path, LangchainEmbedding(self.__embedding), allow_dangerous_deserialization=True)