Inference on new Image
Opened this issue · 1 comments
sindhuatquadrant commented
Any pointers or code snippet to run inference on new image would be helpful
Thanks!
Votess4All commented
I write an infer script for it, hope it will help~~~
# %%
import clip
from utils import *
from datasets import build_dataset
import yaml
from PIL import Image
# %%
# 1. 加载clip预训练模型
cfg = yaml.load(open("./configs/cathouse.yaml", 'r'), Loader=yaml.Loader)
clip_model, preprocess = clip.load(cfg['backbone'])
clip_model.eval()
# 2. 创建一个adapter,并将训练好的权重加载进去
adapter_weight_path = "./caches/cat_house/best_F_16shots.pt"
adapter_weight = torch.load(adapter_weight_path)
adapter = nn.Linear(adapter_weight.shape[1], adapter_weight.shape[0], bias=False).to(clip_model.dtype).cuda()
adapter.weight = adapter_weight
# %%
from pathlib import Path
image_dir = "./images/cat_house"
classnames = [sub_dir.name for sub_dir in Path(image_dir).iterdir() if sub_dir.is_dir()]
classnames.sort()
from datasets.cathouse import template
# 针对每一个classname用
clip_weights = clip_classifier(classnames, template, clip_model)
# %%
val_features_path = "/home/pengyuyan819/code/Tip-Adapter/caches/cat_house/val_f.pt"
val_values_path = "/home/pengyuyan819/code/Tip-Adapter/caches/cat_house/val_l.pt"
val_features = torch.load(val_features_path)
val_labels = torch.load(val_values_path)
cache_keys_path = "/home/pengyuyan819/code/Tip-Adapter/caches/cat_house/keys_16shots.pt"
cache_values_path = "/home/pengyuyan819/code/Tip-Adapter/caches/cat_house/values_16shots.pt"
cache_keys = torch.load(cache_keys_path)
cache_values = torch.load(cache_values_path)
# 在验证集上进行参数搜索
best_beta, best_alpha = search_hp(
cfg, cache_keys, cache_values,
val_features, val_labels, clip_weights, adapter=adapter)
print(best_beta, best_alpha) # 8.7275 4.755,这个应该是要记录下来的
# %%
def extract_image_feature(img_path, preprocess):
img_arr = preprocess(Image.open(img_path).convert('RGB')).unsqueeze(0)
# 提取图像特征
with torch.no_grad():
image = img_arr.cuda()
image_feature = clip_model.encode_image(image)
image_feature /= image_feature.norm(dim=-1, keepdim=True)
return image_feature
# %%
import glob
from sklearn.metrics import classification_report
y_pred, y_true = [], []
for i, image_path in enumerate(glob.glob(image_dir+"/**/*.jpg")):
label = classnames.index(Path(image_path).parents[0].name)
image_feature = extract_image_feature(
img_path=image_path,
preprocess=preprocess
)
clip_logits = 100. * image_feature @ clip_weights
affinity = adapter(image_feature)
cache_logits = ((-1) * (best_beta - best_beta * affinity)).exp() @ cache_values # cache_values还做了one-hot encoding
tip_logits = clip_logits + cache_logits * best_alpha
pred = torch.argmax(tip_logits).detach().cpu().numpy()
y_pred.append(int(pred))
y_true.append(label)
print(classification_report(y_true, y_pred))
# %%
print(classification_report(y_true, y_pred))
# %%
classnames
# %%