diff --git a/stretch_core/nodes/rwlock.py b/stretch_core/nodes/rwlock.py new file mode 100644 index 0000000..43b0164 --- /dev/null +++ b/stretch_core/nodes/rwlock.py @@ -0,0 +1,119 @@ +#! /usr/bin/env python +"""rwlock.py +This write-preferring implementation of the readwrite lock +was made available via the Creative Commons license at: +https://stackoverflow.com/a/22109979 +""" + +import threading + +class RWLock: + """ Non-reentrant write-preferring rwlock. """ + DEBUG = 0 + + def __init__(self): + self.lock = threading.Lock() + + self.active_writer_lock = threading.Lock() + # The total number of writers including the active writer and + # those blocking on active_writer_lock or readers_finished_cond. + self.writer_count = 0 + + # Number of events that are blocking on writers_finished_cond. + self.waiting_reader_count = 0 + + # Number of events currently using the resource. + self.active_reader_count = 0 + + self.readers_finished_cond = threading.Condition(self.lock) + self.writers_finished_cond = threading.Condition(self.lock) + + class _ReadAccess: + def __init__(self, rwlock): + self.rwlock = rwlock + def __enter__(self): + self.rwlock.acquire_read() + return self.rwlock + def __exit__(self, type, value, tb): + self.rwlock.release_read() + # support for the with statement + self.read_access = _ReadAccess(self) + + class _WriteAccess: + def __init__(self, rwlock): + self.rwlock = rwlock + def __enter__(self): + self.rwlock.acquire_write() + return self.rwlock + def __exit__(self, type, value, tb): + self.rwlock.release_write() + # support for the with statement + self.write_access = _WriteAccess(self) + + if self.DEBUG: + self.active_readers = set() + self.active_writer = None + + def acquire_read(self): + with self.lock: + if self.DEBUG: + me = threading.currentThread() + assert me not in self.active_readers, 'This thread has already acquired read access and this lock isn\'t reader-reentrant!' + assert me != self.active_writer, 'This thread already has write access, release that before acquiring read access!' + self.active_readers.add(me) + if self.writer_count: + self.waiting_reader_count += 1 + self.writers_finished_cond.wait() + # Even if the last writer thread notifies us it can happen that a new + # incoming writer thread acquires the lock earlier than this reader + # thread so we test for the writer_count after each wait()... + # We also protect ourselves from spurious wakeups that happen with some POSIX libraries. + while self.writer_count: + self.writers_finished_cond.wait() + self.waiting_reader_count -= 1 + self.active_reader_count += 1 + + def release_read(self): + with self.lock: + if self.DEBUG: + me = threading.currentThread() + assert me in self.active_readers, 'Trying to release read access when it hasn\'t been acquired by this thread!' + self.active_readers.remove(me) + assert self.active_reader_count > 0 + self.active_reader_count -= 1 + if not self.active_reader_count and self.writer_count: + self.readers_finished_cond.notifyAll() + + def acquire_write(self): + with self.lock: + if self.DEBUG: + me = threading.currentThread() + assert me not in self.active_readers, 'This thread already has read access - release that before acquiring write access!' + assert me != self.active_writer, 'This thread already has write access and this lock isn\'t writer-reentrant!' + self.writer_count += 1 + if self.active_reader_count: + self.readers_finished_cond.wait() + while self.active_reader_count: + self.readers_finished_cond.wait() + + self.active_writer_lock.acquire() + if self.DEBUG: + self.active_writer = me + + def release_write(self): + if not self.DEBUG: + self.active_writer_lock.release() + with self.lock: + if self.DEBUG: + me = threading.currentThread() + assert me == self.active_writer, 'Trying to release write access when it hasn\'t been acquired by this thread!' + self.active_writer = None + self.active_writer_lock.release() + assert self.writer_count > 0 + self.writer_count -= 1 + if not self.writer_count and self.waiting_reader_count: + self.writers_finished_cond.notifyAll() + + def get_state(self): + with self.lock: + return (self.writer_count, self.waiting_reader_count, self.active_reader_count) diff --git a/stretch_core/nodes/stretch_driver b/stretch_core/nodes/stretch_driver index 58bcec3..d45630f 100755 --- a/stretch_core/nodes/stretch_driver +++ b/stretch_core/nodes/stretch_driver @@ -4,6 +4,7 @@ from __future__ import print_function import yaml import numpy as np import threading +from rwlock import RWLock import stretch_body.robot as rb from stretch_body.hello_utils import ThreadServiceExit @@ -75,12 +76,8 @@ class StretchBodyNode: self.robot_stop_lock = threading.Lock() self.stop_the_robot = False - self.robot_mode_lock = threading.Lock() - with self.robot_mode_lock: - self.robot_mode = None - self.robot_mode_read_only = 0 - - self.mode_change_polling_rate_hz = 4.0 + self.robot_mode_rwlock = RWLock() + self.robot_mode = None def trajectory_action_server_callback(self, goal): @@ -98,8 +95,7 @@ class StretchBodyNode: # trigger. self.stop_the_robot = False - with self.robot_mode_lock: - self.robot_mode_read_only += 1 + self.robot_mode_rwlock.acquire_read() def invalid_joints_error(error_string): error_string = '{0} action server:'.format(self.node_name) + error_string @@ -107,8 +103,6 @@ class StretchBodyNode: result = FollowJointTrajectoryResult() result.error_code = result.INVALID_JOINTS self.trajectory_action_server.set_aborted(result) - with self.robot_mode_lock: - self.robot_mode_read_only -= 1 def invalid_goal_error(error_string): error_string = '{0} action server:'.format(self.node_name) + error_string @@ -116,8 +110,6 @@ class StretchBodyNode: result = FollowJointTrajectoryResult() result.error_code = result.INVALID_JOINTS self.trajectory_action_server.set_aborted(result) - with self.robot_mode_lock: - self.robot_mode_read_only -= 1 def goal_tolerance_violated(error_string): error_string = '{0} action server:'.format(self.node_name) + error_string @@ -125,8 +117,6 @@ class StretchBodyNode: result = FollowJointTrajectoryResult() result.error_code = result.GOAL_TOLERANCE_VIOLATED self.trajectory_action_server.set_aborted(result) - with self.robot_mode_lock: - self.robot_mode_read_only -= 1 # For now, ignore goal time and configuration tolerances. joint_names = goal.trajectory.joint_names @@ -145,6 +135,7 @@ class StretchBodyNode: # The joint names violated at least one of the command # group's requirements. The command group should have # reported the error. + self.robot_mode_rwlock.release_read() return number_of_valid_joints = sum([c.get_num_valid_commands() for c in command_groups]) @@ -153,11 +144,13 @@ class StretchBodyNode: # Abort if no valid joints were received. error_string = 'received a command without any valid joint names. Received joint names = ' + str(joint_names) invalid_joints_error(error_string) + self.robot_mode_rwlock.release_read() return if len(joint_names) != number_of_valid_joints: error_string = 'received {0} valid joints and {1} total joints. Received joint names = {2}'.format(number_of_valid_joints, len(joint_names), joint_names) invalid_joints_error(error_string) + self.robot_mode_rwlock.release_read() return ################################################### @@ -173,6 +166,7 @@ class StretchBodyNode: # At least one of the goals violated the requirements # of a command group. Any violations should have been # reported as errors by the command groups. + self.robot_mode_rwlock.release_read() return # Attempt to reach the goal. @@ -207,10 +201,8 @@ class StretchBodyNode: if self.trajectory_debug: rospy.loginfo('PREEMPTION REQUESTED, but not stopping current motions to allow smooth interpolation between old and new commands.') self.trajectory_action_server.set_preempted() - with self.robot_mode_lock: - self.robot_mode_read_only -= 1 - self.stop_the_robot = False + self.robot_mode_rwlock.release_read() return if not incremental_commands_executed: @@ -219,6 +211,7 @@ class StretchBodyNode: if translate and rotate: error_string = 'simultaneous translation and rotation of the mobile base requested. This is not allowed.' invalid_goal_error(error_string) + self.robot_mode_rwlock.release_read() return if translate: self.robot.base.translate_by(mobile_base_error_m) @@ -272,6 +265,7 @@ class StretchBodyNode: if (rospy.Time.now() - goal_start_time) > self.default_goal_timeout_duration: error_string = 'time to execute the current goal point = {0} exceeded the default_goal_timeout = {1}'.format(point, self.default_goal_timeout_s) goal_tolerance_violated(error_string) + self.robot_mode_rwlock.release_read() return update_rate.sleep() @@ -284,15 +278,13 @@ class StretchBodyNode: result = FollowJointTrajectoryResult() result.error_code = result.SUCCESSFUL self.trajectory_action_server.set_succeeded(result) - - with self.robot_mode_lock: - self.robot_mode_read_only -= 1 + self.robot_mode_rwlock.release_read() return ###### MOBILE BASE VELOCITY METHODS ####### def set_mobile_base_velocity_callback(self, twist): - # check on thread safety for this callback function + self.robot_mode_rwlock.acquire_read() if self.robot_mode != 'navigation': error_string = '{0} action server must be in navigation mode to receive a twist on cmd_vel. Current mode = {1}.'.format(self.node_name, self.robot_mode) rospy.logerr(error_string) @@ -300,10 +292,10 @@ class StretchBodyNode: self.linear_velocity_mps = twist.linear.x self.angular_velocity_radps = twist.angular.z self.last_twist_time = rospy.get_time() + self.robot_mode_rwlock.release_read() def command_mobile_base_velocity_and_publish_state(self): - with self.robot_mode_lock: - self.robot_mode_read_only += 1 + self.robot_mode_rwlock.acquire_read() if BACKLASH_DEBUG: print('***') @@ -574,24 +566,16 @@ class StretchBodyNode: self.imu_wrist_pub.publish(i) ################################################## - with self.robot_mode_lock: - self.robot_mode_read_only -= 1 - return + self.robot_mode_rwlock.release_read() ######## CHANGE MODES ######### def change_mode(self, new_mode, code_to_run): - polling_rate = rospy.Rate(self.mode_change_polling_rate_hz) - changed = False - while not changed: - with self.robot_mode_lock: - if self.robot_mode_read_only == 0: - self.robot_mode = new_mode - code_to_run() - changed = True - rospy.loginfo('Changed to mode = {0}'.format(self.robot_mode)) - if not changed: - polling_rate.sleep() + self.robot_mode_rwlock.acquire_write() + self.robot_mode = new_mode + code_to_run() + rospy.loginfo('{0}: Changed to mode = {1}'.format(self.node_name, self.robot_mode)) + self.robot_mode_rwlock.release_write() # TODO : add a freewheel mode or something comparable for the mobile base?