LeiJ-USTB/MAF-Stereo

你好,我看见你的名字上面有北京科技大学的缩写,所以我使用中文和你交流

Closed this issue · 3 comments

我在测试你的模型的时候,有一个问题,在模型MAFStereo模型中没有变量'act1'和'bn1s',图1和图2是问题,图三是代码中的地方!
1
2
3

我的代码,在这里

from __future__ import print_function, division
import time
import argparse
import os
import numpy as np
import cv2
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
from PIL import Image
from torchvision import transforms
from tqdm import trange
import math
from datasets import middlebury_loader as mb
from datasets import readpfm as rp
from models import __models__
from utils import *

# 图片转化
def read_img(filename):

    img = cv2.cvtColor(filename, cv2.COLOR_BGR2RGB)

    return img

# 设定初始值
model = "MAF_Stereo"
maxdisp = 192
loadckpt = "./log/sceneflow.ckpt"

# 加载模型
model = __models__[model](maxdisp)
model.cuda()

# 加载预训练模型
state_dict = torch.load(loadckpt)
model.load_state_dict(state_dict['model'])
model.eval()

# 主函数
def main():
 
    cap = cv2.VideoCapture(0)

    width:int = 640
    hight:int = 240
    width_2:int = int(width / 2)

    print("cap", cap)
    cap.set(cv2.CAP_PROP_FRAME_WIDTH, width)     # 视频流中的帧宽
    cap.set(cv2.CAP_PROP_FRAME_HEIGHT, hight)    # 视频流中的帧高



    with torch.no_grad():
       
        try:
            while True:
                t0 = time.time()
                ret, frame = cap.read()

                cv2.imshow("frame", frame)
                
                left_img = frame[:, 0:width_2, :]
                right_img = frame[:, width_2:width, :]

                limg = read_img(left_img)
                rimg = read_img(right_img)
                w, h = limg.size
                wi, hi = (w // 32 + 1) * 32, (h // 32 + 1) * 32

                limg = limg.crop((w - wi, h - hi, w, h))
                rimg = rimg.crop((w - wi, h - hi, w, h))

                limg_tensor = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])(limg)
                rimg_tensor = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])(rimg)
                limg_tensor = limg_tensor.unsqueeze(0).cuda()
                rimg_tensor = rimg_tensor.unsqueeze(0).cuda()
                pred_disp = model(limg_tensor, rimg_tensor)[-1]

                pred_disp = pred_disp[:, hi - h:, wi - w:]

                pred_np = pred_disp.squeeze().cpu().numpy()

                print("pred_np:",pred_np)
                print(pred_np.shape)

                t1 = time.time()
                print("FPS:",math.floor(1/(t1-t0)))


        except KeyboardInterrupt:
            pass
    cap.release()
    cv2.destroyAllWindows()

if __name__ == '__main__':
    main()



您好,这部分的问题可能是由于您的timm版本不同导致的,请使用我首页提到的timm版本再次测试一下,后期timm版本中对预训练的mobilenetv2内部名称进行了调整。期待您的回复

好的,谢谢你,问题已经解决了

不客气