Skip to content

Commit 726e037

Browse files
committed
Add basic code for binding partition revalidation
This adds the binding partition revalidation code from #54654. This is the last piece of that PR that hasn't been merged yet - however the TODO in that PR still stands for future work. This PR itself adds a callback that gets triggered by deleting a binding. It will then walk all code in the system and invalidate code instances of Methods whose lowered source referenced the given global. This walk is quite slow. Future work will add backedges and optimizations to make this faster, but the basic functionality should be in place with this PR.
1 parent e624440 commit 726e037

File tree

5 files changed

+206
-2
lines changed

5 files changed

+206
-2
lines changed

base/Base_compiler.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ include("ordering.jl")
253253
using .Order
254254

255255
include("coreir.jl")
256+
include("invalidation.jl")
256257

257258
# For OS specific stuff
258259
# We need to strcat things here, before strings are really defined

base/invalidation.jl

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# This file is a part of Julia. License is MIT: https://julialang.org/license
2+
3+
struct GlobalRefIterator
4+
mod::Module
5+
end
6+
IteratorSize(::Type{GlobalRefIterator}) = SizeUnknown()
7+
globalrefs(mod::Module) = GlobalRefIterator(mod)
8+
9+
function iterate(gri::GlobalRefIterator, i = 1)
10+
m = gri.mod
11+
table = ccall(:jl_module_get_bindings, Ref{SimpleVector}, (Any,), m)
12+
i == length(table) && return nothing
13+
b = table[i]
14+
b === nothing && return iterate(gri, i+1)
15+
return ((b::Core.Binding).globalref, i+1)
16+
end
17+
18+
const TYPE_TYPE_MT = Type.body.name.mt
19+
const NONFUNCTION_MT = Core.MethodTable.name.mt
20+
function foreach_module_mtable(visit, m::Module, world::UInt)
21+
for gb in globalrefs(m)
22+
binding = gb.binding
23+
bpart = lookup_binding_partition(world, binding)
24+
if is_some_const_binding(binding_kind(bpart))
25+
isdefined(bpart, :restriction) || continue
26+
v = partition_restriction(bpart)
27+
uw = unwrap_unionall(v)
28+
name = gb.name
29+
if isa(uw, DataType)
30+
tn = uw.name
31+
if tn.module === m && tn.name === name && tn.wrapper === v && isdefined(tn, :mt)
32+
# this is the original/primary binding for the type (name/wrapper)
33+
mt = tn.mt
34+
if mt !== nothing && mt !== TYPE_TYPE_MT && mt !== NONFUNCTION_MT
35+
@assert mt.module === m
36+
visit(mt) || return false
37+
end
38+
end
39+
elseif isa(v, Module) && v !== m && parentmodule(v) === m && _nameof(v) === name
40+
# this is the original/primary binding for the submodule
41+
foreach_module_mtable(visit, v, world) || return false
42+
elseif isa(v, Core.MethodTable) && v.module === m && v.name === name
43+
# this is probably an external method table here, so let's
44+
# assume so as there is no way to precisely distinguish them
45+
visit(v) || return false
46+
end
47+
end
48+
end
49+
return true
50+
end
51+
52+
function foreach_reachable_mtable(visit, world::UInt)
53+
visit(TYPE_TYPE_MT) || return
54+
visit(NONFUNCTION_MT) || return
55+
for mod in loaded_modules_array()
56+
foreach_module_mtable(visit, mod, world)
57+
end
58+
end
59+
60+
function invalidate_code_for_globalref!(gr::GlobalRef, src::CodeInfo)
61+
found_any = false
62+
labelchangemap = nothing
63+
stmts = src.code
64+
function get_labelchangemap()
65+
if labelchangemap === nothing
66+
labelchangemap = fill(0, length(stmts))
67+
end
68+
labelchangemap
69+
end
70+
isgr(g::GlobalRef) = gr.mod == g.mod && gr.name === g.name
71+
isgr(g) = false
72+
for i = 1:length(stmts)
73+
stmt = stmts[i]
74+
if isgr(stmt)
75+
found_any = true
76+
continue
77+
end
78+
found_arg = false
79+
ngrs = 0
80+
for ur in Compiler.userefs(stmt)
81+
arg = ur[]
82+
# If any of the GlobalRefs in this stmt match the one that
83+
# we are about, we need to move out all GlobalRefs to preserve
84+
# effect order, in case we later invalidate a different GR
85+
if isa(arg, GlobalRef)
86+
ngrs += 1
87+
if isgr(arg)
88+
@assert !isa(stmt, PhiNode)
89+
found_arg = found_any = true
90+
break
91+
end
92+
end
93+
end
94+
if found_arg
95+
get_labelchangemap()[i] += ngrs
96+
end
97+
end
98+
next_empty_idx = 1
99+
if labelchangemap !== nothing
100+
Compiler.cumsum_ssamap!(labelchangemap)
101+
new_stmts = Vector(undef, length(stmts)+labelchangemap[end])
102+
new_ssaflags = Vector{UInt32}(undef, length(new_stmts))
103+
new_debuginfo = Compiler.DebugInfoStream(nothing, src.debuginfo, length(new_stmts))
104+
new_debuginfo.def = src.debuginfo.def
105+
for i = 1:length(stmts)
106+
stmt = stmts[i]
107+
urs = Compiler.userefs(stmt)
108+
new_stmt_idx = i+labelchangemap[i]
109+
for ur in urs
110+
arg = ur[]
111+
if isa(arg, SSAValue)
112+
ur[] = SSAValue(arg.id + labelchangemap[arg.id])
113+
elseif next_empty_idx != new_stmt_idx && isa(arg, GlobalRef)
114+
new_debuginfo.codelocs[3next_empty_idx - 2] = i
115+
new_stmts[next_empty_idx] = arg
116+
new_ssaflags[next_empty_idx] = UInt32(0)
117+
ur[] = SSAValue(next_empty_idx)
118+
next_empty_idx += 1
119+
end
120+
end
121+
@assert new_stmt_idx == next_empty_idx
122+
new_stmts[new_stmt_idx] = urs[]
123+
new_debuginfo.codelocs[3new_stmt_idx - 2] = i
124+
new_ssaflags[new_stmt_idx] = src.ssaflags[i]
125+
next_empty_idx = new_stmt_idx+1
126+
end
127+
src.code = new_stmts
128+
src.ssavaluetypes = length(new_stmts)
129+
src.ssaflags = new_ssaflags
130+
src.debuginfo = Core.DebugInfo(new_debuginfo, length(new_stmts))
131+
end
132+
return found_any
133+
end
134+
135+
function invalidate_code_for_globalref!(gr::GlobalRef, new_max_world::UInt)
136+
valid_in_valuepos = false
137+
foreach_reachable_mtable(new_max_world) do mt::Core.MethodTable
138+
for method in MethodList(mt)
139+
if isdefined(method, :source)
140+
src = _uncompressed_ir(method)
141+
old_stmts = src.code
142+
if invalidate_code_for_globalref!(gr, src)
143+
if src.code !== old_stmts
144+
method.debuginfo = src.debuginfo
145+
method.source = src
146+
method.source = ccall(:jl_compress_ir, Ref{String}, (Any, Ptr{Cvoid}), method, C_NULL)
147+
end
148+
149+
for mi in specializations(method)
150+
ci = mi.cache
151+
while true
152+
if ci.max_world > new_max_world
153+
ccall(:jl_invalidate_code_instance, Cvoid, (Any, UInt), ci, new_max_world)
154+
end
155+
isdefined(ci, :next) || break
156+
ci = ci.next
157+
end
158+
end
159+
end
160+
end
161+
end
162+
return true
163+
end
164+
end

src/gf.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1784,6 +1784,11 @@ static void invalidate_code_instance(jl_code_instance_t *replaced, size_t max_wo
17841784
JL_UNLOCK(&replaced->def->def.method->writelock);
17851785
}
17861786

1787+
JL_DLLEXPORT void jl_invalidate_code_instance(jl_code_instance_t *replaced, size_t max_world)
1788+
{
1789+
invalidate_code_instance(replaced, max_world, 1);
1790+
}
1791+
17871792
static void _invalidate_backedges(jl_method_instance_t *replaced_mi, size_t max_world, int depth) {
17881793
jl_array_t *backedges = replaced_mi->backedges;
17891794
if (backedges) {

src/module.c

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,6 +1025,21 @@ JL_DLLEXPORT void jl_set_const(jl_module_t *m JL_ROOTING_ARGUMENT, jl_sym_t *var
10251025
jl_gc_wb(bpart, val);
10261026
}
10271027

1028+
void jl_invalidate_binding_refs(jl_globalref_t *ref, size_t new_world)
1029+
{
1030+
static jl_value_t *invalidate_code_for_globalref = NULL;
1031+
if (invalidate_code_for_globalref == NULL && jl_base_module != NULL)
1032+
invalidate_code_for_globalref = jl_get_global(jl_base_module, jl_symbol("invalidate_code_for_globalref!"));
1033+
if (!invalidate_code_for_globalref)
1034+
jl_error("Binding invalidation is not permitted during bootstrap.");
1035+
if (jl_generating_output())
1036+
jl_error("Binding invalidation is not permitted during image generation.");
1037+
jl_value_t *boxed_world = jl_box_ulong(new_world);
1038+
JL_GC_PUSH1(&boxed_world);
1039+
jl_call2((jl_function_t*)invalidate_code_for_globalref, (jl_value_t*)ref, boxed_world);
1040+
JL_GC_POP();
1041+
}
1042+
10281043
extern jl_mutex_t world_counter_lock;
10291044
JL_DLLEXPORT void jl_disable_binding(jl_globalref_t *gr)
10301045
{
@@ -1039,9 +1054,16 @@ JL_DLLEXPORT void jl_disable_binding(jl_globalref_t *gr)
10391054

10401055
JL_LOCK(&world_counter_lock);
10411056
jl_task_t *ct = jl_current_task;
1057+
size_t last_world = ct->world_age;
10421058
size_t new_max_world = jl_atomic_load_acquire(&jl_world_counter);
1043-
// TODO: Trigger invalidation here
1044-
(void)ct;
1059+
JL_TRY {
1060+
ct->world_age = jl_typeinf_world;
1061+
jl_invalidate_binding_refs(gr, new_max_world);
1062+
} JL_CATCH {
1063+
JL_UNLOCK(&world_counter_lock);
1064+
jl_rethrow();
1065+
}
1066+
ct->world_age = last_world;
10451067
jl_atomic_store_release(&bpart->max_world, new_max_world);
10461068
jl_atomic_store_release(&jl_world_counter, new_max_world + 1);
10471069
JL_UNLOCK(&world_counter_lock);
@@ -1327,6 +1349,11 @@ JL_DLLEXPORT void jl_add_to_module_init_list(jl_value_t *mod)
13271349
jl_array_ptr_1d_push(jl_module_init_order, mod);
13281350
}
13291351

1352+
JL_DLLEXPORT jl_svec_t *jl_module_get_bindings(jl_module_t *m)
1353+
{
1354+
return jl_atomic_load_relaxed(&m->bindings);
1355+
}
1356+
13301357
JL_DLLEXPORT void jl_init_restored_module(jl_value_t *mod)
13311358
{
13321359
if (!jl_generating_output() || jl_options.incremental) {

test/rebinding.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,11 @@ module Rebinding
3333
@test Base.@world(Foo, defined_world_age) == typeof(x)
3434
@test Base.@world(Rebinding.Foo, defined_world_age) == typeof(x)
3535
@test Base.@world((@__MODULE__).Foo, defined_world_age) == typeof(x)
36+
37+
# Test invalidation (const -> undefined)
38+
const delete_me = 1
39+
f_return_delete_me() = delete_me
40+
@test f_return_delete_me() == 1
41+
Base.delete_binding(@__MODULE__, :delete_me)
42+
@test_throws UndefVarError f_return_delete_me()
3643
end

0 commit comments

Comments
 (0)