diff --git a/gccjit/gccjit.pyx b/gccjit/gccjit.pyx
index 7371a28..ea5b485 100644
--- a/gccjit/gccjit.pyx
+++ b/gccjit/gccjit.pyx
@@ -17,6 +17,7 @@
# .
from libc.stdlib cimport malloc, free
+import weakref
cimport gccjit as c_api
class Error(Exception):
@@ -25,6 +26,7 @@ class Error(Exception):
cdef class Context:
cdef c_api.gcc_jit_context* _c_ctxt
+ cdef object __weakref__
def __cinit__(self, acquire=True):
if acquire:
@@ -49,12 +51,12 @@ cdef class Context:
def get_type(self, type_enum):
"""get_type(self, type_enum:TypeKind) -> Type"""
- return Type_from_c(self._c_ctxt,
+ return Type_from_c(self,
c_api.gcc_jit_context_get_type(self._c_ctxt, type_enum))
def get_int_type(self, num_bytes, is_signed):
"""get_int_type(self, num_bytes:int, is_signed:bool) -> Type"""
- return Type_from_c(self._c_ctxt,
+ return Type_from_c(self,
c_api.gcc_jit_context_get_int_type(self._c_ctxt, num_bytes, is_signed))
def compile(self):
@@ -63,7 +65,7 @@ cdef class Context:
c_result = c_api.gcc_jit_context_compile(self._c_ctxt)
if c_result == NULL:
raise Error(self.get_first_error())
- r = Result()
+ r = Result(self)
r._set_c_ptr(c_result)
return r
@@ -102,7 +104,7 @@ cdef class Context:
"""new_location(self, filename:str, line:int, column:int) -> Location"""
cdef c_api.gcc_jit_location *c_loc
c_loc = c_api.gcc_jit_context_new_location(self._c_ctxt, filename, line, column)
- loc = Location()
+ loc = Location(self)
loc._set_c_location(c_loc)
return loc
@@ -113,7 +115,7 @@ cdef class Context:
kind,
type_._get_c_type(),
name)
- return LValue_from_c(self._c_ctxt, c_lvalue)
+ return LValue_from_c(self, c_lvalue)
def new_array_type(self, Type element_type, int num_elements, Location loc=None):
"""new_array_type(self, element_type:Type, num_elements:int, loc:Location=None) -> Type"""
@@ -121,7 +123,7 @@ cdef class Context:
get_c_location(loc),
element_type._get_c_type(),
num_elements)
- return Type_from_c(self._c_ctxt,
+ return Type_from_c(self,
c_type)
def new_field(self, Type type_, name, Location loc=None):
@@ -130,7 +132,7 @@ cdef class Context:
get_c_location(loc),
type_._get_c_type(),
name)
- field = Field()
+ field = Field(self)
field._set_c_field(c_field)
return field
@@ -163,7 +165,7 @@ cdef class Context:
name,
num_fields,
c_fields)
- py_struct = Struct()
+ py_struct = Struct(self)
py_struct._set_c_struct(c_struct)
free(c_fields)
return py_struct
@@ -192,7 +194,7 @@ cdef class Context:
name,
num_fields,
c_fields)
- py_type = Type()
+ py_type = Type(self)
py_type._set_c_type(c_type)
free(c_fields)
return py_type
@@ -222,7 +224,7 @@ cdef class Context:
num_params,
c_param_types,
is_variadic)
- py_type = Type()
+ py_type = Type(self)
py_type._set_c_type(c_fn_ptr_type)
free(c_param_types)
return py_type
@@ -233,7 +235,7 @@ cdef class Context:
get_c_location(loc),
type_._get_c_type(),
name)
- return Param_from_c(self._c_ctxt, c_result)
+ return Param_from_c(self, c_result)
def new_function(self, kind, Type return_type, name, params,
Location loc=None,
@@ -258,56 +260,56 @@ cdef class Context:
c_params,
is_variadic)
free(c_params)
- return Function_from_c(self._c_ctxt, c_function)
+ return Function_from_c(self, c_function)
def get_builtin_function(self, name):
"""get_builtin_function(self, name:str) -> Function"""
c_function = c_api.gcc_jit_context_get_builtin_function (self._c_ctxt, name)
- return Function_from_c(self._c_ctxt, c_function)
+ return Function_from_c(self, c_function)
def zero(self, Type type_):
"""zero(self, type_:Type) -> RValue"""
c_rvalue = c_api.gcc_jit_context_zero(self._c_ctxt,
type_._get_c_type())
- return RValue_from_c(self._c_ctxt, c_rvalue)
+ return RValue_from_c(self, c_rvalue)
def one(self, Type type_):
"""one(self, type_:Type) -> RValue"""
c_rvalue = c_api.gcc_jit_context_one(self._c_ctxt,
type_._get_c_type())
- return RValue_from_c(self._c_ctxt, c_rvalue)
+ return RValue_from_c(self, c_rvalue)
def new_rvalue_from_double(self, Type numeric_type, double value):
"""new_rvalue_from_double(self, numeric_type:Type, value:float) -> RValue"""
c_rvalue = c_api.gcc_jit_context_new_rvalue_from_double(self._c_ctxt,
numeric_type._get_c_type(),
value)
- return RValue_from_c(self._c_ctxt, c_rvalue)
+ return RValue_from_c(self, c_rvalue)
def new_rvalue_from_int(self, Type type_, int value):
"""new_rvalue_from_int(self, type_:Type, value:int) -> RValue"""
c_rvalue = c_api.gcc_jit_context_new_rvalue_from_int(self._c_ctxt,
type_._get_c_type(),
value)
- return RValue_from_c(self._c_ctxt, c_rvalue)
+ return RValue_from_c(self, c_rvalue)
def new_rvalue_from_ptr(self, Type pointer_type, long long value):
c_rvalue = c_api.gcc_jit_context_new_rvalue_from_ptr(self._c_ctxt,
pointer_type._get_c_type(),
value)
- return RValue_from_c(self._c_ctxt, c_rvalue)
+ return RValue_from_c(self, c_rvalue)
def null(self, Type pointer_type):
"""null(self, pointer_type:Type) -> RValue"""
c_rvalue = c_api.gcc_jit_context_null(self._c_ctxt,
pointer_type._get_c_type())
- return RValue_from_c(self._c_ctxt, c_rvalue)
+ return RValue_from_c(self, c_rvalue)
def new_string_literal(self, char *value):
"""new_string_literal(self, value:str) -> RValue"""
c_rvalue = c_api.gcc_jit_context_new_string_literal(self._c_ctxt,
value)
- return RValue_from_c(self._c_ctxt, c_rvalue)
+ return RValue_from_c(self, c_rvalue)
def new_unary_op(self, op, Type result_type, RValue rvalue, Location loc=None):
"""new_unary_op(self, op:UnaryOp, result_type:Type, rvalue:RValue, loc:Location=None) -> RValue"""
@@ -316,7 +318,7 @@ cdef class Context:
op,
result_type._get_c_type(),
rvalue._get_c_rvalue())
- return RValue_from_c(self._c_ctxt, c_rvalue)
+ return RValue_from_c(self, c_rvalue)
def new_binary_op(self, op, Type result_type, RValue a, RValue b, Location loc=None):
"""new_binary_op(self, op:BinaryOp, result_type:Type, a:RValue, b:RValue, loc:Location=None) -> RValue"""
@@ -326,7 +328,7 @@ cdef class Context:
result_type._get_c_type(),
a._get_c_rvalue(),
b._get_c_rvalue())
- return RValue_from_c(self._c_ctxt, c_rvalue)
+ return RValue_from_c(self, c_rvalue)
def new_comparison(self, op, RValue a, RValue b, Location loc=None):
"""new_comparison(self, op:Comparison, a:RValue, b:RValue, loc:Location=None) -> RValue"""
@@ -336,7 +338,7 @@ cdef class Context:
a._get_c_rvalue(),
b._get_c_rvalue())
- return RValue_from_c(self._c_ctxt, c_rvalue)
+ return RValue_from_c(self, c_rvalue)
def new_child_context(self):
"""new_child_context(self) -> Context"""
@@ -354,7 +356,7 @@ cdef class Context:
get_c_location(loc),
rvalue._get_c_rvalue(),
type_._get_c_type())
- return RValue_from_c(self._c_ctxt, c_rvalue)
+ return RValue_from_c(self, c_rvalue)
def new_array_access(self, RValue ptr, RValue index, Location loc=None):
"""new_array_access(self, ptr:RValue, index:RValue, loc:Location=None) -> LValue"""
@@ -362,7 +364,7 @@ cdef class Context:
get_c_location(loc),
ptr._get_c_rvalue(),
index._get_c_rvalue())
- return LValue_from_c(self._c_ctxt, c_lvalue)
+ return LValue_from_c(self, c_lvalue)
def new_call(self, Function func, args, Location loc=None):
"""new_call(self, func:Function, args:list of RValue, loc:Location=None) -> RValue"""
@@ -385,7 +387,7 @@ cdef class Context:
c_args)
free(c_args)
- return RValue_from_c(self._c_ctxt, c_rvalue)
+ return RValue_from_c(self, c_rvalue)
def new_call_through_ptr(self, RValue fn_ptr, args, Location loc=None):
"""new_call(self, fn_ptr:RValue, args:list of RValue, loc:Location=None) -> RValue"""
@@ -408,7 +410,7 @@ cdef class Context:
c_args)
free(c_args)
- return RValue_from_c(self._c_ctxt, c_rvalue)
+ return RValue_from_c(self, c_rvalue)
cdef class Result:
cdef c_api.gcc_jit_result* _c_result
@@ -428,8 +430,13 @@ cdef class Result:
cdef class Object:
cdef c_api.gcc_jit_object *_c_object
+ cdef object _py_context
- def __cinit__(self):
+ def __cinit__(self, py_context):
+ self._c_object = NULL
+ self._py_context = weakref.ref(py_context, lambda unused: self._invalidate_ptr())
+
+ cdef _invalidate_ptr(self):
self._c_object = NULL
def __str__(self):
@@ -448,6 +455,12 @@ cdef class Object:
cdef c_api.gcc_jit_context* _get_c_context(self):
return c_api.gcc_jit_object_get_context(self._c_object)
+ cdef _get_py_context(self):
+ ctxt = self._py_context()
+ if ctxt is None:
+ raise Error(b"parent context was destroyed")
+ return ctxt
+
cdef class Type(Object):
cdef c_api.gcc_jit_type* _get_c_type(self):
return self._c_object
@@ -457,24 +470,24 @@ cdef class Type(Object):
def get_pointer(self):
"""get_pointer(self) -> Type"""
- return Type_from_c(self._get_c_context(),
+ return Type_from_c(self._get_py_context(),
c_api.gcc_jit_type_get_pointer(self._get_c_type()))
def get_const(self):
"""get_const(self) -> Type"""
- return Type_from_c(self._get_c_context(),
+ return Type_from_c(self._get_py_context(),
c_api.gcc_jit_type_get_const(self._get_c_type()))
def get_volatile(self):
"""get_volatile(self) -> Type"""
- return Type_from_c(self._get_c_context(),
+ return Type_from_c(self._get_py_context(),
c_api.gcc_jit_type_get_volatile(self._get_c_type()))
-cdef Type_from_c(c_api.gcc_jit_context *c_ctxt,
+cdef Type_from_c(Context ctxt,
c_api.gcc_jit_type *c_type):
if c_type == NULL:
- raise Error(c_api.gcc_jit_context_get_last_error(c_ctxt))
- t = Type()
+ raise Error(ctxt.get_last_error())
+ t = Type(ctxt)
t._set_c_type(c_type)
return t
@@ -540,27 +553,27 @@ cdef class RValue(Object):
def dereference_field(self, Field field, Location loc=None):
"""dereference_field(self, field:Field, loc:Location=None) -> LValue"""
- return LValue_from_c(self._get_c_context(),
+ return LValue_from_c(self._get_py_context(),
c_api.gcc_jit_rvalue_dereference_field (self._get_c_rvalue(),
get_c_location(loc),
field._get_c_field()))
def dereference(self, loc=None):
"""dereference(self, loc:Location=None) -> LValue"""
- return LValue_from_c(self._get_c_context(),
+ return LValue_from_c(self._get_py_context(),
c_api.gcc_jit_rvalue_dereference (self._get_c_rvalue(),
get_c_location(loc)))
def get_type(self):
- return Type_from_c(self._get_c_context(),
+ return Type_from_c(self._get_py_context(),
c_api.gcc_jit_rvalue_get_type (self._get_c_rvalue()))
-cdef RValue RValue_from_c(c_api.gcc_jit_context *c_ctxt,
+cdef RValue RValue_from_c(Context ctxt,
c_api.gcc_jit_rvalue *c_rvalue):
if c_rvalue == NULL:
- raise Error(c_api.gcc_jit_context_get_last_error(c_ctxt))
+ raise Error(ctxt.get_last_error())
- py_rvalue = RValue()
+ py_rvalue = RValue(ctxt)
py_rvalue._set_c_rvalue(c_rvalue)
return py_rvalue
@@ -574,16 +587,16 @@ cdef class LValue(RValue):
def get_address(self, Location loc=None):
"""get_address(self, loc:Location=None) -> RValue"""
- return RValue_from_c(self._get_c_context(),
+ return RValue_from_c(self._get_py_context(),
c_api.gcc_jit_lvalue_get_address(self._get_c_lvalue(),
get_c_location(loc)))
-cdef LValue LValue_from_c(c_api.gcc_jit_context *c_ctxt,
+cdef LValue LValue_from_c(Context ctxt,
c_api.gcc_jit_lvalue *c_lvalue):
if c_lvalue == NULL:
- raise Error(c_api.gcc_jit_context_get_last_error(c_ctxt))
+ raise Error(ctxt.get_last_error())
- py_lvalue = LValue()
+ py_lvalue = LValue(ctxt)
py_lvalue._set_c_lvalue(c_lvalue)
return py_lvalue
@@ -595,12 +608,12 @@ cdef class Param(LValue):
cdef _set_c_param(self, c_api.gcc_jit_param* c_param):
self._c_object = c_param
-cdef Param Param_from_c(c_api.gcc_jit_context *c_ctxt,
+cdef Param Param_from_c(Context ctxt,
c_api.gcc_jit_param *c_param):
if c_param == NULL:
- raise Error(c_api.gcc_jit_context_get_last_error(c_ctxt))
+ raise Error(ctxt.get_last_error())
- p = Param()
+ p = Param(ctxt)
p._set_c_param(c_param)
return p
@@ -618,7 +631,7 @@ cdef class Function(Object):
get_c_location(loc),
type_._get_c_type(),
name)
- return LValue_from_c(self._get_c_context(),
+ return LValue_from_c(self._get_py_context(),
c_lvalue)
def new_block(self, name=None):
@@ -632,14 +645,14 @@ cdef class Function(Object):
c_name)
if c_block == NULL:
raise Error(c_api.gcc_jit_context_get_last_error(self._get_c_context()))
- block = Block()
+ block = Block(self._get_py_context())
block._set_c_block(c_block)
return block
def get_param(self, index):
"""get_param(self, index:int) -> Param"""
c_param = c_api.gcc_jit_function_get_param (self._get_c_function(), index)
- return Param_from_c(self._get_c_context(),
+ return Param_from_c(self._get_py_context(),
c_param)
def dump_to_dot(self, char *path):
@@ -647,11 +660,11 @@ cdef class Function(Object):
c_api.gcc_jit_function_dump_to_dot (self._get_c_function(),
path)
-cdef Function Function_from_c(c_api.gcc_jit_context *c_ctxt,
+cdef Function Function_from_c(Context ctxt,
c_api.gcc_jit_function *c_function):
if c_function == NULL:
- raise Error(c_api.gcc_jit_context_get_last_error(c_ctxt))
- f = Function()
+ raise Error(ctxt.get_last_error())
+ f = Function(ctxt)
f._set_c_function(c_function)
return f
@@ -721,7 +734,7 @@ cdef class Block(Object):
def get_function(self):
"""get_function(self) -> Function"""
c_function = c_api.gcc_jit_block_get_function (self._get_c_block())
- return Function_from_c(self._get_c_context(),
+ return Function_from_c(self._get_py_context(),
c_function)
diff --git a/tests/test.py b/tests/test.py
index 8db2166..a89d944 100644
--- a/tests/test.py
+++ b/tests/test.py
@@ -19,6 +19,7 @@
import os
import tempfile
import unittest
+import gc
import gccjit
@@ -218,5 +219,17 @@ def test_new_block_error(self):
(b'gcc_jit_function_new_block:'
b' cannot add block to an imported function'))
+ # Verify that objects become invalid after Context is released.
+ def test_object_invalidation(self):
+ ctxt = gccjit.Context()
+ int_type = ctxt.get_type(gccjit.TypeKind.INT)
+ self.assertEqual(str(int_type), 'int')
+ del ctxt
+ gc.collect()
+ self.assertEqual(str(int_type), 'NULL')
+ with self.assertRaises(gccjit.Error) as cm:
+ int_type.get_const()
+ self.assertEqual(cm.exception.msg, b'parent context was destroyed')
+
if __name__ == '__main__':
unittest.main()