[entropy.locks] use TLS for reentrancy safety checks, update tests

This commit is contained in:
Fabio Erculiani
2013-12-09 14:19:21 +01:00
parent 7e30b65744
commit bf3a8b79f9
2 changed files with 79 additions and 10 deletions

View File

@@ -93,6 +93,8 @@ class _GenericResourceLock(object):
Generic Entropy Resource Lock abstract class.
"""
_TLS = threading.local()
def __init__(self, lock_map, lock_mutex, output=None):
"""
Object constructor.
@@ -138,23 +140,32 @@ class _GenericResourceLock(object):
with self._lock_mutex:
mapped = self._file_lock_setup(lock_path)
# I asked for an exclusive lock, but
# I am only holding a shared one, don't
# return True.
want_exclusive_when_shared = not shared and mapped['shared']
if mapped['ref'] is not None:
# I asked for an exclusive lock, but
# I am only holding a shared one, don't
# return True.
want_exclusive_when_shared = (
shared != mapped['shared']) and (
not shared and mapped['shared'])
if not want_exclusive_when_shared:
# reentrant lock, already acquired
mapped['count'] += 1
return True
# fall through
else:
mapped['shared'] = shared
# watch for deadlocks using TLS
recursed = getattr(self._TLS, "recursed", False)
if recursed and want_exclusive_when_shared:
# deadlock, raise exception
raise RuntimeError(
"want exclusive lock when shared acquired")
# not the same thread requested an exclusive lock when shared
self._TLS.recursed = True
# fall through, we won't deadlock
path = mapped['path']
acquired, flock_f = self._file_lock_create(
@@ -175,7 +186,15 @@ class _GenericResourceLock(object):
"""
lock_path = self.path()
with self._lock_mutex:
# allow the same thread to acquire the lock again.
self._TLS.recursed = False
mapped = self._file_lock_setup(lock_path)
if mapped['count'] == 0:
raise RuntimeError("releasing a non-acquired lock")
# decrement lock counter
if mapped['count'] > 0:
mapped['count'] -= 1

View File

@@ -90,7 +90,7 @@ class EntropyRepositoryTest(unittest.TestCase):
self.assertEquals(True, erl.try_acquire_shared())
self.assertEquals(6, counter_l[0])
self.assertEquals(False, erl.try_acquire_exclusive())
self.assertRaises(RuntimeError, erl.try_acquire_exclusive)
self.assertEquals(True, erl.try_acquire_shared())
self.assertEquals(7, counter_l[0])
@@ -123,6 +123,56 @@ class EntropyRepositoryTest(unittest.TestCase):
except OSError:
pass
def test_entropy_resources_lock_exception(self):
erl = EntropyResourcesLock()
tmp_fd, tmp_path = None, None
try:
tmp_fd, tmp_path = tempfile.mkstemp(
prefix="test_entropy_resources_lock")
erl.path = lambda: tmp_path
get_count = lambda: erl._file_lock_setup(erl.path())['count']
self.assertEquals(True, erl.try_acquire_shared())
self.assertRaises(RuntimeError, erl.try_acquire_exclusive)
erl.release()
self.assertEquals(True, erl.try_acquire_exclusive())
self.assertEquals(True, erl.try_acquire_shared())
self.assertEquals(True, erl.try_acquire_shared())
self.assertEquals(True, erl.try_acquire_shared())
self.assertEquals(4, get_count())
erl.release()
self.assertEquals(3, get_count())
erl.release()
self.assertEquals(2, get_count())
erl.release()
self.assertEquals(1, get_count())
erl.release()
self.assertEquals(0, get_count())
self.assertRaises(RuntimeError, erl.release)
finally:
if tmp_fd is not None:
os.close(tmp_fd)
if tmp_path is not None:
try:
os.remove(tmp_path)
except OSError:
pass
if __name__ == '__main__':
unittest.main()