microsoft/LLaVA-Med

Can't we train and fine tune the Llavamed model

liucheny opened this issue · 9 comments

Can't we train and fine tune the Llavamed model

@thedaffodil can you explain a bit more?

do you mean you cloned the original llava repo and run the training specified in the link you provided but changed the weights to be the weights of llava-med?

@thedaffodil Is it to modify the weight and directly fine tune it

@thedaffodil Is the environment Llavamed or Llava

name: llava
channels:

  • xformers
  • conda-forge
  • defaults
    dependencies:
  • _libgcc_mutex=0.1=conda_forge
  • _openmp_mutex=4.5=2_gnu
  • asttokens=2.4.1=pyhd8ed1ab_0
  • blas=1.0=mkl
  • bzip2=1.0.8=h5eee18b_6
  • ca-certificates=2024.7.4=hbcca054_0
  • cffi=1.16.0=py310h5eee18b_1
  • comm=0.2.2=pyhd8ed1ab_0
  • cudatoolkit=11.3.1=h2bc3f7f_2
  • debugpy=1.6.7=py310h6a678d5_0
  • decorator=5.1.1=pyhd8ed1ab_0
  • exceptiongroup=1.2.2=pyhd8ed1ab_0
  • executing=2.0.1=pyhd8ed1ab_0
  • future=0.18.3=py310h06a4308_0
  • importlib-metadata=8.2.0=pyha770c72_0
  • importlib_metadata=8.2.0=hd8ed1ab_0
  • intel-openmp=2021.4.0=h06a4308_3561
  • ipykernel=6.29.5=pyh3099207_0
  • ipython=8.26.0=pyh707e725_0
  • jedi=0.19.1=pyhd8ed1ab_0
  • jupyter_client=8.6.2=pyhd8ed1ab_0
  • jupyter_core=5.7.2=py310hff52083_0
  • ld_impl_linux-64=2.38=h1181459_1
  • libffi=3.4.4=h6a678d5_1
  • libgcc-ng=14.1.0=h77fa898_0
  • libgomp=14.1.0=h77fa898_0
  • libprotobuf=3.20.3=he621ea3_0
  • libsodium=1.0.18=h36c2ea0_1
  • libstdcxx-ng=11.2.0=h1234567_1
  • libuuid=1.41.5=h5eee18b_0
  • matplotlib-inline=0.1.7=pyhd8ed1ab_0
  • mkl=2021.4.0=h06a4308_640
  • mkl-service=2.4.0=py310h7f8727e_0
  • mkl_fft=1.3.1=py310hd6ae3a3_0
  • mkl_random=1.2.2=py310h00e6091_0
  • ncurses=6.4=h6a678d5_0
  • nest-asyncio=1.6.0=pyhd8ed1ab_0
  • ninja-base=1.10.2=hd09550d_5
  • numpy-base=1.24.3=py310h8e6c178_0
  • openssl=3.3.1=h4bc722e_2
  • packaging=24.1=pyhd8ed1ab_0
  • parso=0.8.4=pyhd8ed1ab_0
  • pexpect=4.9.0=pyhd8ed1ab_0
  • pickleshare=0.7.5=py_1003
  • platformdirs=4.2.2=pyhd8ed1ab_0
  • prompt-toolkit=3.0.47=pyha770c72_0
  • psutil=6.0.0=py310hc51659f_0
  • ptyprocess=0.7.0=pyhd3deb0d_0
  • pure_eval=0.2.3=pyhd8ed1ab_0
  • pycparser=2.21=pyhd3eb1b0_0
  • pygments=2.18.0=pyhd8ed1ab_0
  • python=3.10.14=h955ad1f_1
  • python_abi=3.10=2_cp310
  • pyyaml=6.0.1=py310h5eee18b_0
  • pyzmq=25.1.2=py310h6a678d5_0
  • readline=8.2=h5eee18b_0
  • setuptools=69.5.1=py310h06a4308_0
  • six=1.16.0=pyhd3eb1b0_1
  • sqlite=3.45.3=h5eee18b_0
  • stack_data=0.6.2=pyhd8ed1ab_0
  • tk=8.6.14=h39e8969_0
  • tornado=6.4.1=py310hc51659f_0
  • traitlets=5.14.3=pyhd8ed1ab_0
  • typing_extensions=4.11.0=py310h06a4308_0
  • wcwidth=0.2.13=pyhd8ed1ab_0
  • wheel=0.43.0=py310h06a4308_0
  • xformers=0.0.22=py310_cu11.6.2_pyt1.12.1
  • xz=5.4.6=h5eee18b_1
  • yaml=0.2.5=h7b6447c_0
  • zeromq=4.3.5=h6a678d5_0
  • zipp=3.19.2=pyhd8ed1ab_0
  • zlib=1.2.13=h5eee18b_1
  • pip:
    • accelerate==0.21.0
    • aiofiles==23.2.1
    • altair==5.3.0
    • annotated-types==0.7.0
    • anyio==4.4.0
    • attrs==23.2.0
    • bitsandbytes==0.43.2
    • certifi==2024.7.4
    • charset-normalizer==3.3.2
    • click==8.1.7
    • contourpy==1.2.1
    • cycler==0.12.1
    • deepspeed==0.14.4
    • dnspython==2.6.1
    • docker-pycreds==0.4.0
    • einops==0.6.1
    • einops-exts==0.0.4
    • email-validator==2.2.0
    • fastapi==0.111.1
    • fastapi-cli==0.0.4
    • ffmpy==0.3.2
    • filelock==3.15.4
    • flash-attn==2.6.3
    • fonttools==4.53.1
    • fsspec==2024.6.1
    • gitdb==4.0.11
    • gitpython==3.1.43
    • gradio==4.16.0
    • gradio-client==0.8.1
    • h11==0.14.0
    • hjson==3.1.0
    • httpcore==0.17.3
    • httptools==0.6.1
    • httpx==0.24.0
    • huggingface-hub==0.24.2
    • idna==3.7
    • importlib-resources==6.4.0
    • jinja2==3.1.4
    • joblib==1.4.2
    • jsonschema==4.23.0
    • jsonschema-specifications==2023.12.1
    • kiwisolver==1.4.5
    • latex2mathml==3.77.0
    • llava==1.2.2.post1
    • markdown-it-py==3.0.0
    • markdown2==2.5.0
    • markupsafe==2.1.5
    • matplotlib==3.9.1
    • mdurl==0.1.2
    • mpmath==1.3.0
    • networkx==3.3
    • ninja==1.11.1.1
    • numpy==1.26.4
    • nvidia-cublas-cu11==11.10.3.66
    • nvidia-cublas-cu12==12.1.3.1
    • nvidia-cuda-cupti-cu12==12.1.105
    • nvidia-cuda-nvrtc-cu11==11.7.99
    • nvidia-cuda-nvrtc-cu12==12.1.105
    • nvidia-cuda-runtime-cu11==11.7.99
    • nvidia-cuda-runtime-cu12==12.1.105
    • nvidia-cudnn-cu11==8.5.0.96
    • nvidia-cudnn-cu12==8.9.2.26
    • nvidia-cufft-cu12==11.0.2.54
    • nvidia-curand-cu12==10.3.2.106
    • nvidia-cusolver-cu12==11.4.5.107
    • nvidia-cusparse-cu12==12.1.0.106
    • nvidia-ml-py==12.555.43
    • nvidia-nccl-cu12==2.18.1
    • nvidia-nvjitlink-cu12==12.5.82
    • nvidia-nvtx-cu12==12.1.105
    • orjson==3.10.6
    • pandas==2.2.2
    • peft==0.12.0
    • pillow==10.4.0
    • pip==24.1.2
    • protobuf==5.27.2
    • py-cpuinfo==9.0.0
    • pydantic==2.8.2
    • pydantic-core==2.20.1
    • pydub==0.25.1
    • pyparsing==3.1.2
    • python-dateutil==2.9.0.post0
    • python-dotenv==1.0.1
    • python-multipart==0.0.9
    • pytz==2024.1
    • referencing==0.35.1
    • regex==2024.7.24
    • requests==2.32.3
    • rich==13.7.1
    • rpds-py==0.19.1
    • ruff==0.5.5
    • safetensors==0.4.3
    • scikit-learn==1.2.2
    • scipy==1.14.0
    • semantic-version==2.10.0
    • sentencepiece==0.1.99
    • sentry-sdk==2.12.0
    • setproctitle==1.3.3
    • shellingham==1.5.4
    • shortuuid==1.0.13
    • smmap==5.0.1
    • sniffio==1.3.1
    • starlette==0.37.2
    • svgwrite==1.4.3
    • sympy==1.13.1
    • threadpoolctl==3.5.0
    • tiktoken==0.7.0
    • timm==0.6.13
    • tokenizers==0.15.1
    • tomlkit==0.12.0
    • toolz==0.12.1
    • torch==2.1.2
    • torchaudio==2.1.2+cu118
    • torchvision==0.16.2
    • tqdm==4.66.4
    • transformers==4.37.2
    • triton==2.1.0
    • typer==0.12.3
    • typing-extensions==4.12.2
    • tzdata==2024.1
    • urllib3==2.2.2
    • uvicorn==0.30.3
    • uvloop==0.19.0
    • wandb==0.17.5
    • watchfiles==0.22.0
    • wavedrom==2.0.3.post3
    • websockets==11.0.3
      prefix: /home/user/miniconda3/envs/llava

my yaml file is like above. I use llava repo with llava-med weights

i am trying to finetune llava initialized with llava-med on my own task.

so far i tried running llava/train/train_mem.py
with parameters:
--deepspeed
./scripts/zero3.json
--model_name_or_path
microsoft/llava-med-v1.5-mistral-7b
--data_path
./playground/data/llava_v1_5_mix665k.json
--image_folder
./playground/data
--vision_tower
openai/clip-vit-large-patch14-336
--mm_vision_select_layer
-2
--mm_use_im_start_end
True
--mm_use_im_patch_token
False
--image_aspect_ratio
pad
--group_by_modality_length
True
--bf16
False
--output_dir
./checkpoints/llava-v1.5-13b
--num_train_epochs
1
--per_device_train_batch_size
16
--per_device_eval_batch_size
4
--gradient_accumulation_steps
1
--evaluation_strategy
"no"
--save_strategy
"steps"
--save_steps
50000
--save_total_limit
1
--learning_rate
2e-5
--weight_decay
0.
--warmup_ratio
0.03
--lr_scheduler_type
"cosine"
--logging_steps
1
--tf32
False
--model_max_length
2048
--gradient_checkpointing
True
--dataloader_num_workers
4
--lazy_preprocess
True
--report_to
wandb

but i notice that the loading of the model is done with LLavaLLama model instead of a mistral one, and i can't figure out where to modify this.

any ideas? and generally where can i find more info on how to finetune llava med?

I use the this command:
#!/bin/bash

deepspeed llava/train/train_mem.py
--lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5
--deepspeed ./scripts/zero3.json
--model_name_or_path ./llava-med-v1.5-mistral-7b
--version v1
--data_path ./dataSlake/train.json
--image_folder ./dataSlake/imgs
--vision_tower openai/clip-vit-large-patch14-336
--mm_projector_type mlp2x_gelu
--mm_vision_select_layer -2
--mm_use_im_start_end False
--mm_use_im_patch_token False
--image_aspect_ratio pad
--group_by_modality_length True
--bf16 True
--output_dir ./checkpoints/llava-version1
--num_train_epochs 1
--per_device_train_batch_size 4
--per_device_eval_batch_size 4
--gradient_accumulation_steps 1
--evaluation_strategy "no"
--save_strategy "steps"
--save_steps 50000
--save_total_limit 1
--learning_rate 2e-4
--weight_decay 0.
--warmup_ratio 0.03
--lr_scheduler_type "cosine"
--logging_steps 1
--tf32 True
--model_max_length 2048
--gradient_checkpointing True
--dataloader_num_workers 2
--lazy_preprocess True
--report_to wandb

after that I merged the output model and the base model to get weights with the code in the link below
https://github.com/haotian-liu/LLaVA/blob/main/scripts/merge_lora_weights.py

then I could use fine tuned model to eval.

you can ask your further questions via my email if you need help.

@thedaffodil I am here https://huggingface.co/microsoft/llava-med-v1.5-mistral-7b/tree/main Download the model from above and fine tune it with your script. The result shows that you are using a model of type llava_stistral to instruct a model of type llava_1lama This is not supported for all configurations of models and can yield errors.
You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with model.to('cuda')
And torch.distributed.elastic.multiprocesse.errors ChildFailedError:

llava/train/train_mem.py FAILED

Failures:
[1]:
time : 2024-08-20_07:47:25
host : 9c813e5131ac
rank : 1 (local_rank: 1)
exitcode : -7 (pid: 1044)
error_file: <N/A>
traceback : Signal 7 (SIGBUS) received by PID 1044
[2]:
time : 2024-08-20_07:47:25
host : 9c813e5131ac
rank : 2 (local_rank: 2)
exitcode : -7 (pid: 1045)
error_file: <N/A>
traceback : Signal 7 (SIGBUS) received by PID 1045
[3]:
time : 2024-08-20_07:47:25
host : 9c813e5131ac
rank : 3 (local_rank: 3)
exitcode : -7 (pid: 1046)
error_file: <N/A>
traceback : Signal 7 (SIGBUS) received by PID 1046

Root Cause (first observed failure):
0
time : 2024-08-20_07:47:25
host : 9c813e5131ac
rank : 0 (local_rank: 0)
exitcode : -7 (pid: 1043)
error_file: <N/A>
traceback : Signal 7 (SIGBUS) received by PID 1043

while you are fine-tuning, your output model folder name should consist "finetune" not "llava".

while you are merging, your output folder name should consist "llava"

while you are fine-tuning, your output model folder name should consist "finetune" not "llava".

while you are merging, your output folder name should consist "llava"

Hi, I'm excited to see your training success, can you provide a sample of the training data json? I see that the script you are using is a v1 template for training, using --version v1. Can you please tell me if you are forming the json file according to the format of the file https://hanoverprod.z21.web.core.windows.net/med_llava/instruct/llava_med_instruct_10k.json