cydonia999/VGGFace2-pytorch

Feature Extraction on own pictures

kevinisbest opened this issue · 4 comments

Hi, thank you for posting this implementation.

I wonder know is it possible that use the pre-trained model as a feature extractor and compare the distance on my own pictures?

I used following command:
python demo.py extract --arch_type resnet50_ft --weight_file ./weights/resnet50_ft_weight.pkl --test_img_list_file ./test_images/130.jpg --feature_dir ./output/ --meta_file identity_meta.csv

but got the error:
AssertionError: root: /path/to/dataset_directory not found.

It seems I need to download the VGGFace2 dataset, right?

I appreciate that if anyone can give me a hint, thanks!

I don't think it is required for feature extraction...

From first glance the error appears because you have no --dataset_dir argument. This becomes variable root which seems to be required by the dataset in line 101 of demo.py

    dv = datasets.VGG_Faces2(root, test_img_list_file, id_label_dict, split='valid',
horizontal_flip=args.horizontal_flip)

After looking through the file, it seems this is not required for anything but training and evaluation so...

maybe try convert:

if args.cmd == 'train':
    dt = datasets.VGG_Faces2(root, train_img_list_file, id_label_dict, split='train')
    train_loader = torch.utils.data.DataLoader(dt, batch_size=args.batch_size, shuffle=True, **kwargs)

dv = datasets.VGG_Faces2(root, test_img_list_file, id_label_dict, split='valid',
    horizontal_flip=args.horizontal_flip)
val_loader = torch.utils.data.DataLoader(dv, batch_size=args.batch_size, shuffle=False, **kwargs)

to

if args.cmd == 'train':
    dt = datasets.VGG_Faces2(root, train_img_list_file, id_label_dict, split='train')
    train_loader = torch.utils.data.DataLoader(dt, batch_size=args.batch_size, shuffle=True, **kwargs)

    dv = datasets.VGG_Faces2(root, test_img_list_file, id_label_dict, split='valid',
        horizontal_flip=args.horizontal_flip)
    val_loader = torch.utils.data.DataLoader(dv, batch_size=args.batch_size, shuffle=False, **kwargs)

Note: untried, untested, left to someone else to verify.

import torchvision.models as models
import pickle

model = models.resnet50(num_classes=8631,pretrained=False)
with open('resnet50_ft_weight.pkl', 'rb') as f:
        obj = f.read()
weights = {key: torch.from_numpy(arr) for key, arr in pickle.loads(obj, encoding='latin1').items()}
model.load_state_dict(weights)


scaler = transforms.Scale((224, 224))
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
to_tensor = transforms.ToTensor()

model.eval()

imagepath = 'Database/10_a.JPG'
image = Image.open(imagepath)
imgblob = Variable(normalize(to_tensor(scaler(image))).unsqueeze(0))

tf_last_layer_chopped = nn.Sequential(*list(model.children())[:-1])
output = tf_last_layer_chopped(imgblob)
print(np.shape(output))

This worked for me :)

This worked for me, too:) Thank VictorVarela for his code.

mean_bgr = np.array([91.4953, 103.8827, 131.0912])

def transform(img):
img = img[:, :, ::-1] # RGB -> BGR
img = img.astype(np.float32)
img -= mean_bgr
img = img.transpose(2, 0, 1) # C x H x W
img = torch.from_numpy(img).float()
return img

model = senet50(num_classes=8631, include_top=False)
load_state_dict(model, "models/senet50_ft_weight.pkl")
model.eval()

imagepath = 'samples_0.png'
image = Image.open(imagepath)
image = Resize(256)(image)
image = CenterCrop(224)(image)
image = np.array(image, dtype=np.uint8)
image = transform(image).unsqueeze(0)

output = model.forward(image)
output = output.view(output.size(0), -1).detach().numpy()
print(output.shape)

This worked for me, too:) Thank VictorVarela for his code.

mean_bgr = np.array([91.4953, 103.8827, 131.0912])

def transform(img):
img = img[:, :, ::-1] # RGB -> BGR
img = img.astype(np.float32)
img -= mean_bgr
img = img.transpose(2, 0, 1) # C x H x W
img = torch.from_numpy(img).float()
return img

model = senet50(num_classes=8631, include_top=False)
load_state_dict(model, "models/senet50_ft_weight.pkl")
model.eval()

imagepath = 'samples_0.png'
image = Image.open(imagepath)
image = Resize(256)(image)
image = CenterCrop(224)(image)
image = np.array(image, dtype=np.uint8)
image = transform(image).unsqueeze(0)

output = model.forward(image)
output = output.view(output.size(0), -1).detach().numpy()
print(output.shape)

Your code looks not complete, some funcs are missing