yinboc/prototypical-network-pytorch

Single image prediction

Closed this issue · 1 comments

Hi
I am trying to create a demo version where single image can be taken and predicted. I am bit new to pytorch coding. and facing few issues.
I modified as below. I have loaded the query image and the 7 different support images into data_query and data_shot respectively.

def load_support():   
    img_data = []
    img_transform = transforms.Compose([
            transforms.Resize(84),
            transforms.CenterCrop(84),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])      
    for file in os.listdir("/home/abc/data/support/"):
        if file.endswith(".jpg"):
            path="/home/abc/data/support/"+file
            image = img_transform(Image.open(path).convert('RGB')) 
            img_data.append(image.tolist())
    img_data=torch.cuda.FloatTensor(img_data)
    return img_data
if __name__ == '__main__':
    
    img_transform = transforms.Compose([
            transforms.Resize(84),
            transforms.CenterCrop(84),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
 
    model = Convnet().cuda()
    model.load_state_dict(torch.load('./save/proto-17-5/max-acc.pth'))
    model.eval()

    data_shot= load_support()
    path="/home/abc/data/val/test-15.jpg"
    data_query = img_transform(Image.open(path).convert('RGB'))
    
    data_query=data_query.cuda()
    
    print("support",data_shot)
    print("query",data_query)
   
    x = model(data_shot)
    print("shape",x.shape)
    

> logits = euclidean_metric(model(data_query), x)

    label = torch.arange(7).repeat(1)
    label = label.type(torch.cuda.LongTensor)

the logits = euclidean_metric(model(data_query), x) line is throwing an error saying
Expected 4-dimensional input for 4-dimensional weight 64 3 3, but got 3-dimensional input of size [3, 84, 84] instead.

What is the additional parameter I am missing? Please guide.
Thanks

I had problem with data query tensor. Post that error got resolved ! :)