diff --git a/lib/entropy/locks.py b/lib/entropy/locks.py index a94a2bddc..ef1bbfc4e 100644 --- a/lib/entropy/locks.py +++ b/lib/entropy/locks.py @@ -93,6 +93,8 @@ class _GenericResourceLock(object): Generic Entropy Resource Lock abstract class. """ + _TLS = threading.local() + def __init__(self, lock_map, lock_mutex, output=None): """ Object constructor. @@ -138,23 +140,32 @@ class _GenericResourceLock(object): with self._lock_mutex: mapped = self._file_lock_setup(lock_path) + + # I asked for an exclusive lock, but + # I am only holding a shared one, don't + # return True. + want_exclusive_when_shared = not shared and mapped['shared'] + if mapped['ref'] is not None: - - # I asked for an exclusive lock, but - # I am only holding a shared one, don't - # return True. - want_exclusive_when_shared = ( - shared != mapped['shared']) and ( - not shared and mapped['shared']) - if not want_exclusive_when_shared: # reentrant lock, already acquired mapped['count'] += 1 return True - # fall through + else: mapped['shared'] = shared + # watch for deadlocks using TLS + recursed = getattr(self._TLS, "recursed", False) + if recursed and want_exclusive_when_shared: + # deadlock, raise exception + raise RuntimeError( + "want exclusive lock when shared acquired") + + # not the same thread requested an exclusive lock when shared + self._TLS.recursed = True + # fall through, we won't deadlock + path = mapped['path'] acquired, flock_f = self._file_lock_create( @@ -175,7 +186,15 @@ class _GenericResourceLock(object): """ lock_path = self.path() with self._lock_mutex: + + # allow the same thread to acquire the lock again. + self._TLS.recursed = False + mapped = self._file_lock_setup(lock_path) + + if mapped['count'] == 0: + raise RuntimeError("releasing a non-acquired lock") + # decrement lock counter if mapped['count'] > 0: mapped['count'] -= 1 diff --git a/lib/tests/locks.py b/lib/tests/locks.py index 644625293..e513c2943 100644 --- a/lib/tests/locks.py +++ b/lib/tests/locks.py @@ -90,7 +90,7 @@ class EntropyRepositoryTest(unittest.TestCase): self.assertEquals(True, erl.try_acquire_shared()) self.assertEquals(6, counter_l[0]) - self.assertEquals(False, erl.try_acquire_exclusive()) + self.assertRaises(RuntimeError, erl.try_acquire_exclusive) self.assertEquals(True, erl.try_acquire_shared()) self.assertEquals(7, counter_l[0]) @@ -123,6 +123,56 @@ class EntropyRepositoryTest(unittest.TestCase): except OSError: pass + def test_entropy_resources_lock_exception(self): + + erl = EntropyResourcesLock() + + tmp_fd, tmp_path = None, None + try: + tmp_fd, tmp_path = tempfile.mkstemp( + prefix="test_entropy_resources_lock") + + erl.path = lambda: tmp_path + + get_count = lambda: erl._file_lock_setup(erl.path())['count'] + + self.assertEquals(True, erl.try_acquire_shared()) + self.assertRaises(RuntimeError, erl.try_acquire_exclusive) + + erl.release() + + self.assertEquals(True, erl.try_acquire_exclusive()) + + self.assertEquals(True, erl.try_acquire_shared()) + self.assertEquals(True, erl.try_acquire_shared()) + self.assertEquals(True, erl.try_acquire_shared()) + + self.assertEquals(4, get_count()) + erl.release() + + self.assertEquals(3, get_count()) + erl.release() + + self.assertEquals(2, get_count()) + erl.release() + + self.assertEquals(1, get_count()) + erl.release() + + self.assertEquals(0, get_count()) + + self.assertRaises(RuntimeError, erl.release) + + + finally: + if tmp_fd is not None: + os.close(tmp_fd) + if tmp_path is not None: + try: + os.remove(tmp_path) + except OSError: + pass + if __name__ == '__main__': unittest.main()