diff --git a/stretch_deep_perception/nodes/detect_objects.py b/stretch_deep_perception/nodes/detect_objects.py index b53b797..63956cb 100755 --- a/stretch_deep_perception/nodes/detect_objects.py +++ b/stretch_deep_perception/nodes/detect_objects.py @@ -3,34 +3,14 @@ import cv2 import sys import rospy -import object_detector as od + +import object_detect_pytorch as od import detection_node as dn import deep_learning_model_options as do -if __name__ == '__main__': - print('cv2.__version__ =', cv2.__version__) - print('Python version (must be > 3.0):', sys.version) - assert(int(sys.version[0]) >= 3) - - - models_directory = do.get_directory() - print('Using the following directory for deep learning models:', models_directory) - use_neural_compute_stick = do.use_neural_compute_stick() - if use_neural_compute_stick: - print('Attempt to use an Intel Neural Compute Stick 2.') - else: - print('Not attempting to use an Intel Neural Compute Stick 2.') - - use_tiny = True - if use_tiny: - confidence_threshold = 0.0 - else: - confidence_threshold = 0.5 - - detector = od.ObjectDetector(models_directory, - use_tiny_yolo3=use_tiny, - confidence_threshold=confidence_threshold, - use_neural_compute_stick=use_neural_compute_stick) +if __name__ == '__main__': + confidence_threshold = 0.0 + detector = od.ObjectDetector(confidence_threshold=confidence_threshold) default_marker_name = 'object' node_name = 'DetectObjectsNode' topic_base_name = 'objects' diff --git a/stretch_deep_perception/nodes/object_detect_pytorch.py b/stretch_deep_perception/nodes/object_detect_pytorch.py new file mode 100644 index 0000000..4bcc593 --- /dev/null +++ b/stretch_deep_perception/nodes/object_detect_pytorch.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 + +import cv2 +import numpy as np +import torch +import pandas +import ros_numpy + +import deep_models_shared_python3 as dm + + +class ObjectDetector: + def __init__(self, confidence_threshold=0.2): + # Load the models + self.model = torch.hub.load('ultralytics/yolov5', 'yolov5s') # or yolov5m, yolov5l, yolov5x, custom + self.confidence_threshold = confidence_threshold + + def get_landmark_names(self): + return None + + def get_landmark_colors(self): + return None + + def get_landmark_color_dict(self): + return None + + def apply_to_image(self, rgb_image, draw_output=False): + results = self.model(rgb_image) + object_detections = results.pandas().xyxy[0] + + results = [] + for index, detection in object_detections.iterrows(): + confidence = detection['confidence'] + if confidence > self.confidence_threshold: + class_label = detection['name'] + object_class_id = detection['class'] + x_min = detection['xmin'] + x_max = detection['xmax'] + y_min = detection['ymin'] + y_max = detection['ymax'] + box = (x_min, y_min, x_max, y_max) + + print(class_label, ' detected') + + results.append({'class_id': object_class_id, + 'label': class_label, + 'confidence': confidence, + 'box': box}) + + output_image = None + if draw_output: + output_image = rgb_image.copy() + for detection_dict in results: + self.draw_detection(output_image, detection_dict) + + return results, output_image + + + def draw_detection(self, image, detection_dict): + font_scale = 0.75 + line_color = [0, 0, 0] + line_width = 1 + font = cv2.FONT_HERSHEY_PLAIN + class_label = detection_dict['label'] + confidence = detection_dict['confidence'] + box = detection_dict['box'] + x_min, y_min, x_max, y_max = box + output_string = '{0}, {1:.2f}'.format(class_label, confidence) + color = (0, 0, 255) + rectangle_line_thickness = 2 #1 + cv2.rectangle(image, (x_min, y_min), (x_max, y_max), color, rectangle_line_thickness) + + # see the following page for a helpful reference + # https://stackoverflow.com/questions/51285616/opencvs-gettextsize-and-puttext-return-wrong-size-and-chop-letters-with-low + + label_background_border = 2 + (label_width, label_height), baseline = cv2.getTextSize(output_string, font, font_scale, line_width) + label_x_min = x_min + label_y_min = y_min + label_x_max = x_min + (label_width + (2 * label_background_border)) + label_y_max = y_min + (label_height + baseline + (2 * label_background_border)) + + text_x = label_x_min + label_background_border + text_y = (label_y_min + label_height) + label_background_border + + cv2.rectangle(image, (label_x_min, label_y_min), (label_x_max, label_y_max), (255, 255, 255), cv2.FILLED) + cv2.putText(image, output_string, (text_x, text_y), font, font_scale, line_color, line_width, cv2.LINE_AA) diff --git a/stretch_deep_perception/nodes/webcam_publisher.py b/stretch_deep_perception/nodes/webcam_publisher.py new file mode 100755 index 0000000..f7d8972 --- /dev/null +++ b/stretch_deep_perception/nodes/webcam_publisher.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python3 + +import cv2 +import numpy as np +import rospy +from cv_bridge import CvBridge # Check if ros_numpy can be used instead +from sensor_msgs.msg import Image + +if __name__ == '__main__': + rospy.init_node("webcam_node", anonymous=True) + rgb_topic_name = '/camera/color/image_raw' #'/camera/infra1/image_rect_raw' + webcam_pub = rospy.Publisher(rgb_topic_name, Image, queue_size=1) + vid = cv2.VideoCapture(0) + bridge = CvBridge() + + while(True): + ret, frame = vid.read() + cv2.imshow('frame', frame) + image_message = bridge.cv2_to_imgmsg(frame, encoding="passthrough") + webcam_pub.publish(image_message) + # press 'q' to quit + if cv2.waitKey(1) & 0xFF == ord('q'): + break + + # After the loop release the cap object + vid.release() + # Destroy all the windows + cv2.destroyAllWindows() \ No newline at end of file