Browse Source

Changes object detection node to using pytorch

feature/pytorch_perception
hello-chintan 2 years ago
parent
commit
ccd64e64c0
3 changed files with 120 additions and 25 deletions
  1. +5
    -25
      stretch_deep_perception/nodes/detect_objects.py
  2. +87
    -0
      stretch_deep_perception/nodes/object_detect_pytorch.py
  3. +28
    -0
      stretch_deep_perception/nodes/webcam_publisher.py

+ 5
- 25
stretch_deep_perception/nodes/detect_objects.py View File

@ -3,34 +3,14 @@
import cv2
import sys
import rospy
import object_detector as od
import object_detect_pytorch as od
import detection_node as dn
import deep_learning_model_options as do
if __name__ == '__main__':
print('cv2.__version__ =', cv2.__version__)
print('Python version (must be > 3.0):', sys.version)
assert(int(sys.version[0]) >= 3)
models_directory = do.get_directory()
print('Using the following directory for deep learning models:', models_directory)
use_neural_compute_stick = do.use_neural_compute_stick()
if use_neural_compute_stick:
print('Attempt to use an Intel Neural Compute Stick 2.')
else:
print('Not attempting to use an Intel Neural Compute Stick 2.')
use_tiny = True
if use_tiny:
confidence_threshold = 0.0
else:
confidence_threshold = 0.5
detector = od.ObjectDetector(models_directory,
use_tiny_yolo3=use_tiny,
confidence_threshold=confidence_threshold,
use_neural_compute_stick=use_neural_compute_stick)
if __name__ == '__main__':
confidence_threshold = 0.0
detector = od.ObjectDetector(confidence_threshold=confidence_threshold)
default_marker_name = 'object'
node_name = 'DetectObjectsNode'
topic_base_name = 'objects'

+ 87
- 0
stretch_deep_perception/nodes/object_detect_pytorch.py View File

@ -0,0 +1,87 @@
#!/usr/bin/env python3
import cv2
import numpy as np
import torch
import pandas
import ros_numpy
import deep_models_shared_python3 as dm
class ObjectDetector:
def __init__(self, confidence_threshold=0.2):
# Load the models
self.model = torch.hub.load('ultralytics/yolov5', 'yolov5s') # or yolov5m, yolov5l, yolov5x, custom
self.confidence_threshold = confidence_threshold
def get_landmark_names(self):
return None
def get_landmark_colors(self):
return None
def get_landmark_color_dict(self):
return None
def apply_to_image(self, rgb_image, draw_output=False):
results = self.model(rgb_image)
object_detections = results.pandas().xyxy[0]
results = []
for index, detection in object_detections.iterrows():
confidence = detection['confidence']
if confidence > self.confidence_threshold:
class_label = detection['name']
object_class_id = detection['class']
x_min = detection['xmin']
x_max = detection['xmax']
y_min = detection['ymin']
y_max = detection['ymax']
box = (x_min, y_min, x_max, y_max)
print(class_label, ' detected')
results.append({'class_id': object_class_id,
'label': class_label,
'confidence': confidence,
'box': box})
output_image = None
if draw_output:
output_image = rgb_image.copy()
for detection_dict in results:
self.draw_detection(output_image, detection_dict)
return results, output_image
def draw_detection(self, image, detection_dict):
font_scale = 0.75
line_color = [0, 0, 0]
line_width = 1
font = cv2.FONT_HERSHEY_PLAIN
class_label = detection_dict['label']
confidence = detection_dict['confidence']
box = detection_dict['box']
x_min, y_min, x_max, y_max = box
output_string = '{0}, {1:.2f}'.format(class_label, confidence)
color = (0, 0, 255)
rectangle_line_thickness = 2 #1
cv2.rectangle(image, (x_min, y_min), (x_max, y_max), color, rectangle_line_thickness)
# see the following page for a helpful reference
# https://stackoverflow.com/questions/51285616/opencvs-gettextsize-and-puttext-return-wrong-size-and-chop-letters-with-low
label_background_border = 2
(label_width, label_height), baseline = cv2.getTextSize(output_string, font, font_scale, line_width)
label_x_min = x_min
label_y_min = y_min
label_x_max = x_min + (label_width + (2 * label_background_border))
label_y_max = y_min + (label_height + baseline + (2 * label_background_border))
text_x = label_x_min + label_background_border
text_y = (label_y_min + label_height) + label_background_border
cv2.rectangle(image, (label_x_min, label_y_min), (label_x_max, label_y_max), (255, 255, 255), cv2.FILLED)
cv2.putText(image, output_string, (text_x, text_y), font, font_scale, line_color, line_width, cv2.LINE_AA)

+ 28
- 0
stretch_deep_perception/nodes/webcam_publisher.py View File

@ -0,0 +1,28 @@
#!/usr/bin/env python3
import cv2
import numpy as np
import rospy
from cv_bridge import CvBridge # Check if ros_numpy can be used instead
from sensor_msgs.msg import Image
if __name__ == '__main__':
rospy.init_node("webcam_node", anonymous=True)
rgb_topic_name = '/camera/color/image_raw' #'/camera/infra1/image_rect_raw'
webcam_pub = rospy.Publisher(rgb_topic_name, Image, queue_size=1)
vid = cv2.VideoCapture(0)
bridge = CvBridge()
while(True):
ret, frame = vid.read()
cv2.imshow('frame', frame)
image_message = bridge.cv2_to_imgmsg(frame, encoding="passthrough")
webcam_pub.publish(image_message)
# press 'q' to quit
if cv2.waitKey(1) & 0xFF == ord('q'):
break
# After the loop release the cap object
vid.release()
# Destroy all the windows
cv2.destroyAllWindows()

Loading…
Cancel
Save