Qidian213/deep_sort_yolov3

i want to set the gpu fraction,but it failed, gpu is always fully occupied

zfs1993 opened this issue · 0 comments

i add these code to set the gpu fraction,
the yolo part (yolo.py)
class YOLO(object):
def init(self):
self.model_path = 'model_data/yolo.h5'
#self.model_path = 'model_data/yolo_tiny.h5'
self.anchors_path = 'model_data/yolo_anchors.txt'
self.classes_path = 'model_data/coco_classes.txt'
self.score = 0.5
self.iou = 0.5
self.class_names = self._get_class()
self.anchors = self._get_anchors()
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.3
config.gpu_options.allow_growth = True
sess=tf.Session(config=config)
self.sess = sess
set_session(sess)
self.model_image_size = (416, 416) # fixed size or (None, None)
self.is_fixed_size = self.model_image_size != (None, None)
self.boxes, self.scores, self.classes = self.generate()

the features part(tools/generate_detections.py)
class ImageEncoder(object):

def __init__(self, checkpoint_filename, input_name="images",
             output_name="features"):
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    self.session=tf.Session(config=config)
    #self.session = tf.Session()
    with tf.gfile.GFile(checkpoint_filename, "rb") as file_handle:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(file_handle.read())
    tf.import_graph_def(graph_def, name="net")
    self.input_var = tf.get_default_graph().get_tensor_by_name(
        "net/%s:0" % input_name)
    self.output_var = tf.get_default_graph().get_tensor_by_name(
        "net/%s:0" % output_name)

    assert len(self.output_var.get_shape()) == 2
    assert len(self.input_var.get_shape()) == 4
    self.feature_dim = self.output_var.get_shape().as_list()[-1]
    self.image_shape = self.input_var.get_shape().as_list()[1:]

def __call__(self, data_x, batch_size=32):
    out = np.zeros((len(data_x), self.feature_dim), np.float32)
    _run_in_batches(
        lambda x: self.session.run(self.output_var, feed_dict=x),
        {self.input_var: data_x}, out, batch_size)
    return out

but it failed ,is there anyone who meet the same questions?