From 79f89e6e5a659846d1068e8b1bd8e491ccdef861 Mon Sep 17 00:00:00 2001 From: Pablo Galindo Date: Thu, 23 Jan 2020 14:07:05 +0000 Subject: [PATCH] bpo-39421: Fix posible crash in heapq with custom comparison operators (GH-18118) * bpo-39421: Fix posible crash in heapq with custom comparison operators * fixup! bpo-39421: Fix posible crash in heapq with custom comparison operators * fixup! fixup! bpo-39421: Fix posible crash in heapq with custom comparison operators --- Lib/test/test_heapq.py | 31 ++++++++++++++++ .../2020-01-22-15-53-37.bpo-39421.O3nG7u.rst | 2 ++ Modules/_heapqmodule.c | 35 ++++++++++++++----- 3 files changed, 59 insertions(+), 9 deletions(-) create mode 100644 Misc/NEWS.d/next/Core and Builtins/2020-01-22-15-53-37.bpo-39421.O3nG7u.rst Backport: * Drop Misc/NEWS.d * test_heapq.py: + Update hunk context + list.clear() -> del list[:] * _heapqmodule.c: Port the patch with significant changes + PyObject_RichCompareBool -> cmp_lt + X[Y] -> PyList_GET_ITEM(X, Y) + 4th hunk: newitem refcount is already incremented, parent refcount extended diff --git a/Lib/test/test_heapq.py b/Lib/test/test_heapq.py index 861ba7540d..6902573e8f 100644 --- a/Lib/test/test_heapq.py +++ b/Lib/test/test_heapq.py @@ -432,6 +432,37 @@ def test_heappop_mutating_heap(self): with self.assertRaises((IndexError, RuntimeError)): self.module.heappop(heap) + def test_comparison_operator_modifiying_heap(self): + # See bpo-39421: Strong references need to be taken + # when comparing objects as they can alter the heap + class EvilClass(int): + def __lt__(self, o): + del heap[:] + return NotImplemented + + heap = [] + self.module.heappush(heap, EvilClass(0)) + self.assertRaises(IndexError, self.module.heappushpop, heap, 1) + + def test_comparison_operator_modifiying_heap_two_heaps(self): + + class h(int): + def __lt__(self, o): + del list2[:] + return NotImplemented + + class g(int): + def __lt__(self, o): + del list1[:] + return NotImplemented + + list1, list2 = [], [] + + self.module.heappush(list1, h(0)) + self.module.heappush(list2, g(0)) + + self.assertRaises((IndexError, RuntimeError), self.module.heappush, list1, g(1)) + self.assertRaises((IndexError, RuntimeError), self.module.heappush, list2, h(1)) class TestErrorHandlingPython(TestErrorHandling): module = py_heapq diff --git a/Modules/_heapqmodule.c b/Modules/_heapqmodule.c index a84cade3aa..6bc18b5f82 100644 --- a/Modules/_heapqmodule.c +++ b/Modules/_heapqmodule.c @@ -36,7 +36,11 @@ siftdown(PyListObject *heap, Py_ssize_t startpos, Py_ssize_t pos) while (pos > startpos) { parentpos = (pos - 1) >> 1; parent = PyList_GET_ITEM(heap, parentpos); + Py_INCREF(newitem); + Py_INCREF(parent); cmp = cmp_lt(newitem, parent); + Py_DECREF(parent); + Py_DECREF(newitem); if (cmp == -1) return -1; if (size != PyList_GET_SIZE(heap)) { @@ -78,9 +82,13 @@ siftup(PyListObject *heap, Py_ssize_t pos) childpos = 2*pos + 1; /* leftmost child position */ rightpos = childpos + 1; if (rightpos < endpos) { - cmp = cmp_lt( - PyList_GET_ITEM(heap, childpos), - PyList_GET_ITEM(heap, rightpos)); + PyObject* a = PyList_GET_ITEM(heap, childpos); + PyObject* b = PyList_GET_ITEM(heap, rightpos); + Py_INCREF(a); + Py_INCREF(b); + cmp = cmp_lt(a, b); + Py_DECREF(a); + Py_DECREF(b); if (cmp == -1) return -1; if (cmp == 0) @@ -264,7 +271,10 @@ _heapq_heappushpop_impl(PyObject *module, PyObject *heap, PyObject *item) return item; } - cmp = cmp_lt(PyList_GET_ITEM(heap, 0), item); + PyObject* top = PyList_GET_ITEM(heap, 0); + Py_INCREF(top); + cmp = cmp_lt(top, item); + Py_DECREF(top); if (cmp == -1) return NULL; if (cmp == 0) { @@ -420,14 +430,17 @@ siftdown_max(PyListObject *heap, Py_ssize_t startpos, Py_ssize_t pos) while (pos > startpos){ parentpos = (pos - 1) >> 1; parent = PyList_GET_ITEM(heap, parentpos); + Py_INCREF(parent); cmp = cmp_lt(parent, newitem); if (cmp == -1) { + Py_DECREF(parent); Py_DECREF(newitem); return -1; } - if (cmp == 0) + if (cmp == 0) { + Py_DECREF(parent); break; + } - Py_INCREF(parent); Py_DECREF(PyList_GET_ITEM(heap, pos)); PyList_SET_ITEM(heap, pos, parent); pos = parentpos; @@ -462,9 +476,13 @@ siftup_max(PyListObject *heap, Py_ssize_t pos) childpos = 2*pos + 1; /* leftmost child position */ rightpos = childpos + 1; if (rightpos < endpos) { - cmp = cmp_lt( - PyList_GET_ITEM(heap, rightpos), - PyList_GET_ITEM(heap, childpos)); + PyObject* a = PyList_GET_ITEM(heap, rightpos); + PyObject* b = PyList_GET_ITEM(heap, childpos); + Py_INCREF(a); + Py_INCREF(b); + cmp = cmp_lt(a, b); + Py_DECREF(a); + Py_DECREF(b); if (cmp == -1) { Py_DECREF(newitem); return -1; -- 2.40.1