Skip to content

Commit 974c052

Browse files
committed
C++ backend for Binary Indexed Trees completed
1 parent bc84f6d commit 974c052

File tree

2 files changed

+70
-9
lines changed

2 files changed

+70
-9
lines changed

pydatastructs/trees/_backend/cpp/BinaryIndexedTree.hpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,33 @@ static void BinaryIndexedTree_dealloc(BinaryIndexedTree *self) {
2424
Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
2525
}
2626

27+
static PyObject* BinaryIndexedTree_update(BinaryIndexedTree* self, PyObject *args) {
28+
long index = PyLong_AsLong(PyObject_GetItem(args, PyZero));
29+
long value = PyLong_AsLong(PyObject_GetItem(args, PyOne));
30+
long _index = index;
31+
long _value = value;
32+
if (PyList_GetItem(self->flag, index) == PyZero) {
33+
PyList_SetItem(self->flag, index, PyOne);
34+
index += 1;
35+
while (index < self->array->_size + 1) {
36+
long curr = PyLong_AsLong(PyList_GetItem(self->tree, index));
37+
PyList_SetItem(self->tree, index, PyLong_FromLong(curr + value));
38+
index = index + (index & (-1*index));
39+
}
40+
}
41+
else {
42+
value = value - PyLong_AsLong(self->array->_data[index]);
43+
index += 1;
44+
while (index < self->array->_size + 1) {
45+
long curr = PyLong_AsLong(PyList_GetItem(self->tree, index));
46+
PyList_SetItem(self->tree, index, PyLong_FromLong(curr + value));
47+
index = index + (index & (-1*index));
48+
}
49+
}
50+
self->array->_data[_index] = PyLong_FromLong(_value);
51+
Py_RETURN_NONE;
52+
}
53+
2754
static PyObject* BinaryIndexedTree___new__(PyTypeObject* type, PyObject *args, PyObject *kwds) {
2855
BinaryIndexedTree *self;
2956
self = reinterpret_cast<BinaryIndexedTree*>(type->tp_alloc(type, 0));
@@ -45,13 +72,42 @@ static PyObject* BinaryIndexedTree___new__(PyTypeObject* type, PyObject *args, P
4572
self->flag = PyList_New(self->array->_size);
4673
for(int i=0;i<self->array->_size;i++){
4774
PyList_SetItem(self->flag, i, PyZero);
75+
BinaryIndexedTree_update(self, Py_BuildValue("(OO)", PyLong_FromLong(i), self->array->_data[i]));
4876
}
4977

5078
return reinterpret_cast<PyObject*>(self);
5179
}
5280

81+
static PyObject* BinaryIndexedTree_get_prefix_sum(BinaryIndexedTree* self, PyObject *args) {
82+
long index = PyLong_AsLong(PyObject_GetItem(args, PyZero));
83+
index += 1;
84+
long sum = 0;
85+
while (index > 0) {
86+
sum += PyLong_AsLong(PyList_GetItem(self->tree, index));
87+
index = index - (index & (-1*index));
88+
}
89+
90+
return PyLong_FromLong(sum);
91+
}
92+
93+
static PyObject* BinaryIndexedTree_get_sum(BinaryIndexedTree* self, PyObject *args) {
94+
long left_index = PyLong_AsLong(PyObject_GetItem(args, PyZero));
95+
long right_index = PyLong_AsLong(PyObject_GetItem(args, PyOne));
96+
if (left_index >= 1) {
97+
long l1 = PyLong_AsLong(BinaryIndexedTree_get_prefix_sum(self, Py_BuildValue("(O)", PyLong_FromLong(right_index))));
98+
long l2 = PyLong_AsLong(BinaryIndexedTree_get_prefix_sum(self, Py_BuildValue("(O)", PyLong_FromLong(left_index - 1))));
99+
return PyLong_FromLong(l1 - l2);
100+
}
101+
else {
102+
return BinaryIndexedTree_get_prefix_sum(self, Py_BuildValue("(O)", PyLong_FromLong(right_index)));
103+
}
104+
}
105+
53106

54107
static struct PyMethodDef BinaryIndexedTree_PyMethodDef[] = {
108+
{"update", (PyCFunction) BinaryIndexedTree_update, METH_VARARGS, NULL},
109+
{"get_prefix_sum", (PyCFunction) BinaryIndexedTree_get_prefix_sum, METH_VARARGS, NULL},
110+
{"get_sum", (PyCFunction) BinaryIndexedTree_get_sum, METH_VARARGS, NULL},
55111
{NULL}
56112
};
57113

pydatastructs/trees/tests/test_binary_trees.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -361,20 +361,25 @@ def test_select_rank(expected_output):
361361
test_select_rank([])
362362

363363

364-
def test_BinaryIndexedTree():
364+
def _test_BinaryIndexedTree(backend):
365365

366366
FT = BinaryIndexedTree
367367

368-
t = FT([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], backend=Backend.CPP)
368+
t = FT([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], backend=backend)
369+
370+
assert t.get_sum(0, 2) == 6
371+
assert t.get_sum(0, 4) == 15
372+
assert t.get_sum(0, 9) == 55
373+
t.update(0, 100)
374+
assert t.get_sum(0, 2) == 105
375+
assert t.get_sum(0, 4) == 114
376+
assert t.get_sum(1, 9) == 54
369377

370-
# assert t.get_sum(0, 2) == 6
371-
# assert t.get_sum(0, 4) == 15
372-
# assert t.get_sum(0, 9) == 55
373-
# t.update(0, 100)
374-
# assert t.get_sum(0, 2) == 105
375-
# assert t.get_sum(0, 4) == 114
376-
# assert t.get_sum(1, 9) == 54
378+
def test_BinaryIndexedTree():
379+
_test_BinaryIndexedTree(Backend.PYTHON)
377380

381+
def test_cpp_BinaryIndexedTree():
382+
_test_BinaryIndexedTree(Backend.CPP)
378383

379384
def test_CartesianTree():
380385
tree = CartesianTree()

0 commit comments

Comments
 (0)