predict.py
CHSLAM opened this issue · 3 comments
CHSLAM commented
I have downloaded fpn_inception.h5.
Why I run python predict.py 0.png nothing happens?
ulucsahin commented
Because the code is not correct.
in predict.py, replace if __name__ == "__main__"
code block with:
if __name__ == "__main__":
image_path = "path-to-your-image.jpg"
main(image_path)
fighting666777 commented
I have downloaded fpn_inception.h5. Why I run python predict.py 0.png nothing happens?
Ihave the same problem,do you know how to slove it?
baselqt commented
here is my fix, place all jpg files you want to process in the test_img folder and run python predict.py only
and modify the path\to\your\test_img with the actual path
import os
from glob import glob
from typing import Optional
import cv2
import numpy as np
import torch
import yaml
from tqdm import tqdm
from aug import get_normalize
from models.networks import get_generator
class Predictor:
def __init__(self, weights_path: str, model_name: str = ''):
with open('config/config.yaml', encoding='utf-8') as cfg:
config = yaml.load(cfg, Loader=yaml.FullLoader)
model = get_generator(model_name or config['model'])
model.load_state_dict(torch.load(weights_path)['model'])
self.model = model.cuda()
self.model.train(True)
self.normalize_fn = get_normalize()
@staticmethod
def _array_to_batch(x):
x = np.transpose(x, (2, 0, 1))
x = np.expand_dims(x, 0)
return torch.from_numpy(x)
def _preprocess(self, x: np.ndarray, mask: Optional[np.ndarray]):
x, _ = self.normalize_fn(x, x)
if mask is None:
mask = np.ones_like(x, dtype=np.float32)
else:
mask = np.round(mask.astype('float32') / 255)
h, w, _ = x.shape
block_size = 32
min_height = (h // block_size + 1) * block_size
min_width = (w // block_size + 1) * block_size
pad_params = {'mode': 'constant',
'constant_values': 0,
'pad_width': ((0, min_height - h), (0, min_width - w), (0, 0))
}
x = np.pad(x, **pad_params)
mask = np.pad(mask, **pad_params)
return map(self._array_to_batch, (x, mask)), h, w
@staticmethod
def _postprocess(x: torch.Tensor) -> np.ndarray:
x, = x
x = x.detach().cpu().float().numpy()
x = (np.transpose(x, (1, 2, 0)) + 1) / 2.0 * 255.0
return x.astype('uint8')
def __call__(self, img: np.ndarray, mask: Optional[np.ndarray], ignore_mask=True) -> np.ndarray:
(img, mask), h, w = self._preprocess(img, mask)
with torch.no_grad():
inputs = [img.cuda()]
if not ignore_mask:
inputs += [mask]
pred = self.model(*inputs)
return self._postprocess(pred)[:h, :w, :]
def main(img_pattern: str,
weights_path='fpn_inception.h5',
out_dir='submit/',
side_by_side: bool = False):
def sorted_glob(pattern):
return sorted(glob(pattern))
imgs = sorted_glob(img_pattern)
names = sorted([os.path.basename(x) for x in glob(img_pattern)])
predictor = Predictor(weights_path=weights_path)
os.makedirs(out_dir, exist_ok=True)
print(f"Total images to process: {len(names)}")
for name, img_path in tqdm(zip(names, imgs), total=len(names)):
print(f"Processing: {name}")
img = cv2.imread(img_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
pred = predictor(img, None)
if side_by_side:
pred = np.hstack((img, pred))
pred = cv2.cvtColor(pred, cv2.COLOR_RGB2BGR)
cv2.imwrite(os.path.join(out_dir, name), pred)
if __name__ == '__main__':
image_pattern = "PATH\TO\TEST_IMG"
main(img_pattern=image_pattern)