gaopengcuhk/Tip-Adapter

Inference on new Image

Opened this issue · 1 comments

Any pointers or code snippet to run inference on new image would be helpful

Thanks!

I write an infer script for it, hope it will help~~~

# %%
import clip
from utils import *
from datasets import build_dataset
import yaml
from PIL import Image

# %%
# 1. 加载clip预训练模型
cfg = yaml.load(open("./configs/cathouse.yaml", 'r'), Loader=yaml.Loader)
clip_model, preprocess = clip.load(cfg['backbone'])
clip_model.eval()

# 2. 创建一个adapter,并将训练好的权重加载进去
adapter_weight_path = "./caches/cat_house/best_F_16shots.pt"
adapter_weight = torch.load(adapter_weight_path)
adapter = nn.Linear(adapter_weight.shape[1], adapter_weight.shape[0], bias=False).to(clip_model.dtype).cuda()
adapter.weight = adapter_weight 

# %%
from pathlib import Path 

image_dir = "./images/cat_house"
classnames = [sub_dir.name for sub_dir in Path(image_dir).iterdir() if sub_dir.is_dir()]
classnames.sort()

from datasets.cathouse import template

# 针对每一个classname用
clip_weights = clip_classifier(classnames, template, clip_model)

# %%
val_features_path = "/home/pengyuyan819/code/Tip-Adapter/caches/cat_house/val_f.pt"
val_values_path = "/home/pengyuyan819/code/Tip-Adapter/caches/cat_house/val_l.pt"

val_features = torch.load(val_features_path)
val_labels   = torch.load(val_values_path)  

cache_keys_path = "/home/pengyuyan819/code/Tip-Adapter/caches/cat_house/keys_16shots.pt"
cache_values_path = "/home/pengyuyan819/code/Tip-Adapter/caches/cat_house/values_16shots.pt"

cache_keys   = torch.load(cache_keys_path)
cache_values = torch.load(cache_values_path)

# 在验证集上进行参数搜索
best_beta, best_alpha = search_hp(
    cfg, cache_keys, cache_values, 
    val_features, val_labels, clip_weights, adapter=adapter)
print(best_beta, best_alpha)  # 8.7275 4.755,这个应该是要记录下来的

# %%
def extract_image_feature(img_path, preprocess):
    img_arr = preprocess(Image.open(img_path).convert('RGB')).unsqueeze(0)

    # 提取图像特征
    with torch.no_grad():
        image = img_arr.cuda()
        image_feature = clip_model.encode_image(image)
        image_feature /= image_feature.norm(dim=-1, keepdim=True)

    return image_feature

# %%
import glob 
from sklearn.metrics import classification_report

y_pred, y_true = [], []
for i, image_path in enumerate(glob.glob(image_dir+"/**/*.jpg")):

    label = classnames.index(Path(image_path).parents[0].name)

    image_feature = extract_image_feature(
        img_path=image_path,
        preprocess=preprocess
    )
    clip_logits = 100. * image_feature @ clip_weights

    affinity = adapter(image_feature)
    cache_logits = ((-1) * (best_beta - best_beta * affinity)).exp() @ cache_values  # cache_values还做了one-hot encoding
    tip_logits = clip_logits + cache_logits * best_alpha
    pred = torch.argmax(tip_logits).detach().cpu().numpy()

    y_pred.append(int(pred))
    y_true.append(label)

print(classification_report(y_true, y_pred))

# %%
print(classification_report(y_true, y_pred))

# %%
classnames

# %%