diff --git a/gccjit/gccjit.pyx b/gccjit/gccjit.pyx index 7371a28..ea5b485 100644 --- a/gccjit/gccjit.pyx +++ b/gccjit/gccjit.pyx @@ -17,6 +17,7 @@ # . from libc.stdlib cimport malloc, free +import weakref cimport gccjit as c_api class Error(Exception): @@ -25,6 +26,7 @@ class Error(Exception): cdef class Context: cdef c_api.gcc_jit_context* _c_ctxt + cdef object __weakref__ def __cinit__(self, acquire=True): if acquire: @@ -49,12 +51,12 @@ cdef class Context: def get_type(self, type_enum): """get_type(self, type_enum:TypeKind) -> Type""" - return Type_from_c(self._c_ctxt, + return Type_from_c(self, c_api.gcc_jit_context_get_type(self._c_ctxt, type_enum)) def get_int_type(self, num_bytes, is_signed): """get_int_type(self, num_bytes:int, is_signed:bool) -> Type""" - return Type_from_c(self._c_ctxt, + return Type_from_c(self, c_api.gcc_jit_context_get_int_type(self._c_ctxt, num_bytes, is_signed)) def compile(self): @@ -63,7 +65,7 @@ cdef class Context: c_result = c_api.gcc_jit_context_compile(self._c_ctxt) if c_result == NULL: raise Error(self.get_first_error()) - r = Result() + r = Result(self) r._set_c_ptr(c_result) return r @@ -102,7 +104,7 @@ cdef class Context: """new_location(self, filename:str, line:int, column:int) -> Location""" cdef c_api.gcc_jit_location *c_loc c_loc = c_api.gcc_jit_context_new_location(self._c_ctxt, filename, line, column) - loc = Location() + loc = Location(self) loc._set_c_location(c_loc) return loc @@ -113,7 +115,7 @@ cdef class Context: kind, type_._get_c_type(), name) - return LValue_from_c(self._c_ctxt, c_lvalue) + return LValue_from_c(self, c_lvalue) def new_array_type(self, Type element_type, int num_elements, Location loc=None): """new_array_type(self, element_type:Type, num_elements:int, loc:Location=None) -> Type""" @@ -121,7 +123,7 @@ cdef class Context: get_c_location(loc), element_type._get_c_type(), num_elements) - return Type_from_c(self._c_ctxt, + return Type_from_c(self, c_type) def new_field(self, Type type_, name, Location loc=None): @@ -130,7 +132,7 @@ cdef class Context: get_c_location(loc), type_._get_c_type(), name) - field = Field() + field = Field(self) field._set_c_field(c_field) return field @@ -163,7 +165,7 @@ cdef class Context: name, num_fields, c_fields) - py_struct = Struct() + py_struct = Struct(self) py_struct._set_c_struct(c_struct) free(c_fields) return py_struct @@ -192,7 +194,7 @@ cdef class Context: name, num_fields, c_fields) - py_type = Type() + py_type = Type(self) py_type._set_c_type(c_type) free(c_fields) return py_type @@ -222,7 +224,7 @@ cdef class Context: num_params, c_param_types, is_variadic) - py_type = Type() + py_type = Type(self) py_type._set_c_type(c_fn_ptr_type) free(c_param_types) return py_type @@ -233,7 +235,7 @@ cdef class Context: get_c_location(loc), type_._get_c_type(), name) - return Param_from_c(self._c_ctxt, c_result) + return Param_from_c(self, c_result) def new_function(self, kind, Type return_type, name, params, Location loc=None, @@ -258,56 +260,56 @@ cdef class Context: c_params, is_variadic) free(c_params) - return Function_from_c(self._c_ctxt, c_function) + return Function_from_c(self, c_function) def get_builtin_function(self, name): """get_builtin_function(self, name:str) -> Function""" c_function = c_api.gcc_jit_context_get_builtin_function (self._c_ctxt, name) - return Function_from_c(self._c_ctxt, c_function) + return Function_from_c(self, c_function) def zero(self, Type type_): """zero(self, type_:Type) -> RValue""" c_rvalue = c_api.gcc_jit_context_zero(self._c_ctxt, type_._get_c_type()) - return RValue_from_c(self._c_ctxt, c_rvalue) + return RValue_from_c(self, c_rvalue) def one(self, Type type_): """one(self, type_:Type) -> RValue""" c_rvalue = c_api.gcc_jit_context_one(self._c_ctxt, type_._get_c_type()) - return RValue_from_c(self._c_ctxt, c_rvalue) + return RValue_from_c(self, c_rvalue) def new_rvalue_from_double(self, Type numeric_type, double value): """new_rvalue_from_double(self, numeric_type:Type, value:float) -> RValue""" c_rvalue = c_api.gcc_jit_context_new_rvalue_from_double(self._c_ctxt, numeric_type._get_c_type(), value) - return RValue_from_c(self._c_ctxt, c_rvalue) + return RValue_from_c(self, c_rvalue) def new_rvalue_from_int(self, Type type_, int value): """new_rvalue_from_int(self, type_:Type, value:int) -> RValue""" c_rvalue = c_api.gcc_jit_context_new_rvalue_from_int(self._c_ctxt, type_._get_c_type(), value) - return RValue_from_c(self._c_ctxt, c_rvalue) + return RValue_from_c(self, c_rvalue) def new_rvalue_from_ptr(self, Type pointer_type, long long value): c_rvalue = c_api.gcc_jit_context_new_rvalue_from_ptr(self._c_ctxt, pointer_type._get_c_type(), value) - return RValue_from_c(self._c_ctxt, c_rvalue) + return RValue_from_c(self, c_rvalue) def null(self, Type pointer_type): """null(self, pointer_type:Type) -> RValue""" c_rvalue = c_api.gcc_jit_context_null(self._c_ctxt, pointer_type._get_c_type()) - return RValue_from_c(self._c_ctxt, c_rvalue) + return RValue_from_c(self, c_rvalue) def new_string_literal(self, char *value): """new_string_literal(self, value:str) -> RValue""" c_rvalue = c_api.gcc_jit_context_new_string_literal(self._c_ctxt, value) - return RValue_from_c(self._c_ctxt, c_rvalue) + return RValue_from_c(self, c_rvalue) def new_unary_op(self, op, Type result_type, RValue rvalue, Location loc=None): """new_unary_op(self, op:UnaryOp, result_type:Type, rvalue:RValue, loc:Location=None) -> RValue""" @@ -316,7 +318,7 @@ cdef class Context: op, result_type._get_c_type(), rvalue._get_c_rvalue()) - return RValue_from_c(self._c_ctxt, c_rvalue) + return RValue_from_c(self, c_rvalue) def new_binary_op(self, op, Type result_type, RValue a, RValue b, Location loc=None): """new_binary_op(self, op:BinaryOp, result_type:Type, a:RValue, b:RValue, loc:Location=None) -> RValue""" @@ -326,7 +328,7 @@ cdef class Context: result_type._get_c_type(), a._get_c_rvalue(), b._get_c_rvalue()) - return RValue_from_c(self._c_ctxt, c_rvalue) + return RValue_from_c(self, c_rvalue) def new_comparison(self, op, RValue a, RValue b, Location loc=None): """new_comparison(self, op:Comparison, a:RValue, b:RValue, loc:Location=None) -> RValue""" @@ -336,7 +338,7 @@ cdef class Context: a._get_c_rvalue(), b._get_c_rvalue()) - return RValue_from_c(self._c_ctxt, c_rvalue) + return RValue_from_c(self, c_rvalue) def new_child_context(self): """new_child_context(self) -> Context""" @@ -354,7 +356,7 @@ cdef class Context: get_c_location(loc), rvalue._get_c_rvalue(), type_._get_c_type()) - return RValue_from_c(self._c_ctxt, c_rvalue) + return RValue_from_c(self, c_rvalue) def new_array_access(self, RValue ptr, RValue index, Location loc=None): """new_array_access(self, ptr:RValue, index:RValue, loc:Location=None) -> LValue""" @@ -362,7 +364,7 @@ cdef class Context: get_c_location(loc), ptr._get_c_rvalue(), index._get_c_rvalue()) - return LValue_from_c(self._c_ctxt, c_lvalue) + return LValue_from_c(self, c_lvalue) def new_call(self, Function func, args, Location loc=None): """new_call(self, func:Function, args:list of RValue, loc:Location=None) -> RValue""" @@ -385,7 +387,7 @@ cdef class Context: c_args) free(c_args) - return RValue_from_c(self._c_ctxt, c_rvalue) + return RValue_from_c(self, c_rvalue) def new_call_through_ptr(self, RValue fn_ptr, args, Location loc=None): """new_call(self, fn_ptr:RValue, args:list of RValue, loc:Location=None) -> RValue""" @@ -408,7 +410,7 @@ cdef class Context: c_args) free(c_args) - return RValue_from_c(self._c_ctxt, c_rvalue) + return RValue_from_c(self, c_rvalue) cdef class Result: cdef c_api.gcc_jit_result* _c_result @@ -428,8 +430,13 @@ cdef class Result: cdef class Object: cdef c_api.gcc_jit_object *_c_object + cdef object _py_context - def __cinit__(self): + def __cinit__(self, py_context): + self._c_object = NULL + self._py_context = weakref.ref(py_context, lambda unused: self._invalidate_ptr()) + + cdef _invalidate_ptr(self): self._c_object = NULL def __str__(self): @@ -448,6 +455,12 @@ cdef class Object: cdef c_api.gcc_jit_context* _get_c_context(self): return c_api.gcc_jit_object_get_context(self._c_object) + cdef _get_py_context(self): + ctxt = self._py_context() + if ctxt is None: + raise Error(b"parent context was destroyed") + return ctxt + cdef class Type(Object): cdef c_api.gcc_jit_type* _get_c_type(self): return self._c_object @@ -457,24 +470,24 @@ cdef class Type(Object): def get_pointer(self): """get_pointer(self) -> Type""" - return Type_from_c(self._get_c_context(), + return Type_from_c(self._get_py_context(), c_api.gcc_jit_type_get_pointer(self._get_c_type())) def get_const(self): """get_const(self) -> Type""" - return Type_from_c(self._get_c_context(), + return Type_from_c(self._get_py_context(), c_api.gcc_jit_type_get_const(self._get_c_type())) def get_volatile(self): """get_volatile(self) -> Type""" - return Type_from_c(self._get_c_context(), + return Type_from_c(self._get_py_context(), c_api.gcc_jit_type_get_volatile(self._get_c_type())) -cdef Type_from_c(c_api.gcc_jit_context *c_ctxt, +cdef Type_from_c(Context ctxt, c_api.gcc_jit_type *c_type): if c_type == NULL: - raise Error(c_api.gcc_jit_context_get_last_error(c_ctxt)) - t = Type() + raise Error(ctxt.get_last_error()) + t = Type(ctxt) t._set_c_type(c_type) return t @@ -540,27 +553,27 @@ cdef class RValue(Object): def dereference_field(self, Field field, Location loc=None): """dereference_field(self, field:Field, loc:Location=None) -> LValue""" - return LValue_from_c(self._get_c_context(), + return LValue_from_c(self._get_py_context(), c_api.gcc_jit_rvalue_dereference_field (self._get_c_rvalue(), get_c_location(loc), field._get_c_field())) def dereference(self, loc=None): """dereference(self, loc:Location=None) -> LValue""" - return LValue_from_c(self._get_c_context(), + return LValue_from_c(self._get_py_context(), c_api.gcc_jit_rvalue_dereference (self._get_c_rvalue(), get_c_location(loc))) def get_type(self): - return Type_from_c(self._get_c_context(), + return Type_from_c(self._get_py_context(), c_api.gcc_jit_rvalue_get_type (self._get_c_rvalue())) -cdef RValue RValue_from_c(c_api.gcc_jit_context *c_ctxt, +cdef RValue RValue_from_c(Context ctxt, c_api.gcc_jit_rvalue *c_rvalue): if c_rvalue == NULL: - raise Error(c_api.gcc_jit_context_get_last_error(c_ctxt)) + raise Error(ctxt.get_last_error()) - py_rvalue = RValue() + py_rvalue = RValue(ctxt) py_rvalue._set_c_rvalue(c_rvalue) return py_rvalue @@ -574,16 +587,16 @@ cdef class LValue(RValue): def get_address(self, Location loc=None): """get_address(self, loc:Location=None) -> RValue""" - return RValue_from_c(self._get_c_context(), + return RValue_from_c(self._get_py_context(), c_api.gcc_jit_lvalue_get_address(self._get_c_lvalue(), get_c_location(loc))) -cdef LValue LValue_from_c(c_api.gcc_jit_context *c_ctxt, +cdef LValue LValue_from_c(Context ctxt, c_api.gcc_jit_lvalue *c_lvalue): if c_lvalue == NULL: - raise Error(c_api.gcc_jit_context_get_last_error(c_ctxt)) + raise Error(ctxt.get_last_error()) - py_lvalue = LValue() + py_lvalue = LValue(ctxt) py_lvalue._set_c_lvalue(c_lvalue) return py_lvalue @@ -595,12 +608,12 @@ cdef class Param(LValue): cdef _set_c_param(self, c_api.gcc_jit_param* c_param): self._c_object = c_param -cdef Param Param_from_c(c_api.gcc_jit_context *c_ctxt, +cdef Param Param_from_c(Context ctxt, c_api.gcc_jit_param *c_param): if c_param == NULL: - raise Error(c_api.gcc_jit_context_get_last_error(c_ctxt)) + raise Error(ctxt.get_last_error()) - p = Param() + p = Param(ctxt) p._set_c_param(c_param) return p @@ -618,7 +631,7 @@ cdef class Function(Object): get_c_location(loc), type_._get_c_type(), name) - return LValue_from_c(self._get_c_context(), + return LValue_from_c(self._get_py_context(), c_lvalue) def new_block(self, name=None): @@ -632,14 +645,14 @@ cdef class Function(Object): c_name) if c_block == NULL: raise Error(c_api.gcc_jit_context_get_last_error(self._get_c_context())) - block = Block() + block = Block(self._get_py_context()) block._set_c_block(c_block) return block def get_param(self, index): """get_param(self, index:int) -> Param""" c_param = c_api.gcc_jit_function_get_param (self._get_c_function(), index) - return Param_from_c(self._get_c_context(), + return Param_from_c(self._get_py_context(), c_param) def dump_to_dot(self, char *path): @@ -647,11 +660,11 @@ cdef class Function(Object): c_api.gcc_jit_function_dump_to_dot (self._get_c_function(), path) -cdef Function Function_from_c(c_api.gcc_jit_context *c_ctxt, +cdef Function Function_from_c(Context ctxt, c_api.gcc_jit_function *c_function): if c_function == NULL: - raise Error(c_api.gcc_jit_context_get_last_error(c_ctxt)) - f = Function() + raise Error(ctxt.get_last_error()) + f = Function(ctxt) f._set_c_function(c_function) return f @@ -721,7 +734,7 @@ cdef class Block(Object): def get_function(self): """get_function(self) -> Function""" c_function = c_api.gcc_jit_block_get_function (self._get_c_block()) - return Function_from_c(self._get_c_context(), + return Function_from_c(self._get_py_context(), c_function) diff --git a/tests/test.py b/tests/test.py index 8db2166..a89d944 100644 --- a/tests/test.py +++ b/tests/test.py @@ -19,6 +19,7 @@ import os import tempfile import unittest +import gc import gccjit @@ -218,5 +219,17 @@ def test_new_block_error(self): (b'gcc_jit_function_new_block:' b' cannot add block to an imported function')) + # Verify that objects become invalid after Context is released. + def test_object_invalidation(self): + ctxt = gccjit.Context() + int_type = ctxt.get_type(gccjit.TypeKind.INT) + self.assertEqual(str(int_type), 'int') + del ctxt + gc.collect() + self.assertEqual(str(int_type), 'NULL') + with self.assertRaises(gccjit.Error) as cm: + int_type.get_const() + self.assertEqual(cm.exception.msg, b'parent context was destroyed') + if __name__ == '__main__': unittest.main()