Skip to content

Commit 0c2bece

Browse files
committed
Add basic code for binding partition revalidation
This adds the binding partition revalidation code from #54654. This is the last piece of that PR that hasn't been merged yet - however the TODO in that PR still stands for future work. This PR itself adds a callback that gets triggered by deleting a binding. It will then walk all code in the system and invalidate code instances of Methods whose lowered source referenced the given global. This walk is quite slow. Future work will add backedges and optimizations to make this faster, but the basic functionality should be in place with this PR.
1 parent ec2b509 commit 0c2bece

File tree

5 files changed

+148
-2
lines changed

5 files changed

+148
-2
lines changed

base/Base_compiler.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,7 @@ include("ordering.jl")
255255
using .Order
256256

257257
include("coreir.jl")
258+
include("invalidation.jl")
258259

259260
# For OS specific stuff
260261
# We need to strcat things here, before strings are really defined

base/invalidation.jl

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# This file is a part of Julia. License is MIT: https://julialang.org/license
2+
3+
struct GlobalRefIterator
4+
mod::Module
5+
end
6+
IteratorSize(::Type{GlobalRefIterator}) = SizeUnknown()
7+
globalrefs(mod::Module) = GlobalRefIterator(mod)
8+
9+
function iterate(gri::GlobalRefIterator, i = 1)
10+
m = gri.mod
11+
table = ccall(:jl_module_get_bindings, Ref{SimpleVector}, (Any,), m)
12+
i == length(table) && return nothing
13+
b = table[i]
14+
b === nothing && return iterate(gri, i+1)
15+
return ((b::Core.Binding).globalref, i+1)
16+
end
17+
18+
const TYPE_TYPE_MT = Type.body.name.mt
19+
const NONFUNCTION_MT = Core.MethodTable.name.mt
20+
function foreach_module_mtable(visit, m::Module, world::UInt)
21+
for gb in globalrefs(m)
22+
binding = gb.binding
23+
bpart = lookup_binding_partition(world, binding)
24+
if is_defined_const_binding(binding_kind(bpart))
25+
v = partition_restriction(bpart)
26+
uw = unwrap_unionall(v)
27+
name = gb.name
28+
if isa(uw, DataType)
29+
tn = uw.name
30+
if tn.module === m && tn.name === name && tn.wrapper === v && isdefined(tn, :mt)
31+
# this is the original/primary binding for the type (name/wrapper)
32+
mt = tn.mt
33+
if mt !== nothing && mt !== TYPE_TYPE_MT && mt !== NONFUNCTION_MT
34+
@assert mt.module === m
35+
visit(mt) || return false
36+
end
37+
end
38+
elseif isa(v, Module) && v !== m && parentmodule(v) === m && _nameof(v) === name
39+
# this is the original/primary binding for the submodule
40+
foreach_module_mtable(visit, v, world) || return false
41+
elseif isa(v, Core.MethodTable) && v.module === m && v.name === name
42+
# this is probably an external method table here, so let's
43+
# assume so as there is no way to precisely distinguish them
44+
visit(v) || return false
45+
end
46+
end
47+
end
48+
return true
49+
end
50+
51+
function foreach_reachable_mtable(visit, world::UInt)
52+
visit(TYPE_TYPE_MT) || return
53+
visit(NONFUNCTION_MT) || return
54+
for mod in loaded_modules_array()
55+
foreach_module_mtable(visit, mod, world)
56+
end
57+
end
58+
59+
function should_invalidate_code_for_globalref(gr::GlobalRef, src::CodeInfo)
60+
found_any = false
61+
labelchangemap = nothing
62+
stmts = src.code
63+
isgr(g::GlobalRef) = gr.mod == g.mod && gr.name === g.name
64+
isgr(g) = false
65+
for i = 1:length(stmts)
66+
stmt = stmts[i]
67+
if isgr(stmt)
68+
found_any = true
69+
continue
70+
end
71+
for ur in Compiler.userefs(stmt)
72+
arg = ur[]
73+
# If any of the GlobalRefs in this stmt match the one that
74+
# we are about, we need to move out all GlobalRefs to preserve
75+
# effect order, in case we later invalidate a different GR
76+
if isa(arg, GlobalRef)
77+
if isgr(arg)
78+
@assert !isa(stmt, PhiNode)
79+
found_any = true
80+
break
81+
end
82+
end
83+
end
84+
end
85+
return found_any
86+
end
87+
88+
function invalidate_code_for_globalref!(gr::GlobalRef, new_max_world::UInt)
89+
valid_in_valuepos = false
90+
foreach_reachable_mtable(new_max_world) do mt::Core.MethodTable
91+
for method in MethodList(mt)
92+
if isdefined(method, :source)
93+
src = _uncompressed_ir(method)
94+
old_stmts = src.code
95+
if should_invalidate_code_for_globalref(gr, src)
96+
for mi in specializations(method)
97+
ci = mi.cache
98+
while true
99+
if ci.max_world > new_max_world
100+
ccall(:jl_invalidate_code_instance, Cvoid, (Any, UInt), ci, new_max_world)
101+
end
102+
isdefined(ci, :next) || break
103+
ci = ci.next
104+
end
105+
end
106+
end
107+
end
108+
end
109+
return true
110+
end
111+
end

src/gf.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1870,6 +1870,11 @@ static void invalidate_code_instance(jl_code_instance_t *replaced, size_t max_wo
18701870
JL_UNLOCK(&replaced_mi->def.method->writelock);
18711871
}
18721872

1873+
JL_DLLEXPORT void jl_invalidate_code_instance(jl_code_instance_t *replaced, size_t max_world)
1874+
{
1875+
invalidate_code_instance(replaced, max_world, 1);
1876+
}
1877+
18731878
static void _invalidate_backedges(jl_method_instance_t *replaced_mi, size_t max_world, int depth) {
18741879
jl_array_t *backedges = replaced_mi->backedges;
18751880
if (backedges) {

src/module.c

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1032,6 +1032,21 @@ JL_DLLEXPORT void jl_set_const(jl_module_t *m JL_ROOTING_ARGUMENT, jl_sym_t *var
10321032
jl_gc_wb(bpart, val);
10331033
}
10341034

1035+
void jl_invalidate_binding_refs(jl_globalref_t *ref, size_t new_world)
1036+
{
1037+
static jl_value_t *invalidate_code_for_globalref = NULL;
1038+
if (invalidate_code_for_globalref == NULL && jl_base_module != NULL)
1039+
invalidate_code_for_globalref = jl_get_global(jl_base_module, jl_symbol("invalidate_code_for_globalref!"));
1040+
if (!invalidate_code_for_globalref)
1041+
jl_error("Binding invalidation is not permitted during bootstrap.");
1042+
if (jl_generating_output())
1043+
jl_error("Binding invalidation is not permitted during image generation.");
1044+
jl_value_t *boxed_world = jl_box_ulong(new_world);
1045+
JL_GC_PUSH1(&boxed_world);
1046+
jl_call2((jl_function_t*)invalidate_code_for_globalref, (jl_value_t*)ref, boxed_world);
1047+
JL_GC_POP();
1048+
}
1049+
10351050
extern jl_mutex_t world_counter_lock;
10361051
JL_DLLEXPORT void jl_disable_binding(jl_globalref_t *gr)
10371052
{
@@ -1046,9 +1061,11 @@ JL_DLLEXPORT void jl_disable_binding(jl_globalref_t *gr)
10461061

10471062
JL_LOCK(&world_counter_lock);
10481063
jl_task_t *ct = jl_current_task;
1064+
size_t last_world = ct->world_age;
10491065
size_t new_max_world = jl_atomic_load_acquire(&jl_world_counter);
1050-
// TODO: Trigger invalidation here
1051-
(void)ct;
1066+
ct->world_age = jl_typeinf_world;
1067+
jl_invalidate_binding_refs(gr, new_max_world);
1068+
ct->world_age = last_world;
10521069
jl_atomic_store_release(&bpart->max_world, new_max_world);
10531070
jl_atomic_store_release(&jl_world_counter, new_max_world + 1);
10541071
JL_UNLOCK(&world_counter_lock);
@@ -1334,6 +1351,11 @@ JL_DLLEXPORT void jl_add_to_module_init_list(jl_value_t *mod)
13341351
jl_array_ptr_1d_push(jl_module_init_order, mod);
13351352
}
13361353

1354+
JL_DLLEXPORT jl_svec_t *jl_module_get_bindings(jl_module_t *m)
1355+
{
1356+
return jl_atomic_load_relaxed(&m->bindings);
1357+
}
1358+
13371359
JL_DLLEXPORT void jl_init_restored_module(jl_value_t *mod)
13381360
{
13391361
if (!jl_generating_output() || jl_options.incremental) {

test/rebinding.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,11 @@ module Rebinding
3333
@test Base.@world(Foo, defined_world_age) == typeof(x)
3434
@test Base.@world(Rebinding.Foo, defined_world_age) == typeof(x)
3535
@test Base.@world((@__MODULE__).Foo, defined_world_age) == typeof(x)
36+
37+
# Test invalidation (const -> undefined)
38+
const delete_me = 1
39+
f_return_delete_me() = delete_me
40+
@test f_return_delete_me() == 1
41+
Base.delete_binding(@__MODULE__, :delete_me)
42+
@test_throws UndefVarError f_return_delete_me()
3643
end

0 commit comments

Comments
 (0)