long time persist model parameters
cyrushu opened this issue · 0 comments
cyrushu commented
Dear developers,
Here is a solution for long time persist model parameters. Which would save some networks. It would be better to have a sha256 check inside the cache-check process.
diff --git a/chroma/utility/api.py b/chroma/utility/api.py
index 902b776..ce996c8 100644
--- a/chroma/utility/api.py
+++ b/chroma/utility/api.py
@@ -21,7 +21,11 @@ import requests
import chroma
-ROOT_DIR = os.path.dirname(os.path.dirname(chroma.__file__))
+# SETTING CHROMA_ROOT_DIR or use default directory: ~/.config/chroma
+ROOT_DIR = os.environ.get(
+ "CHROMA_ROOT_DIR",
+ os.path.join(os.path.expanduser("~"), ".config", "chroma"))
+os.makedirs(ROOT_DIR, exist_ok=True)
def register_key(key: str, key_directory=ROOT_DIR) -> None:
@@ -92,11 +96,8 @@ def download_from_generate(
# Create a hash of the URL + weight name to determine the path for the cached/temporary file
url_hash = hashlib.md5((base_url + weights_name).encode()).hexdigest()
- temp_dir = os.path.join(tempfile.gettempdir(), "chroma_weights", url_hash)
- destination = os.path.join(temp_dir, "weights.pt")
-
- # Ensure the directory exists
- os.makedirs(temp_dir, exist_ok=True)
+ os.makedirs(os.path.join(ROOT_DIR, "weights"), exist_ok=True)
+ destination = os.path.join(ROOT_DIR, "weights", f"{url_hash}.pt")
# Check if cache exists
cache_exists = os.path.exists(destination)
@@ -117,8 +118,14 @@ def download_from_generate(
response = requests.get(base_url, params=params)
response.raise_for_status() # Raise an error for HTTP errors
- with open(destination, "wb") as file:
- file.write(response.content)
+ # Write into temp_file
+ temp_file = tempfile.TemporaryFile()
+ temp_file.write(response.content)
+
+ # Write into cached destination
+ with open(destination, "wb") as f:
+ temp_file.seek(0)
+ f.write(temp_file.read())
print(f"Data saved to {destination}")
return destination