jeonggg119/DL_paper

[CV_Pose Estimation] Deep High-Resolution Representation Learning for Human Pose Estimation

jeonggg119 opened this issue · 0 comments

Deep High-Resolution Representation Learning for Human Pose Estimation

Basic

  • Trade off : Global information vs High-resolution(Original size)
    • Global information (Receptive field ↑) -> Low resolution -> Up -sampling ↑ -> Pixel-wise prediction ↓
    • Need : Learning both Global + Local Feature & Recovering High-resolution

image

1. Introduction

  • Most existing method

    • Recover high-resolution from low-resolution
    • By high-to-low resolution network connected in Series
    • ex) Hourglass, SimpleBaseline, Dilated conv
  • High-Resolution Net (HR-net)

    • Maintain high-resolution through Whole process
    • First stage : a high-resolution subnetwork --> Next stage : Gradually add high-to-low resolution subnetworks
    • Repeated Multi-scale fusions By Parallel multi-resolution subnetworks : help of same depth-low resolution
    • Result : rich high-resolution representations -> more accurate and spatially precise heatmap
    • Dataset : COCO keypoint detection dataset, MPII Human Pose dataset, PoseTrack dataset

2. Related Work

  • Traditional solutions to single-pose estimation : probabilistic graphical model, pictorial structure model
  • Present mainstream methods by DNN : Regressing keypoint positions & Estimating keypoint Heatmaps
    • Regressing (x, y) : ex) (2013) DeepPose : Human Pose Estimation via Deep Neural Networks
    • Estimating Heatmap [loc = (x, y)] : ex) (2015) Efficient Object Localization Using Convolutional Networks
  • Most CNN for keypoint heatmap
    • consist of subnetwork similar to classification network
    • input --> a regressor estimating heatmaps
    • main body : high-to-low and low-to-high framework, augmented with multi-scale fusion + intermediate supervision

image

  • (a) Hourglass : symmetric low-to-high and high-to-low
  • (b) Cascade pyramid networks
  • (c) SimpleBaseline : Transposed conv for low-to-high
  • (d) Combination with Dilated conv

2.1. High-to-low and Low-to-high

  • Symmetric high-to-low and low-to-high
  • Heavy high-to-low (classification network = strided conv or pooling) and Light low-to-high (bilinear-upsampling or transposed conv)
  • Combination with Dialted conv
  • Bad for Small object or Detail spatial information -> Bad for Pixcel-wise prediction
    • Serialization of network : Local, Global feature extraction and learning rely excessively on Up-sampling

2.2. Multi-scale fusion

  • (a), (b) : skip-connections bw same-resolution layers of h-t-l and l-t-h
  • (a) Hourglass : Feeding multi-resolution imgs separately into multiple networks and Aggregating output map
  • (b) Cascaded pyramid network : globalnet + refinent(right part for combinating features)

2.3. Intermediate supervision

  • For helping deep networks training and improving heatmap estimation quality
  • ex) Hourglass, conv pose machine approach : intermediate heatmaps as (part of) input of remaining subnetwork

HR-net

image

  • High-to-low subnetworks in Parallel + Fusing multi-scale representations
  • No intermediate supervision
  • Result : superior in detection accuracy + efficient in computation complexity and params

3. Approach

  • Human pose estimation Task : detecting locations of K keypoints or parts from img I (W x H x 3)
  • SOTA methods : estimating K heatmaps of size W' x H', {H_1, H_2, ..., H_K}, H_k : location confidence of kth keypoint
  • HR-net : using CNN consisting 3 parts
    • Two strided conv decreasing resolution
    • Main body outputting feature maps with same resolution as its input feature maps
    • Regressor estimating heatmaps where keypoint positions are chosen and transformed to full resolution

3.1. Sequential multi-resolution subnetworks

  • Existing networks : connecting high-to-low resolution subnetworks in Series
  • Sequence of subnetworks + down-sample layer to halve resolution
  • N_sr : subnetwork (s : s-th stage, r : resolution index) -> resolution : 1/2^(r-1) of first subnetwork
    • ex) High-to-low network : N_11 -> N_22 -> N_33 -> N_44

3.2. Parallel multi-resolution subnetworks

  • ex) 4 Parallel sub-networks
    image

3.3. Repeated multi-scale fusion

image
image

  • Exchange units (Fusion) across parallel subnetworks
  • Input : X = {X_1, X_2, ..., X_s}
  • Output : Y = {Y_1, Y_2, ..., Y_s}, whose sizes are same to inputs
    • Each output is an aggregation of input maps : Y_k = ∑ a(X_i, k), i=1, ..., s
    • Extra output maps : Y_(s+1) = a(Y_s, s+1)
  • Function : a(X_i, k) : Up-sampling or Ddown-sampling X_i from resolution i to k
    • Down-sampling(halve) : strided 3x3 conv (Stride = 2, Padding = 1)
    • Up-sampling(double) : simple nearest neighbor sampling following a 1x1 conv

3.4. Heatmap estimation

  • Regressing heatmaps from high-resolution output by Last exchange unit
  • Loss function : MSE
    • GT heatmaps : 2D gaussian with sd=1 pixel-centered on GT location of each keypoing

3.5. Network instantiation

  • ResNet to distribute depth to each stage and # of channels to each resolution
  • Main body : HR-net : 4 stages with 4 parallel subnetworks
    • Resolution is gradually decreased (halve) -> Width(# of channels) is increased (dounle)
    • 1st stage : 4 Residual units
      • each unit is formed by a bottleneck with width 64, followed by one 3x3 conv reducing width of feature maps to C
    • 2, 3, 4th stages : 1, 4, 3 Exchange blocks -> Totally 8 Exchange blocks (-> 8 multi-scale fusions)
      • one Exchange block contains 4 Residual units (each unit is followed by two 3x3 conv) and an Exchange block
  • Experiments : HRNet-W32 (small net), HRNet-W48 (big net)
    • 32 and 48 : widths(C) of high-resolution subnetworks in last 3 stages
    • HRNet-W32 = 64,128, 256, 32, 32, 32
    • HRNet-W48 = 96, 192, 384, 48, 48, 48

4. Experiments

4.1. COCO Keypoint Detection

Dataset

  • COCO dataset : 200K imgs, 250K person instances labeled with 17 Keypoints
    • COCO train2017 dataset : 57K imgs + 150K person instances
    • COCO val2017 : 5K imgs
    • COCO test-dec2017 set : 20K imgs
    • [Annotation] 17 Keypoints : (x, y, z)
      • x, y : (x,y), 2D img coordinate
      • z : visibility flag (0 : not labeled / 1 : labeled but not showed / 2 : labeled and showed)

Evaluation metric

  • Similarity Metric : OKS (Object Keypoint Similarity)
    image
    • d_i : Euclidean distance bw detected keypoint and GT
    • v_i : visibility flag of GT
    • s : object scale (diagonal length of bbox)
    • k_i : per-keypoint constant that controls falloff
    • OKS = 0(Worst) ~ 1(Best)
  • Evaluation Metric : AP (Average Precision) : AP^50, AP^75, AP, AP^M, AP^L, AR

Training

  • Fixed Human detection box img (h : w = 4 : 3) ... ex) 256 x 192 or 384 x 288
  • Data Augmentation : random rotation, random scale, flipping, half body data augmentation
  • Adam optimizer
  • lr scheduler : 1e-3 (base) -> 1e-4 (170th epochs) -> 1e-5 (200th epochs) -> (210 epochs)

Testing

  • Top-down : Detect person instance using person detector --> Predict detection keypoints
    • person detectors : same with SimpleBaseline model
  • Averaging heatmaps of original and flipped imgs
  • Predicted keypoint location : Highest heatvalue location with a quarter offset

Results on validation set

image

  • [Red] AP : HRNet = 73.4 > Others
  • [Red] #Params, GFLOPs : HRNet > CPN model
  • [Red] #Params, GFLOPs : HRNet < SimpleBaseline model
  • [Blue] Pre-trained model for ImageNet classification is better : 1.0 points ↑
  • [Green] Width size ↑ (HRNet-W48) -> AP ↑ : 0.7, 0.5 ↑
  • [Orange] Input size ↑ (384 x 288) -> AP ↑ : 1.4, 1.2 ↑

Results on test-dev set

image

  • HR-net (Top-down) is better than Botton-up methods
  • HRnet-W32 : 74.9 AP > Other Top-down methods
    • More efficient in model size (#Params) and computation complexicity (GELOPs)
  • HRNet-W48 : highest 75.5 AP > SimpleBaseline
  • +) Additional data from AI Challenger for training : best 77.0 AP

4.2. MPII Human Pose Estimation

Dataset

  • MPII Human Pose dataset (real-world / full-body pose) : 25K imgs with 40K subjects
    • 12K subjects for testing + 13K subjects for training

Training

  • Same to MS COCO, except that input size is cropped to 256 x 256

Testing

  • Same to MS COCO, except that using provided person boxes (instead of detected person boxes)
  • six-scale pyramid testing procedure

Evaluation metric

  • PCKh (head-normalized probability of correct keypoint) score -> PCKh@0.5 (α=0.5)
    • Joint is correct if it falls within α * ℓ pixels of GT position
    • α : constant
    • ℓ : head size that corresponds to 60% of diagonal length of GT head bbox

Results on test set

image
image

  • HRNet-W32 : model size (#Params = 28.5M) ↓, computation complexicity (GELOPs = 9.5) ↓, 92.3 PCKh@0.5 ↑
  • HRNet-W48 : same result 92.3 PCKh@0.5

4.3. Application to Pose Tracking

Dataset

  • PoseTrack (articulated tracking in video provided by MPII Human Pose dataset) : 550 video seq with 66, 374 frames
    • video seq are split into 292(train) + 50(val) + 208(test)
      • train : length ranges bw 41~151 frames / 30 frames from center of video are densely annotated
      • val/test : 65~298 frames / 30 frames around keyframe are densely annotated + afterwards every fourth frame is annotated

Evaluation metric

  • [1] Frame-wise Multi-person Pose Estimation : mAP (mean Average Precision)
  • [2] Multi-person Pose Tracking : MOTA (multi-object tracking accuracy)

Training

  • network : HRNet-W48 (pre-trained on COCO dataset) for single person pose estimation on PoseTrack2017 training set
  • Input : Person box extracted from annotated keypoints in training frames by extending bbox of all keypoints by 15%
  • Training setup, data aug : almost same as COCO except lr scheduler : 1e-4 -> 1e-5 (10th) -> 1e-6 (15th) -> (20 epochs)

Testing

  • 1) Person box Detection and Propagation
    • Same detector in SimpleBaseline
    • Propagating box into nearby frames by propagating predicted keypoints according to optical flows + NMS for removing
  • 2) Human Pose Estimation
    • Metric : OKS (Object Keypoint Similarity)
  • 3) Pose Association cross nearby frames
    • Greedy matching algorithm to compute correspondence bw keypoints in nearby frames

Results on PoseTrack2017 test set

image

  • HRNet-W48 : 74.9 mAP score, 57.9 MOTA score

4.4. Ablation Study

Repeated multi-scale fusion

  • (a) Without Intermediate Exchange (1 fusions)
  • (b) With only Across-stage Exchange (3 fusions)
  • (c) With both Across-stage and Within-stage Exchange (8 fusions) = HR-Net
  • All networks are trained from scratch
  • Result on COCO val set : More fusions lead to better performance (AP : c>b>a)
    image

Resolution maintenance

  • HRNet-W32 : 73.4 AP > Variant : 72.5 AP
  • Low-level features extracted from early stages over low-resolution subnetworks are less helpful
  • Simple high-resolution without low-resolution parallel subnetworks shows lower performance

Representation resolution

  • (1) Resolution ↑ -> AP ↑ = Keypoint heatmap prediction quality ↑
    image

  • (2) Input size

    • Performance(AP) Improvement for smaller input size (128 x 96) is bigger than larger input size (256 x 192)
    • Input size ↑ -> AP ↑
    • Intuition : Maintaining high resolution is important!
      image

5. Conclusion and Future Works

  • Maintaining high resolution through whole process without need of recovering
  • Fusing multi-resolution representations repeatly
  • Result : reliable high-resolution representations

Code

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import csv
import os
import shutil

from PIL import Image
import torch
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision
import cv2
import numpy as np
import time


import _init_paths
import models
from config import cfg
from config import update_config
from core.function import get_final_preds
from utils.transforms import get_affine_transform

COCO_KEYPOINT_INDEXES = {
    0: 'nose',
    1: 'left_eye',
    2: 'right_eye',
    3: 'left_ear',
    4: 'right_ear',
    5: 'left_shoulder',
    6: 'right_shoulder',
    7: 'left_elbow',
    8: 'right_elbow',
    9: 'left_wrist',
    10: 'right_wrist',
    11: 'left_hip',
    12: 'right_hip',
    13: 'left_knee',
    14: 'right_knee',
    15: 'left_ankle',
    16: 'right_ankle'
}

COCO_INSTANCE_CATEGORY_NAMES = [
    '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
    'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
    'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
    'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
    'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
    'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
    'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
    'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
    'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
    'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
    'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]

SKELETON = [
    [1,3],[1,0],[2,4],[2,0],[0,5],[0,6],[5,7],[7,9],[6,8],[8,10],[5,11],[6,12],[11,12],[11,13],[13,15],[12,14],[14,16]
]

CocoColors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0],
              [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255],
              [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]

NUM_KPTS = 17

CTX = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

def draw_pose(keypoints,img):
    """draw the keypoints and the skeletons.
    :params keypoints: the shape should be equal to [17,2]
    :params img:
    """
    assert keypoints.shape == (NUM_KPTS,2)
    for i in range(len(SKELETON)):
        kpt_a, kpt_b = SKELETON[i][0], SKELETON[i][1]
        x_a, y_a = keypoints[kpt_a][0],keypoints[kpt_a][1]
        x_b, y_b = keypoints[kpt_b][0],keypoints[kpt_b][1] 
        cv2.circle(img, (int(x_a), int(y_a)), 6, CocoColors[i], -1)
        cv2.circle(img, (int(x_b), int(y_b)), 6, CocoColors[i], -1)
        cv2.line(img, (int(x_a), int(y_a)), (int(x_b), int(y_b)), CocoColors[i], 2)

def draw_bbox(box,img):
    """draw the detected bounding box on the image.
    :param img:
    """
    cv2.rectangle(img, box[0], box[1], color=(0, 255, 0),thickness=3)


def get_person_detection_boxes(model, img, threshold=0.5):
    pred = model(img)
    pred_classes = [COCO_INSTANCE_CATEGORY_NAMES[i]
                    for i in list(pred[0]['labels'].cpu().numpy())]  # Get the Prediction Score
    pred_boxes = [[(i[0], i[1]), (i[2], i[3])]
                  for i in list(pred[0]['boxes'].detach().cpu().numpy())]  # Bounding boxes
    pred_score = list(pred[0]['scores'].detach().cpu().numpy())
    if not pred_score or max(pred_score)<threshold:
        return []
    # Get list of index with score greater than threshold
    pred_t = [pred_score.index(x) for x in pred_score if x > threshold][-1]
    pred_boxes = pred_boxes[:pred_t+1]
    pred_classes = pred_classes[:pred_t+1]

    person_boxes = []
    for idx, box in enumerate(pred_boxes):
        if pred_classes[idx] == 'person':
            person_boxes.append(box)

    return person_boxes


def get_pose_estimation_prediction(pose_model, image, center, scale):
    rotation = 0

    # pose estimation transformation
    trans = get_affine_transform(center, scale, rotation, cfg.MODEL.IMAGE_SIZE)
    model_input = cv2.warpAffine(
        image,
        trans,
        (int(cfg.MODEL.IMAGE_SIZE[0]), int(cfg.MODEL.IMAGE_SIZE[1])),
        flags=cv2.INTER_LINEAR)
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

    # pose estimation inference
    model_input = transform(model_input).unsqueeze(0)
    # switch to evaluate mode
    pose_model.eval()
    with torch.no_grad():
        # compute output heatmap
        output = pose_model(model_input)
        preds, _ = get_final_preds(
            cfg,
            output.clone().cpu().numpy(),
            np.asarray([center]),
            np.asarray([scale]))

        return preds


def box_to_center_scale(box, model_image_width, model_image_height):
    """convert a box to center,scale information required for pose transformation
    Parameters
    ----------
    box : list of tuple
        list of length 2 with two tuples of floats representing
        bottom left and top right corner of a box
    model_image_width : int
    model_image_height : int

    Returns
    -------
    (numpy array, numpy array)
        Two numpy arrays, coordinates for the center of the box and the scale of the box
    """
    center = np.zeros((2), dtype=np.float32)

    bottom_left_corner = box[0]
    top_right_corner = box[1]
    box_width = top_right_corner[0]-bottom_left_corner[0]
    box_height = top_right_corner[1]-bottom_left_corner[1]
    bottom_left_x = bottom_left_corner[0]
    bottom_left_y = bottom_left_corner[1]
    center[0] = bottom_left_x + box_width * 0.5
    center[1] = bottom_left_y + box_height * 0.5

    aspect_ratio = model_image_width * 1.0 / model_image_height
    pixel_std = 200

    if box_width > aspect_ratio * box_height:
        box_height = box_width * 1.0 / aspect_ratio
    elif box_width < aspect_ratio * box_height:
        box_width = box_height * aspect_ratio
    scale = np.array(
        [box_width * 1.0 / pixel_std, box_height * 1.0 / pixel_std],
        dtype=np.float32)
    if center[0] != -1:
        scale = scale * 1.25

    return center, scale

def parse_args():
    parser = argparse.ArgumentParser(description='Train keypoints network')
    # general
    parser.add_argument('--cfg', type=str, default='demo/inference-config.yaml')
    parser.add_argument('--video', type=str)
    parser.add_argument('--webcam',action='store_true')
    parser.add_argument('--image',type=str)
    parser.add_argument('--write',action='store_true')
    parser.add_argument('--showFps',action='store_true')

    parser.add_argument('opts',
                        help='Modify config options using the command-line',
                        default=None,
                        nargs=argparse.REMAINDER)

    args = parser.parse_args()

    # args expected by supporting codebase  
    args.modelDir = ''
    args.logDir = ''
    args.dataDir = ''
    args.prevModelDir = ''
    return args


def main():
    # cudnn related setting
    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    args = parse_args()
    update_config(cfg, args)

    box_model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
    box_model.to(CTX)
    box_model.eval()

    pose_model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(
        cfg, is_train=False
    )

    if cfg.TEST.MODEL_FILE:
        print('=> loading model from {}'.format(cfg.TEST.MODEL_FILE))
        pose_model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=False)
    else:
        print('expected model defined in config at TEST.MODEL_FILE')

    pose_model = torch.nn.DataParallel(pose_model, device_ids=cfg.GPUS)
    pose_model.to(CTX)
    pose_model.eval()

    # Loading an video or an image or webcam 
    if args.webcam:
        vidcap = cv2.VideoCapture(0)
    elif args.video:
        vidcap = cv2.VideoCapture(args.video)
    elif args.image:
        image_bgr = cv2.imread(args.image)
    else:
        print('please use --video or --webcam or --image to define the input.')
        return 

    if args.webcam or args.video:
        if args.write:
            save_path = 'output.avi'
            fourcc = cv2.VideoWriter_fourcc(*'XVID')
            out = cv2.VideoWriter(save_path,fourcc, 24.0, (int(vidcap.get(3)),int(vidcap.get(4))))
        while True:
            ret, image_bgr = vidcap.read()
            if ret:
                last_time = time.time()
                image = image_bgr[:, :, [2, 1, 0]]

                input = []
                img = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
                img_tensor = torch.from_numpy(img/255.).permute(2,0,1).float().to(CTX)
                input.append(img_tensor)

                # object detection box
                pred_boxes = get_person_detection_boxes(box_model, input, threshold=0.9)

                # pose estimation
                if len(pred_boxes) >= 1:
                    for box in pred_boxes:
                        center, scale = box_to_center_scale(box, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[1])
                        image_pose = image.copy() if cfg.DATASET.COLOR_RGB else image_bgr.copy()
                        pose_preds = get_pose_estimation_prediction(pose_model, image_pose, center, scale)
                        if len(pose_preds)>=1:
                            for kpt in pose_preds:
                                draw_pose(kpt,image_bgr) # draw the poses

                if args.showFps:
                    fps = 1/(time.time()-last_time)
                    img = cv2.putText(image_bgr, 'fps: '+ "%.2f"%(fps), (25, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 255, 0), 2)

                if args.write:
                    out.write(image_bgr)

                cv2.imshow('demo',image_bgr)
                if cv2.waitKey(1) & 0XFF==ord('q'):
                    break
            else:
                print('cannot load the video.')
                break

        cv2.destroyAllWindows()
        vidcap.release()
        if args.write:
            print('video has been saved as {}'.format(save_path))
            out.release()

    else:
        # estimate on the image
        last_time = time.time()
        image = image_bgr[:, :, [2, 1, 0]]

        input = []
        img = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
        img_tensor = torch.from_numpy(img/255.).permute(2,0,1).float().to(CTX)
        input.append(img_tensor)

        # object detection box
        pred_boxes = get_person_detection_boxes(box_model, input, threshold=0.9)

        # pose estimation
        if len(pred_boxes) >= 1:
            for box in pred_boxes:
                center, scale = box_to_center_scale(box, cfg.MODEL.IMAGE_SIZE[0], cfg.MODEL.IMAGE_SIZE[1])
                image_pose = image.copy() if cfg.DATASET.COLOR_RGB else image_bgr.copy()
                pose_preds = get_pose_estimation_prediction(pose_model, image_pose, center, scale)
                if len(pose_preds)>=1:
                    for kpt in pose_preds:
                        draw_pose(kpt,image_bgr) # draw the poses
        
        if args.showFps:
            fps = 1/(time.time()-last_time)
            img = cv2.putText(image_bgr, 'fps: '+ "%.2f"%(fps), (25, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 255, 0), 2)
        
        if args.write:
            save_path = 'output.jpg'
            cv2.imwrite(save_path,image_bgr)
            print('the result image has been saved as {}'.format(save_path))

        cv2.imshow('demo',image_bgr)
        if cv2.waitKey(0) & 0XFF==ord('q'):
            cv2.destroyAllWindows()
        
if __name__ == '__main__':
    main()