Skip to content

Commit 00a6c7c

Browse files
committed
ENH: Use TypeTag
1 parent e05269a commit 00a6c7c

File tree

9 files changed

+77
-15
lines changed

9 files changed

+77
-15
lines changed

pydatastructs/graphs/_backend/cpp/AdjacencyList.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <algorithm>
1010
#include "AdjacencyListGraphNode.hpp"
1111
#include "GraphEdge.hpp"
12+
#include "../../../utils/_backend/cpp/utils.hpp"
1213

1314
extern PyTypeObject AdjacencyListGraphType;
1415

@@ -62,13 +63,12 @@ static PyObject* AdjacencyListGraph_new(PyTypeObject* type, PyObject* args, PyOb
6263
Py_ssize_t num_args = PyTuple_Size(args);
6364
for (Py_ssize_t i = 0; i < num_args; ++i) {
6465
PyObject* node_obj = PyTuple_GetItem(args, i);
65-
if (!PyObject_IsInstance(node_obj, (PyObject*)&AdjacencyListGraphNodeType)) {
66+
AdjacencyListGraphNode* node = reinterpret_cast<AdjacencyListGraphNode*>(node_obj);
67+
if (get_type_tag(node_obj) != NodeType_::AdjacencyListGraphNode) {
6668
PyErr_SetString(PyExc_TypeError, "All arguments must be AdjacencyListGraphNode instances");
6769
return NULL;
6870
}
6971

70-
AdjacencyListGraphNode* node = reinterpret_cast<AdjacencyListGraphNode*>(node_obj);
71-
7272
if (self->node_map.find(node->name) != self->node_map.end()) {
7373
PyErr_Format(PyExc_ValueError, "Duplicate node with name '%s'", node->name.c_str());
7474
return NULL;
@@ -107,7 +107,7 @@ static PyObject* AdjacencyListGraph_add_vertex(AdjacencyListGraph* self, PyObjec
107107
return NULL;
108108
}
109109

110-
if (!PyObject_IsInstance(node_obj, (PyObject*)&AdjacencyListGraphNodeType)) {
110+
if (get_type_tag(node_obj) != NodeType_::AdjacencyListGraphNode) {
111111
PyErr_SetString(PyExc_TypeError, "Object is not an AdjacencyListGraphNode");
112112
return NULL;
113113
}

pydatastructs/graphs/_backend/cpp/AdjacencyMatrix.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ static PyObject* AdjacencyMatrixGraph_new(PyTypeObject* type, PyObject* args, Py
6666
Py_ssize_t len = PyTuple_Size(vertices);
6767
for (Py_ssize_t i = 0; i < len; ++i) {
6868
PyObject* item = PyTuple_GetItem(vertices, i);
69-
if (!PyObject_TypeCheck(item, &AdjacencyMatrixGraphNodeType)) {
69+
if (get_type_tag(item) != NodeType_::AdjacencyMatrixGraphNode) {
7070
PyErr_SetString(PyExc_TypeError, "All elements must be AdjacencyMatrixGraphNode instances");
7171
Py_DECREF(self);
7272
return NULL;

pydatastructs/graphs/_backend/cpp/Algorithms.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ static PyObject* breadth_first_search_adjacency_list(PyObject* self, PyObject* a
4040

4141
for (const auto& [adj_name, adj_obj] : node->adjacent) {
4242
if (visited.count(adj_name)) continue;
43-
if (!PyObject_IsInstance(adj_obj, (PyObject*)&AdjacencyListGraphNodeType)) continue;
43+
if (get_type_tag(adj_obj) != NodeType_::AdjacencyListGraphNode) continue;
4444

4545
AdjacencyListGraphNode* adj_node = reinterpret_cast<AdjacencyListGraphNode*>(adj_obj);
4646

pydatastructs/utils/_backend/cpp/AdjacencyListGraphNode.hpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ extern PyTypeObject AdjacencyListGraphNodeType;
1212

1313
typedef struct {
1414
PyObject_HEAD
15+
NodeType_ type_tag;
1516
std::string name;
1617
int internal_id;
1718
std::variant<std::monostate, int64_t, double, std::string, PyObject *> data;
@@ -34,6 +35,7 @@ static void AdjacencyListGraphNode_dealloc(AdjacencyListGraphNode* self) {
3435
static PyObject* AdjacencyListGraphNode_new(PyTypeObject* type, PyObject* args, PyObject* kwds) {
3536
AdjacencyListGraphNode* self = PyObject_New(AdjacencyListGraphNode, &AdjacencyListGraphNodeType);
3637
if (!self) return NULL;
38+
self->type_tag = NodeType_::AdjacencyListGraphNode;
3739
new (&self->adjacent) std::unordered_map<std::string, PyObject*>();
3840
new (&self->name) std::string();
3941
new (&self->data) std::variant<std::monostate, int64_t, double, std::string, PyObject*>();
@@ -234,6 +236,11 @@ static int AdjacencyListGraphNode_set_adjacent(AdjacencyListGraphNode* self, PyO
234236
return 0;
235237
}
236238

239+
static struct PyMemberDef AdjacencyListGraphNode_PyMemberDef[] = {
240+
{"type_tag", T_INT, offsetof(AdjacencyListGraphNode, type_tag), 0, "AdjacencyListGraphNode type_tag"},
241+
{NULL},
242+
};
243+
237244
static PyGetSetDef AdjacencyListGraphNode_getsetters[] = {
238245
{"name", (getter)AdjacencyListGraphNode_get_name, (setter)AdjacencyListGraphNode_set_name, "Get or set node name", NULL},
239246
{"data", (getter)AdjacencyListGraphNode_get_data, (setter)AdjacencyListGraphNode_set_data, "Get or set node data", NULL},
@@ -275,7 +282,7 @@ inline PyTypeObject AdjacencyListGraphNodeType = {
275282
/* tp_iter */ 0,
276283
/* tp_iternext */ 0,
277284
/* tp_methods */ AdjacencyListGraphNode_methods,
278-
/* tp_members */ 0,
285+
/* tp_members */ AdjacencyListGraphNode_PyMemberDef,
279286
/* tp_getset */ AdjacencyListGraphNode_getsetters,
280287
/* tp_base */ &GraphNodeType,
281288
/* tp_dict */ 0,

pydatastructs/utils/_backend/cpp/AdjacencyMatrixGraphNode.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ static PyObject* AdjacencyMatrixGraphNode_new(PyTypeObject* type, PyObject* args
2323
}
2424

2525
AdjacencyMatrixGraphNode* self = reinterpret_cast<AdjacencyMatrixGraphNode*>(base_obj);
26+
self->super.type_tag = NodeType_::AdjacencyMatrixGraphNode;
2627

2728
return reinterpret_cast<PyObject*>(self);
2829
}

pydatastructs/utils/_backend/cpp/GraphNode.hpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <Python.h>
66
#include <string>
77
#include <variant>
8+
#include "Node.hpp"
89

910
enum class DataType {
1011
None,
@@ -16,6 +17,7 @@ enum class DataType {
1617

1718
typedef struct {
1819
PyObject_HEAD
20+
NodeType_ type_tag;
1921
std::string name;
2022
int internal_id;
2123
std::variant<std::monostate, int64_t, double, std::string, PyObject *> data;
@@ -32,6 +34,7 @@ static void GraphNode_dealloc(GraphNode* self) {
3234
static PyObject* GraphNode_new(PyTypeObject* type, PyObject* args, PyObject* kwds) {
3335
GraphNode* self = reinterpret_cast<GraphNode*>(type->tp_alloc(type, 0));
3436
if (!self) return NULL;
37+
self->type_tag = NodeType_::GraphNode;
3538

3639
new (&self->name) std::string();
3740
new (&self->data) std::variant<std::monostate, int64_t, double, std::string, PyObject*>();
@@ -195,6 +198,12 @@ static PyGetSetDef GraphNode_getsetters[] = {
195198
{nullptr}
196199
};
197200

201+
static struct PyMemberDef GraphNode_PyMemberDef[] = {
202+
{"type_tag", T_INT, offsetof(GraphNode, type_tag), 0, "GraphNode type_tag"},
203+
{NULL},
204+
};
205+
206+
198207
static PyTypeObject GraphNodeType = {
199208
/* tp_name */ PyVarObject_HEAD_INIT(NULL, 0) "GraphNode",
200209
/* tp_basicsize */ sizeof(GraphNode),
@@ -223,7 +232,7 @@ static PyTypeObject GraphNodeType = {
223232
/* tp_iter */ 0,
224233
/* tp_iternext */ 0,
225234
/* tp_methods */ 0,
226-
/* tp_members */ 0,
235+
/* tp_members */ GraphNode_PyMemberDef,
227236
/* tp_getset */ GraphNode_getsetters,
228237
/* tp_base */ &PyBaseObject_Type,
229238
/* tp_dict */ 0,

pydatastructs/utils/_backend/cpp/Node.hpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,19 @@
88

99
typedef struct {
1010
PyObject_HEAD
11+
NodeType_ type_tag;
1112
} Node;
1213
// Node is an abstract class representing a Node
1314

1415
static void Node_dealloc(Node *self) {
1516
Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
1617
}
1718

19+
static struct PyMemberDef Node_PyMemberDef[] = {
20+
{"type_tag", T_INT, offsetof(Node, type_tag), 0, "Node type_tag"},
21+
{NULL},
22+
};
23+
1824

1925
static PyTypeObject NodeType = {
2026
/* tp_name */ PyVarObject_HEAD_INIT(NULL, 0) "Node",
@@ -44,7 +50,7 @@ static PyTypeObject NodeType = {
4450
/* tp_iter */ 0,
4551
/* tp_iternext */ 0,
4652
/* tp_methods */ 0,
47-
/* tp_members */ 0,
53+
/* tp_members */ Node_PyMemberDef,
4854
/* tp_getset */ 0,
4955
/* tp_base */ &PyBaseObject_Type,
5056
/* tp_dict */ 0,

pydatastructs/utils/_backend/cpp/TreeNode.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
typedef struct {
1111
PyObject_HEAD
12+
NodeType_ type_tag;
1213
PyObject* key;
1314
PyObject* data; // can store None or a number
1415
PyObject* left; // can store None or a number
@@ -29,6 +30,7 @@ static void TreeNode_dealloc(TreeNode *self) {
2930
static PyObject* TreeNode___new__(PyTypeObject* type, PyObject *args, PyObject *kwds) {
3031
TreeNode *self;
3132
self = reinterpret_cast<TreeNode*>(type->tp_alloc(type, 0));
33+
self->type_tag = NodeType_::TreeNode;
3234

3335
// Assume that arguments are in the order below. Python code is such that this is true.
3436
self->key = PyObject_GetItem(args, PyZero);
@@ -56,6 +58,7 @@ static PyObject* TreeNode___str__(TreeNode *self) {
5658
}
5759

5860
static struct PyMemberDef TreeNode_PyMemberDef[] = {
61+
{"type_tag", T_INT, offsetof(TreeNode, type_tag), 0, "TreeNode type_tag"},
5962
{"key", T_OBJECT, offsetof(TreeNode, key), 0, "TreeNode key"},
6063
{"data", T_OBJECT, offsetof(TreeNode, data), 0, "TreeNode data"},
6164
{"height", T_LONG, offsetof(TreeNode, height), 0, "TreeNode height"},

pydatastructs/utils/_backend/cpp/utils.hpp

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
#include <cstring>
77
#include <string>
88

9-
PyObject *PyZero = PyLong_FromLong(0);
10-
PyObject *PyOne = PyLong_FromLong(1);
11-
PyObject *PyTwo = PyLong_FromLong(2);
12-
PyObject *PyThree = PyLong_FromLong(3);
13-
const char* _encoding = "utf-8";
14-
const char* _invalid_char = "<invalid-character>";
9+
static PyObject *PyZero = PyLong_FromLong(0);
10+
static PyObject *PyOne = PyLong_FromLong(1);
11+
static PyObject *PyTwo = PyLong_FromLong(2);
12+
static PyObject *PyThree = PyLong_FromLong(3);
13+
static const char* _encoding = "utf-8";
14+
static const char* _invalid_char = "<invalid-character>";
1515

1616
static char* PyObject_AsString(PyObject* obj) {
1717
return PyBytes_AS_STRING(PyUnicode_AsEncodedString(obj, _encoding, _invalid_char));
@@ -107,4 +107,40 @@ static int _comp(PyObject* u, PyObject* v, PyObject* tcomp) {
107107
return result;
108108
}
109109

110+
enum class NodeType_ {
111+
InvalidType,
112+
Node,
113+
TreeNode,
114+
GraphNode,
115+
AdjacencyListGraphNode,
116+
AdjacencyMatrixGraphNode,
117+
GraphEdge
118+
};
119+
120+
static NodeType_ get_type_tag(PyObject *node_obj) {
121+
if (!PyObject_HasAttrString(node_obj, "type_tag")) {
122+
return NodeType_::InvalidType; // attribute missing
123+
}
124+
125+
PyObject *attr = PyObject_GetAttrString(node_obj, "type_tag");
126+
if (!attr) {
127+
return NodeType_::InvalidType; // getattr failed
128+
}
129+
130+
if (!PyLong_Check(attr)) {
131+
Py_DECREF(attr);
132+
return NodeType_::InvalidType; // not an int
133+
}
134+
135+
int tag = (int)PyLong_AsLong(attr);
136+
Py_DECREF(attr);
137+
138+
if (PyErr_Occurred()) {
139+
return NodeType_::InvalidType; // overflow or error in cast
140+
}
141+
142+
return static_cast<NodeType_>(tag);
143+
}
144+
145+
110146
#endif

0 commit comments

Comments
 (0)