Browse Source

Replaced starved condition lock with read_write lock

pull/2/head
hello-binit 4 years ago
parent
commit
5871394001
2 changed files with 140 additions and 37 deletions
  1. +119
    -0
      stretch_core/nodes/rwlock.py
  2. +21
    -37
      stretch_core/nodes/stretch_driver

+ 119
- 0
stretch_core/nodes/rwlock.py View File

@ -0,0 +1,119 @@
#! /usr/bin/env python
"""rwlock.py
This write-preferring implementation of the readwrite lock
was made available via the Creative Commons license at:
https://stackoverflow.com/a/22109979
"""
import threading
class RWLock:
""" Non-reentrant write-preferring rwlock. """
DEBUG = 0
def __init__(self):
self.lock = threading.Lock()
self.active_writer_lock = threading.Lock()
# The total number of writers including the active writer and
# those blocking on active_writer_lock or readers_finished_cond.
self.writer_count = 0
# Number of events that are blocking on writers_finished_cond.
self.waiting_reader_count = 0
# Number of events currently using the resource.
self.active_reader_count = 0
self.readers_finished_cond = threading.Condition(self.lock)
self.writers_finished_cond = threading.Condition(self.lock)
class _ReadAccess:
def __init__(self, rwlock):
self.rwlock = rwlock
def __enter__(self):
self.rwlock.acquire_read()
return self.rwlock
def __exit__(self, type, value, tb):
self.rwlock.release_read()
# support for the with statement
self.read_access = _ReadAccess(self)
class _WriteAccess:
def __init__(self, rwlock):
self.rwlock = rwlock
def __enter__(self):
self.rwlock.acquire_write()
return self.rwlock
def __exit__(self, type, value, tb):
self.rwlock.release_write()
# support for the with statement
self.write_access = _WriteAccess(self)
if self.DEBUG:
self.active_readers = set()
self.active_writer = None
def acquire_read(self):
with self.lock:
if self.DEBUG:
me = threading.currentThread()
assert me not in self.active_readers, 'This thread has already acquired read access and this lock isn\'t reader-reentrant!'
assert me != self.active_writer, 'This thread already has write access, release that before acquiring read access!'
self.active_readers.add(me)
if self.writer_count:
self.waiting_reader_count += 1
self.writers_finished_cond.wait()
# Even if the last writer thread notifies us it can happen that a new
# incoming writer thread acquires the lock earlier than this reader
# thread so we test for the writer_count after each wait()...
# We also protect ourselves from spurious wakeups that happen with some POSIX libraries.
while self.writer_count:
self.writers_finished_cond.wait()
self.waiting_reader_count -= 1
self.active_reader_count += 1
def release_read(self):
with self.lock:
if self.DEBUG:
me = threading.currentThread()
assert me in self.active_readers, 'Trying to release read access when it hasn\'t been acquired by this thread!'
self.active_readers.remove(me)
assert self.active_reader_count > 0
self.active_reader_count -= 1
if not self.active_reader_count and self.writer_count:
self.readers_finished_cond.notifyAll()
def acquire_write(self):
with self.lock:
if self.DEBUG:
me = threading.currentThread()
assert me not in self.active_readers, 'This thread already has read access - release that before acquiring write access!'
assert me != self.active_writer, 'This thread already has write access and this lock isn\'t writer-reentrant!'
self.writer_count += 1
if self.active_reader_count:
self.readers_finished_cond.wait()
while self.active_reader_count:
self.readers_finished_cond.wait()
self.active_writer_lock.acquire()
if self.DEBUG:
self.active_writer = me
def release_write(self):
if not self.DEBUG:
self.active_writer_lock.release()
with self.lock:
if self.DEBUG:
me = threading.currentThread()
assert me == self.active_writer, 'Trying to release write access when it hasn\'t been acquired by this thread!'
self.active_writer = None
self.active_writer_lock.release()
assert self.writer_count > 0
self.writer_count -= 1
if not self.writer_count and self.waiting_reader_count:
self.writers_finished_cond.notifyAll()
def get_state(self):
with self.lock:
return (self.writer_count, self.waiting_reader_count, self.active_reader_count)

+ 21
- 37
stretch_core/nodes/stretch_driver View File

@ -4,6 +4,7 @@ from __future__ import print_function
import yaml import yaml
import numpy as np import numpy as np
import threading import threading
from rwlock import RWLock
import stretch_body.robot as rb import stretch_body.robot as rb
from stretch_body.hello_utils import ThreadServiceExit from stretch_body.hello_utils import ThreadServiceExit
@ -75,12 +76,8 @@ class StretchBodyNode:
self.robot_stop_lock = threading.Lock() self.robot_stop_lock = threading.Lock()
self.stop_the_robot = False self.stop_the_robot = False
self.robot_mode_lock = threading.Lock()
with self.robot_mode_lock:
self.robot_mode = None
self.robot_mode_read_only = 0
self.mode_change_polling_rate_hz = 4.0
self.robot_mode_rwlock = RWLock()
self.robot_mode = None
def trajectory_action_server_callback(self, goal): def trajectory_action_server_callback(self, goal):
@ -98,8 +95,7 @@ class StretchBodyNode:
# trigger. # trigger.
self.stop_the_robot = False self.stop_the_robot = False
with self.robot_mode_lock:
self.robot_mode_read_only += 1
self.robot_mode_rwlock.acquire_read()
def invalid_joints_error(error_string): def invalid_joints_error(error_string):
error_string = '{0} action server:'.format(self.node_name) + error_string error_string = '{0} action server:'.format(self.node_name) + error_string
@ -107,8 +103,6 @@ class StretchBodyNode:
result = FollowJointTrajectoryResult() result = FollowJointTrajectoryResult()
result.error_code = result.INVALID_JOINTS result.error_code = result.INVALID_JOINTS
self.trajectory_action_server.set_aborted(result) self.trajectory_action_server.set_aborted(result)
with self.robot_mode_lock:
self.robot_mode_read_only -= 1
def invalid_goal_error(error_string): def invalid_goal_error(error_string):
error_string = '{0} action server:'.format(self.node_name) + error_string error_string = '{0} action server:'.format(self.node_name) + error_string
@ -116,8 +110,6 @@ class StretchBodyNode:
result = FollowJointTrajectoryResult() result = FollowJointTrajectoryResult()
result.error_code = result.INVALID_JOINTS result.error_code = result.INVALID_JOINTS
self.trajectory_action_server.set_aborted(result) self.trajectory_action_server.set_aborted(result)
with self.robot_mode_lock:
self.robot_mode_read_only -= 1
def goal_tolerance_violated(error_string): def goal_tolerance_violated(error_string):
error_string = '{0} action server:'.format(self.node_name) + error_string error_string = '{0} action server:'.format(self.node_name) + error_string
@ -125,8 +117,6 @@ class StretchBodyNode:
result = FollowJointTrajectoryResult() result = FollowJointTrajectoryResult()
result.error_code = result.GOAL_TOLERANCE_VIOLATED result.error_code = result.GOAL_TOLERANCE_VIOLATED
self.trajectory_action_server.set_aborted(result) self.trajectory_action_server.set_aborted(result)
with self.robot_mode_lock:
self.robot_mode_read_only -= 1
# For now, ignore goal time and configuration tolerances. # For now, ignore goal time and configuration tolerances.
joint_names = goal.trajectory.joint_names joint_names = goal.trajectory.joint_names
@ -145,6 +135,7 @@ class StretchBodyNode:
# The joint names violated at least one of the command # The joint names violated at least one of the command
# group's requirements. The command group should have # group's requirements. The command group should have
# reported the error. # reported the error.
self.robot_mode_rwlock.release_read()
return return
number_of_valid_joints = sum([c.get_num_valid_commands() for c in command_groups]) number_of_valid_joints = sum([c.get_num_valid_commands() for c in command_groups])
@ -153,11 +144,13 @@ class StretchBodyNode:
# Abort if no valid joints were received. # Abort if no valid joints were received.
error_string = 'received a command without any valid joint names. Received joint names = ' + str(joint_names) error_string = 'received a command without any valid joint names. Received joint names = ' + str(joint_names)
invalid_joints_error(error_string) invalid_joints_error(error_string)
self.robot_mode_rwlock.release_read()
return return
if len(joint_names) != number_of_valid_joints: if len(joint_names) != number_of_valid_joints:
error_string = 'received {0} valid joints and {1} total joints. Received joint names = {2}'.format(number_of_valid_joints, len(joint_names), joint_names) error_string = 'received {0} valid joints and {1} total joints. Received joint names = {2}'.format(number_of_valid_joints, len(joint_names), joint_names)
invalid_joints_error(error_string) invalid_joints_error(error_string)
self.robot_mode_rwlock.release_read()
return return
################################################### ###################################################
@ -173,6 +166,7 @@ class StretchBodyNode:
# At least one of the goals violated the requirements # At least one of the goals violated the requirements
# of a command group. Any violations should have been # of a command group. Any violations should have been
# reported as errors by the command groups. # reported as errors by the command groups.
self.robot_mode_rwlock.release_read()
return return
# Attempt to reach the goal. # Attempt to reach the goal.
@ -207,10 +201,8 @@ class StretchBodyNode:
if self.trajectory_debug: if self.trajectory_debug:
rospy.loginfo('PREEMPTION REQUESTED, but not stopping current motions to allow smooth interpolation between old and new commands.') rospy.loginfo('PREEMPTION REQUESTED, but not stopping current motions to allow smooth interpolation between old and new commands.')
self.trajectory_action_server.set_preempted() self.trajectory_action_server.set_preempted()
with self.robot_mode_lock:
self.robot_mode_read_only -= 1
self.stop_the_robot = False self.stop_the_robot = False
self.robot_mode_rwlock.release_read()
return return
if not incremental_commands_executed: if not incremental_commands_executed:
@ -219,6 +211,7 @@ class StretchBodyNode:
if translate and rotate: if translate and rotate:
error_string = 'simultaneous translation and rotation of the mobile base requested. This is not allowed.' error_string = 'simultaneous translation and rotation of the mobile base requested. This is not allowed.'
invalid_goal_error(error_string) invalid_goal_error(error_string)
self.robot_mode_rwlock.release_read()
return return
if translate: if translate:
self.robot.base.translate_by(mobile_base_error_m) self.robot.base.translate_by(mobile_base_error_m)
@ -272,6 +265,7 @@ class StretchBodyNode:
if (rospy.Time.now() - goal_start_time) > self.default_goal_timeout_duration: if (rospy.Time.now() - goal_start_time) > self.default_goal_timeout_duration:
error_string = 'time to execute the current goal point = {0} exceeded the default_goal_timeout = {1}'.format(point, self.default_goal_timeout_s) error_string = 'time to execute the current goal point = {0} exceeded the default_goal_timeout = {1}'.format(point, self.default_goal_timeout_s)
goal_tolerance_violated(error_string) goal_tolerance_violated(error_string)
self.robot_mode_rwlock.release_read()
return return
update_rate.sleep() update_rate.sleep()
@ -284,15 +278,13 @@ class StretchBodyNode:
result = FollowJointTrajectoryResult() result = FollowJointTrajectoryResult()
result.error_code = result.SUCCESSFUL result.error_code = result.SUCCESSFUL
self.trajectory_action_server.set_succeeded(result) self.trajectory_action_server.set_succeeded(result)
with self.robot_mode_lock:
self.robot_mode_read_only -= 1
self.robot_mode_rwlock.release_read()
return return
###### MOBILE BASE VELOCITY METHODS ####### ###### MOBILE BASE VELOCITY METHODS #######
def set_mobile_base_velocity_callback(self, twist): def set_mobile_base_velocity_callback(self, twist):
# check on thread safety for this callback function
self.robot_mode_rwlock.acquire_read()
if self.robot_mode != 'navigation': if self.robot_mode != 'navigation':
error_string = '{0} action server must be in navigation mode to receive a twist on cmd_vel. Current mode = {1}.'.format(self.node_name, self.robot_mode) error_string = '{0} action server must be in navigation mode to receive a twist on cmd_vel. Current mode = {1}.'.format(self.node_name, self.robot_mode)
rospy.logerr(error_string) rospy.logerr(error_string)
@ -300,10 +292,10 @@ class StretchBodyNode:
self.linear_velocity_mps = twist.linear.x self.linear_velocity_mps = twist.linear.x
self.angular_velocity_radps = twist.angular.z self.angular_velocity_radps = twist.angular.z
self.last_twist_time = rospy.get_time() self.last_twist_time = rospy.get_time()
self.robot_mode_rwlock.release_read()
def command_mobile_base_velocity_and_publish_state(self): def command_mobile_base_velocity_and_publish_state(self):
with self.robot_mode_lock:
self.robot_mode_read_only += 1
self.robot_mode_rwlock.acquire_read()
if BACKLASH_DEBUG: if BACKLASH_DEBUG:
print('***') print('***')
@ -574,24 +566,16 @@ class StretchBodyNode:
self.imu_wrist_pub.publish(i) self.imu_wrist_pub.publish(i)
################################################## ##################################################
with self.robot_mode_lock:
self.robot_mode_read_only -= 1
return
self.robot_mode_rwlock.release_read()
######## CHANGE MODES ######### ######## CHANGE MODES #########
def change_mode(self, new_mode, code_to_run): def change_mode(self, new_mode, code_to_run):
polling_rate = rospy.Rate(self.mode_change_polling_rate_hz)
changed = False
while not changed:
with self.robot_mode_lock:
if self.robot_mode_read_only == 0:
self.robot_mode = new_mode
code_to_run()
changed = True
rospy.loginfo('Changed to mode = {0}'.format(self.robot_mode))
if not changed:
polling_rate.sleep()
self.robot_mode_rwlock.acquire_write()
self.robot_mode = new_mode
code_to_run()
rospy.loginfo('{0}: Changed to mode = {1}'.format(self.node_name, self.robot_mode))
self.robot_mode_rwlock.release_write()
# TODO : add a freewheel mode or something comparable for the mobile base? # TODO : add a freewheel mode or something comparable for the mobile base?

Loading…
Cancel
Save