huggingface/diffusion-fast

lora support for optimization

Sandeep-Narahari opened this issue · 3 comments

I was using the torch compile optimization for speeding the inference time

Here I am using the dreambooth lora model which was trained on juggernut

when making the inference its not compiling

pipe.load_lora_weights(prj_path, weight_name="pytorch_lora_weights.safetensors")

is there any way so that I can able to use this optimzation for dreambooth lora models

packages I am using

Package Version


absl-py 2.1.0
accelerate 0.26.1
aiofiles 23.2.1
aiohttp 3.9.3
aiosignal 1.3.1
albumentations 1.3.1
alembic 1.13.1
altair 5.2.0
annotated-types 0.6.0
anyio 3.7.1
arrow 1.3.0
async-timeout 4.0.3
attrs 23.2.0
Authlib 1.3.0
autotrain-advanced 0.6.92
bitsandbytes 0.42.0
Brotli 1.1.0
cachetools 5.3.2
certifi 2023.11.17
cffi 1.16.0
charset-normalizer 3.3.2
click 8.1.7
cmaes 0.10.0
cmake 3.28.3
codecarbon 2.2.3
colorlog 6.8.2
contourpy 1.1.1
cryptography 42.0.3
cycler 0.12.1
datasets 2.14.7
diffusers 0.21.4
dill 0.3.8
docstring-parser 0.15
einops 0.6.1
evaluate 0.3.0
exceptiongroup 1.2.0
fastapi 0.104.1
ffmpy 0.3.1
filelock 3.13.1
fonttools 4.47.2
frozenlist 1.4.1
fsspec 2023.10.0
fuzzywuzzy 0.18.0
google-auth 2.27.0
google-auth-oauthlib 1.0.0
GPUtil 1.4.0
gradio 3.41.0
gradio-client 0.5.0
greenlet 3.0.3
grpcio 1.60.0
h11 0.14.0
hf-transfer 0.1.5
httpcore 1.0.2
httpx 0.26.0
huggingface-hub 0.20.3
idna 3.6
imageio 2.33.1
importlib-metadata 7.0.1
importlib-resources 6.1.1
inflate64 1.0.0
install 1.3.5
invisible-watermark 0.2.0
ipadic 1.0.0
itsdangerous 2.1.2
Jinja2 3.1.3
jiwer 3.0.2
joblib 1.3.1
jsonschema 4.21.1
jsonschema-specifications 2023.12.1
kiwisolver 1.4.5
lazy-loader 0.3
loguru 0.7.0
Mako 1.3.2
Markdown 3.5.2
markdown-it-py 3.0.0
MarkupSafe 2.1.4
matplotlib 3.7.4
mdurl 0.1.2
mpmath 1.3.0
multidict 6.0.4
multiprocess 0.70.16
multivolumefile 0.2.3
networkx 3.1
nltk 3.8.1
numpy 1.24.4
nvidia-cublas-cu11 11.11.3.6
nvidia-cublas-cu12 12.1.3.1
nvidia-cuda-cupti-cu11 11.8.87
nvidia-cuda-cupti-cu12 12.1.105
nvidia-cuda-nvrtc-cu11 11.8.89
nvidia-cuda-nvrtc-cu12 12.1.105
nvidia-cuda-runtime-cu11 11.8.89
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu11 8.7.0.84
nvidia-cudnn-cu12 8.9.2.26
nvidia-cufft-cu11 10.9.0.58
nvidia-cufft-cu12 11.0.2.54
nvidia-curand-cu11 10.3.0.86
nvidia-curand-cu12 10.3.2.106
nvidia-cusolver-cu11 11.4.1.48
nvidia-cusolver-cu12 11.4.5.107
nvidia-cusparse-cu11 11.7.5.86
nvidia-cusparse-cu12 12.1.0.106
nvidia-nccl-cu11 2.19.3
nvidia-nccl-cu12 2.19.3
nvidia-nvjitlink-cu12 12.3.101
nvidia-nvtx-cu11 11.8.86
nvidia-nvtx-cu12 12.1.105
oauthlib 3.2.2
opencv-python 4.9.0.80
opencv-python-headless 4.9.0.80
optuna 3.3.0
orjson 3.9.12
packaging 23.1
pandas 2.0.3
peft 0.8.2
Pillow 10.0.0
pip 20.0.2
pkg-resources 0.0.0
pkgutil-resolve-name 1.3.10
protobuf 4.23.4
psutil 5.9.8
py-cpuinfo 9.0.0
py7zr 0.20.6
pyarrow 15.0.0
pyarrow-hotfix 0.6
pyasn1 0.5.1
pyasn1-modules 0.3.0
pybcj 1.0.2
pycparser 2.21
pycryptodomex 3.20.0
pydantic 2.4.2
pydantic-core 2.10.1
pydub 0.25.1
pygments 2.17.2
pyngrok 7.0.3
pynvml 11.5.0
pyparsing 3.1.1
pyppmd 1.0.0
python-dateutil 2.8.2
python-dotenv 1.0.1
python-multipart 0.0.6
pytorch-triton 3.0.0+901819d2b6
pytz 2023.4
PyWavelets 1.4.1
PyYAML 6.0.1
pyzstd 0.15.9
qudida 0.0.4
rapidfuzz 2.13.7
referencing 0.33.0
regex 2023.12.25
requests 2.31.0
requests-oauthlib 1.3.1
responses 0.18.0
rich 13.7.0
rouge-score 0.1.2
rpds-py 0.17.1
rsa 4.9
sacremoses 0.0.53
safetensors 0.4.2
scikit-image 0.21.0
scikit-learn 1.3.0
scipy 1.10.1
semantic-version 2.10.0
sentencepiece 0.1.99
setuptools 44.0.0
shtab 1.6.5
six 1.16.0
sniffio 1.3.0
SQLAlchemy 2.0.25
starlette 0.27.0
sympy 1.12
tensorboard 2.14.0
tensorboard-data-server 0.7.2
texttable 1.7.0
threadpoolctl 3.2.0
tifffile 2023.7.10
tiktoken 0.5.1
tokenizers 0.15.1
toolz 0.12.1
torch 2.3.0.dev20240221+cu118
torchaudio 2.2.0+cu118
torchtriton 2.0.0+f16138d447
torchvision 0.17.0
tqdm 4.65.0
transformers 4.37.0
triton 2.2.0
trl 0.7.11
types-python-dateutil 2.8.19.20240106
typing-extensions 4.9.0
tyro 0.7.0
tzdata 2023.4
urllib3 2.2.0
uvicorn 0.22.0
websockets 11.0.3
Werkzeug 2.3.6
wheel 0.34.2
xformers 0.0.24
xgboost 1.7.6
xxhash 3.4.1
yarl 1.9.4
zipp 3.17.0

image

You should call fuse_lora() after loading the LoRA checkpoint. And then call compile.

Closing it?

Awesome working fne now

Thanks