From c05b2f710b416dcf3491a13c91b21afdc5686d78 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Sat, 29 Sep 2018 16:46:21 +0200 Subject: [PATCH 1/3] * Attempt to prevent race condition when the camera has no image yet. Related: https://github.com/JdeRobot/dl-objectdetector/issues/38 --- Camera/local_camera.py | 2 +- GUI/gui.py | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/Camera/local_camera.py b/Camera/local_camera.py index 80d5a8d..37c2b23 100644 --- a/Camera/local_camera.py +++ b/Camera/local_camera.py @@ -27,7 +27,7 @@ def __init__ (self, device_idx): self.im_width = self.cam.get(3) self.im_height = self.cam.get(4) - + self.update() def getImage(self): ''' Gets the image from the webcam and returns it. ''' diff --git a/GUI/gui.py b/GUI/gui.py index 32830ec..772bd11 100644 --- a/GUI/gui.py +++ b/GUI/gui.py @@ -118,7 +118,12 @@ def setNetwork(self, network, t_network): def update(self): ''' Updates the GUI for every time the thread change ''' # We get the original image and display it. - self.im_prev = self.cam.getImage() + try: + self.im_prev = self.cam.getImage() + except: + print("no image yet") + return + im = QtGui.QImage(self.im_prev.data, self.im_prev.shape[1], self.im_prev.shape[0], QtGui.QImage.Format_RGB888) self.im_scaled = im.scaled(self.im_label.size()) From 9f71dc4843bc18d1af2f75e623814809218bb342 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Sat, 29 Sep 2018 17:16:12 +0200 Subject: [PATCH 2/3] * Proof-of-concept Darknet backend. Not intended for general use. Please, keep in mind the following important known issues: - Detections are still extremely slow. The main reason is that the image supplied to Darknet is created by iterating through the Python numpy array. - Labels are assumed to be Coco. - Labels are read from data files copied from the Darknet distribution, and then supplied to the darknet library. However, the GUI still uses the same data files used throughout the project. We should reuse those files instead of keeping separate copies. - Some hardcoded paths and values. To test YOLOv3-tiny, please follow these steps: 0. Download and compile Darknet to obtain libdarknet.so. 1. Copy yolov3-tiny.cfg from the Darknet distribution into Net/Darknet. 2. Download yolov3-tiny.weights and place it into Net/Darknet. 3. Update your DYLD_LIBRARY_PATH to include the directory where libdarknet.so resides. 4. Run the objectdector as usual, using the yml configuration file in this revision. Other models are possible. The corresponding darknet weights and configuration files must be placed in the Net/Darknet directory. They must have the same name (except for the extension), which must in turn match the Model name defined in the YAML configuration file. I plan to work on speed next. As mentioned above, libdarknet requires an image structure that is manually built from a Python numpy array. I tried to supply a pointer to the underlying Python array instead, but I didn't know how to make it work using ctypes. I'm concerned that the byte ordering will differ anyway. It would be trivial to create a public helper function in libdarknet, but I don't want to depend on patched versions of Darknet. Credits - All data files have been taken from the Darknet distribution. - The darknet.py file was copied from the Darknet distribution, but it includes modifications to trigger detection from a numpy image instead of a file. Part of https://github.com/JdeRobot/dl-objectdetector/issues/38 --- Net/Darknet/__init__.py | 1 + Net/Darknet/coco.data | 4 + Net/Darknet/darknet/darknet.py | 203 +++++++++++++++++++++++++++++++++ Net/Darknet/data/coco.names | 80 +++++++++++++ Net/Darknet/network.py | 83 ++++++++++++++ objectdetector.py | 6 +- objectdetector.yml | 8 +- od_darknet.sh | 5 + 8 files changed, 387 insertions(+), 3 deletions(-) create mode 100644 Net/Darknet/__init__.py create mode 100644 Net/Darknet/coco.data create mode 100644 Net/Darknet/darknet/darknet.py create mode 100644 Net/Darknet/data/coco.names create mode 100644 Net/Darknet/network.py create mode 100755 od_darknet.sh diff --git a/Net/Darknet/__init__.py b/Net/Darknet/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/Net/Darknet/__init__.py @@ -0,0 +1 @@ + diff --git a/Net/Darknet/coco.data b/Net/Darknet/coco.data new file mode 100644 index 0000000..1f9268b --- /dev/null +++ b/Net/Darknet/coco.data @@ -0,0 +1,4 @@ +classes = 80 +names = Net/Darknet/data/coco.names +eval = coco + diff --git a/Net/Darknet/darknet/darknet.py b/Net/Darknet/darknet/darknet.py new file mode 100644 index 0000000..1a0eaad --- /dev/null +++ b/Net/Darknet/darknet/darknet.py @@ -0,0 +1,203 @@ +from ctypes import * +import math +import random +import numpy as np + +def sample(probs): + s = sum(probs) + probs = [a/s for a in probs] + r = random.uniform(0, 1) + for i in range(len(probs)): + r = r - probs[i] + if r <= 0: + return i + return len(probs)-1 + +def c_array(ctype, values): + arr = (ctype*len(values))() + arr[:] = values + return arr + +class BOX(Structure): + _fields_ = [("x", c_float), + ("y", c_float), + ("w", c_float), + ("h", c_float)] + +class DETECTION(Structure): + _fields_ = [("bbox", BOX), + ("classes", c_int), + ("prob", POINTER(c_float)), + ("mask", POINTER(c_float)), + ("objectness", c_float), + ("sort_class", c_int)] + + +class IMAGE(Structure): + _fields_ = [("w", c_int), + ("h", c_int), + ("c", c_int), + ("data", POINTER(c_float))] + +class METADATA(Structure): + _fields_ = [("classes", c_int), + ("names", POINTER(c_char_p))] + + +# Use DYLD_LIBRARY_PATH to locate the library +lib = CDLL("libdarknet.so", RTLD_GLOBAL) +lib.network_width.argtypes = [c_void_p] +lib.network_width.restype = c_int +lib.network_height.argtypes = [c_void_p] +lib.network_height.restype = c_int + +predict = lib.network_predict +predict.argtypes = [c_void_p, POINTER(c_float)] +predict.restype = POINTER(c_float) + +set_gpu = lib.cuda_set_device +set_gpu.argtypes = [c_int] + +make_image = lib.make_image +make_image.argtypes = [c_int, c_int, c_int] +make_image.restype = IMAGE + +get_network_boxes = lib.get_network_boxes +get_network_boxes.argtypes = [c_void_p, c_int, c_int, c_float, c_float, POINTER(c_int), c_int, POINTER(c_int)] +get_network_boxes.restype = POINTER(DETECTION) + +make_network_boxes = lib.make_network_boxes +make_network_boxes.argtypes = [c_void_p] +make_network_boxes.restype = POINTER(DETECTION) + +free_detections = lib.free_detections +free_detections.argtypes = [POINTER(DETECTION), c_int] + +free_ptrs = lib.free_ptrs +free_ptrs.argtypes = [POINTER(c_void_p), c_int] + +network_predict = lib.network_predict +network_predict.argtypes = [c_void_p, POINTER(c_float)] + +reset_rnn = lib.reset_rnn +reset_rnn.argtypes = [c_void_p] + +load_net = lib.load_network +load_net.argtypes = [c_char_p, c_char_p, c_int] +load_net.restype = c_void_p + +do_nms_obj = lib.do_nms_obj +do_nms_obj.argtypes = [POINTER(DETECTION), c_int, c_int, c_float] + +do_nms_sort = lib.do_nms_sort +do_nms_sort.argtypes = [POINTER(DETECTION), c_int, c_int, c_float] + +free_image = lib.free_image +free_image.argtypes = [IMAGE] + +letterbox_image = lib.letterbox_image +letterbox_image.argtypes = [IMAGE, c_int, c_int] +letterbox_image.restype = IMAGE + +load_meta = lib.get_metadata +lib.get_metadata.argtypes = [c_char_p] +lib.get_metadata.restype = METADATA + +load_image = lib.load_image_color +load_image.argtypes = [c_char_p, c_int, c_int] +load_image.restype = IMAGE + +rgbgr_image = lib.rgbgr_image +rgbgr_image.argtypes = [IMAGE] + +predict_image = lib.network_predict_image +predict_image.argtypes = [c_void_p, IMAGE] +predict_image.restype = POINTER(c_float) + +# set_pixel: symbol not found +#set_pixel = lib.set_pixel +#set_pixel.argtypes = [IMAGE, c_int, c_int, c_int, c_float] + +def classify(net, meta, im): + out = predict_image(net, im) + res = [] + for i in range(meta.classes): + res.append((meta.names[i], out[i])) + res = sorted(res, key=lambda x: -x[1]) + return res + +# Detect from file +def detect(net, meta, image, thresh=.5, hier_thresh=.5, nms=.45): + im = load_image(image, 0, 0) + num = c_int(0) + pnum = pointer(num) + predict_image(net, im) + dets = get_network_boxes(net, im.w, im.h, thresh, hier_thresh, None, 0, pnum) + num = pnum[0] + if (nms): do_nms_obj(dets, num, meta.classes, nms); + + res = [] + for j in range(num): + for i in range(meta.classes): + if dets[j].prob[i] > 0: + b = dets[j].bbox + res.append((meta.names[i], dets[j].prob[i], (b.x, b.y, b.w, b.h))) + res = sorted(res, key=lambda x: -x[1]) + free_image(im) + free_detections(dets, num) + return res + +# Detect from numpy image +def detect_from_image(net, meta, image, thresh=.5, hier_thresh=.5, nms=.45): + # Make an empty image and copy the values + # I tried to pass a pointer to the floats but didn't know how + # If we really need to do this it should be in C + h = image.shape[0] + w = image.shape[1] + c = image.shape[2] + im = make_image(w, h, c) + for x in np.arange(w): + for y in np.arange(h): + for l in np.arange(c): + pixel = float(image[y, x, l]) / 255. + #set_pixel(im, x, y, l, pixel) + im.data[ l * im.h * im.w + y * im.w + x ] = pixel + + # im = IMAGE() + # im.h = image.shape[0] + # im.w = image.shape[1] + # im.c = image.shape[2] + # f_image = image.astype(float) / 255. + # #im.data = LP_c_float(f_image.ctypes.data) + # #im.data = byref(f_image.ctypes) + # LP_c_float = POINTER(c_float) + # im.data = byref(c_float(f_image[0, 0, 0])) + + num = c_int(0) + pnum = pointer(num) + predict_image(net, im) + dets = get_network_boxes(net, im.w, im.h, thresh, hier_thresh, None, 0, pnum) + num = pnum[0] + if (nms): do_nms_obj(dets, num, meta.classes, nms); + + res = [] + for j in range(num): + for i in range(meta.classes): + if dets[j].prob[i] > 0: + b = dets[j].bbox + res.append((meta.names[i], dets[j].prob[i], (b.x, b.y, b.w, b.h))) + res = sorted(res, key=lambda x: -x[1]) + free_image(im) + free_detections(dets, num) + return res + +if __name__ == "__main__": + #net = load_net("cfg/densenet201.cfg", "/home/pjreddie/trained/densenet201.weights", 0) + #im = load_image("data/wolf.jpg", 0, 0) + #meta = load_meta("cfg/imagenet1k.data") + #r = classify(net, meta, im) + #print r[:10] + net = load_net("cfg/tiny-yolo.cfg", "tiny-yolo.weights", 0) + meta = load_meta("cfg/coco.data") + r = detect(net, meta, "data/dog.jpg") + print r diff --git a/Net/Darknet/data/coco.names b/Net/Darknet/data/coco.names new file mode 100644 index 0000000..ca76c80 --- /dev/null +++ b/Net/Darknet/data/coco.names @@ -0,0 +1,80 @@ +person +bicycle +car +motorbike +aeroplane +bus +train +truck +boat +traffic light +fire hydrant +stop sign +parking meter +bench +bird +cat +dog +horse +sheep +cow +elephant +bear +zebra +giraffe +backpack +umbrella +handbag +tie +suitcase +frisbee +skis +snowboard +sports ball +kite +baseball bat +baseball glove +skateboard +surfboard +tennis racket +bottle +wine glass +cup +fork +knife +spoon +bowl +banana +apple +sandwich +orange +broccoli +carrot +hot dog +pizza +donut +cake +chair +sofa +pottedplant +bed +diningtable +toilet +tvmonitor +laptop +mouse +remote +keyboard +cell phone +microwave +oven +toaster +sink +refrigerator +book +clock +vase +scissors +teddy bear +hair drier +toothbrush diff --git a/Net/Darknet/network.py b/Net/Darknet/network.py new file mode 100644 index 0000000..9ef500e --- /dev/null +++ b/Net/Darknet/network.py @@ -0,0 +1,83 @@ +import darknet +import numpy as np +from PIL import Image + +from Net.utils import label_map_util + +LABELS_DICT = {'voc': 'Net/labels/pascal_label_map.pbtxt', + 'coco': 'Net/labels/mscoco_label_map.pbtxt', + 'kitti': 'Net/labels/kitti_label_map.txt', + 'oid': 'Net/labels/oid_bboc_trainable_label_map.pbtxt', + 'pet': 'Net/labels/pet_label_map.pbtxt'} + +class DetectionNetwork(): + def __init__(self, net_model): + self.framework = "Darknet" + + # Parse the dataset to get which labels to yield + # TODO: we should hand them over to Darknet to avoid duplication + # TODO: (or read them from the Darknet data file) + labels_file = LABELS_DICT[net_model['Dataset'].lower()] + label_map = label_map_util.load_labelmap(labels_file) # loads the labels map. + categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=100000) + category_index = label_map_util.create_category_index(categories) + self.classes = {} + # We build is as a dict because of gaps on the labels definitions + for cat in category_index: + self.classes[cat] = str(category_index[cat]['name']) + + WEIGHTS_FILE = 'Net/Darknet/' + net_model['Model'] + '.weights' + CONFIG_FILE = 'Net/Darknet/' + net_model['Model'] + '.cfg' + LABELS_FILE = 'Net/Darknet/coco.data' # Hardcoded for now! + + self.model = darknet.load_net(CONFIG_FILE, WEIGHTS_FILE, 0) + self.meta = darknet.load_meta(LABELS_FILE) + + # Output preallocation + self.predictions = np.asarray([]) + self.boxes = np.asarray([]) + self.scores = np.asarray([]) + + print("Network ready!") + + + def setCamera(self, cam): + self.cam = cam + + self.original_height = cam.im_height + self.original_width = cam.im_width + + # Factors to rescale the output bounding boxes + # self.height_factor = np.true_divide(self.original_height, self.img_height) + # self.width_factor = np.true_divide(self.original_width, self.img_width) + + # No scaling, for now + self.height_factor = 1 + self.width_factor = 1 + + def predict(self): + input_image = self.cam.getImage() + + predictions = darknet.detect_from_image(self.model, self.meta, input_image) + print(predictions) + + self.predictions = [] + self.scores = [] + self.boxes = [] + + # iterate over predictions + for prediction in predictions: + self.predictions.append(prediction[0]) + self.scores.append(prediction[1]) + + # No scaling for now + box = prediction[2] + box_x = box[0] + box_y = box[1] + box_w = box[2] + box_h = box[3] + xmin = int((box_x - box_w / 2) * self.width_factor) + ymin = int((box_y - box_h / 2) * self.height_factor) + xmax = xmin + int(box_w * self.width_factor) + ymax = ymin + int(box_h * self.height_factor) + self.boxes.append([xmin, ymin, xmax, ymax]) diff --git a/objectdetector.py b/objectdetector.py index bd6549d..351c0cd 100644 --- a/objectdetector.py +++ b/objectdetector.py @@ -68,13 +68,17 @@ def selectNetwork(cfg): """ net_prop = cfg['ObjectDetector']['Network'] framework = net_prop['Framework'] + # TODO: import network from the filesystem instead of using a hardcoded set if framework.lower() == 'tensorflow': from Net.TensorFlow.network import DetectionNetwork elif framework.lower() == 'keras': sys.path.append('Net/Keras') from Net.Keras.network import DetectionNetwork + elif framework.lower() == 'darknet': + sys.path.append('Net/Darknet/darknet') + from Net.Darknet.network import DetectionNetwork else: - raise SystemExit(('%s not supported! Supported frameworks: Keras, TensorFlow') % (framework)) + raise SystemExit(('%s not supported! Supported frameworks: Keras, TensorFlow, Darknet') % (framework)) return net_prop, DetectionNetwork def readConfig(): diff --git a/objectdetector.yml b/objectdetector.yml index 1b8793a..b93a1b0 100755 --- a/objectdetector.yml +++ b/objectdetector.yml @@ -15,8 +15,12 @@ ObjectDetector: Name: cameraA Network: - Framework: TensorFlow # Currently supported: "Keras" or "TensorFlow" + #Framework: Keras # Currently supported: "Keras" or "TensorFlow" #Model: VGG_512_512_coco.h5 - Model: ssdlite_mobilenet_v2_coco_2018_05_29.pb + #Model: full_model.h5 + + Framework: Darknet + Model: yolov3-tiny + Dataset: COCO # available: VOC, COCO, KITTI, OID, PET NodeName: dl-digitclassifier diff --git a/od_darknet.sh b/od_darknet.sh new file mode 100755 index 0000000..1e754be --- /dev/null +++ b/od_darknet.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +export DYLD_LIBRARY_PATH=$HOME/code/ml/darknet:$DYLD_LIBRARY_PATH +python objectdetector.py objectdetector.yml + From f006b8ee221ffa19d4123c07f0b6693c6cb01dbc Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 1 Oct 2018 10:51:26 +0200 Subject: [PATCH 3/3] * Prepare image data by converting to float32 and transposing instead of iterating. Speed is much faster, as expected. Part of https://github.com/JdeRobot/dl-objectdetector/issues/38 --- Net/Darknet/darknet/darknet.py | 32 ++++++++------------------------ 1 file changed, 8 insertions(+), 24 deletions(-) diff --git a/Net/Darknet/darknet/darknet.py b/Net/Darknet/darknet/darknet.py index 1a0eaad..dfc3035 100644 --- a/Net/Darknet/darknet/darknet.py +++ b/Net/Darknet/darknet/darknet.py @@ -149,29 +149,14 @@ def detect(net, meta, image, thresh=.5, hier_thresh=.5, nms=.45): # Detect from numpy image def detect_from_image(net, meta, image, thresh=.5, hier_thresh=.5, nms=.45): - # Make an empty image and copy the values - # I tried to pass a pointer to the floats but didn't know how - # If we really need to do this it should be in C - h = image.shape[0] - w = image.shape[1] - c = image.shape[2] - im = make_image(w, h, c) - for x in np.arange(w): - for y in np.arange(h): - for l in np.arange(c): - pixel = float(image[y, x, l]) / 255. - #set_pixel(im, x, y, l, pixel) - im.data[ l * im.h * im.w + y * im.w + x ] = pixel - - # im = IMAGE() - # im.h = image.shape[0] - # im.w = image.shape[1] - # im.c = image.shape[2] - # f_image = image.astype(float) / 255. - # #im.data = LP_c_float(f_image.ctypes.data) - # #im.data = byref(f_image.ctypes) - # LP_c_float = POINTER(c_float) - # im.data = byref(c_float(f_image[0, 0, 0])) + im = IMAGE() + im.h = image.shape[0] + im.w = image.shape[1] + im.c = image.shape[2] + + f_image = image.astype(np.float32) / 255. + f_image = f_image.transpose(2, 0, 1).flatten() + im.data = f_image.ctypes.data_as(POINTER(c_float)) num = c_int(0) pnum = pointer(num) @@ -187,7 +172,6 @@ def detect_from_image(net, meta, image, thresh=.5, hier_thresh=.5, nms=.45): b = dets[j].bbox res.append((meta.names[i], dets[j].prob[i], (b.x, b.y, b.w, b.h))) res = sorted(res, key=lambda x: -x[1]) - free_image(im) free_detections(dets, num) return res