|
| 1 | +# This file is a part of Julia. License is MIT: https://julialang.org/license |
| 2 | + |
| 3 | +struct GlobalRefIterator |
| 4 | + mod::Module |
| 5 | +end |
| 6 | +IteratorSize(::Type{GlobalRefIterator}) = SizeUnknown() |
| 7 | +globalrefs(mod::Module) = GlobalRefIterator(mod) |
| 8 | + |
| 9 | +function iterate(gri::GlobalRefIterator, i = 1) |
| 10 | + m = gri.mod |
| 11 | + table = ccall(:jl_module_get_bindings, Ref{SimpleVector}, (Any,), m) |
| 12 | + i == length(table) && return nothing |
| 13 | + b = table[i] |
| 14 | + b === nothing && return iterate(gri, i+1) |
| 15 | + return ((b::Core.Binding).globalref, i+1) |
| 16 | +end |
| 17 | + |
| 18 | +const TYPE_TYPE_MT = Type.body.name.mt |
| 19 | +const NONFUNCTION_MT = Core.MethodTable.name.mt |
| 20 | +function foreach_module_mtable(visit, m::Module, world::UInt) |
| 21 | + for gb in globalrefs(m) |
| 22 | + binding = gb.binding |
| 23 | + bpart = lookup_binding_partition(world, binding) |
| 24 | + if is_some_const_binding(binding_kind(bpart)) |
| 25 | + isdefined(bpart, :restriction) || continue |
| 26 | + v = partition_restriction(bpart) |
| 27 | + uw = unwrap_unionall(v) |
| 28 | + name = gb.name |
| 29 | + if isa(uw, DataType) |
| 30 | + tn = uw.name |
| 31 | + if tn.module === m && tn.name === name && tn.wrapper === v && isdefined(tn, :mt) |
| 32 | + # this is the original/primary binding for the type (name/wrapper) |
| 33 | + mt = tn.mt |
| 34 | + if mt !== nothing && mt !== TYPE_TYPE_MT && mt !== NONFUNCTION_MT |
| 35 | + @assert mt.module === m |
| 36 | + visit(mt) || return false |
| 37 | + end |
| 38 | + end |
| 39 | + elseif isa(v, Module) && v !== m && parentmodule(v) === m && _nameof(v) === name |
| 40 | + # this is the original/primary binding for the submodule |
| 41 | + foreach_module_mtable(visit, v, world) || return false |
| 42 | + elseif isa(v, Core.MethodTable) && v.module === m && v.name === name |
| 43 | + # this is probably an external method table here, so let's |
| 44 | + # assume so as there is no way to precisely distinguish them |
| 45 | + visit(v) || return false |
| 46 | + end |
| 47 | + end |
| 48 | + end |
| 49 | + return true |
| 50 | +end |
| 51 | + |
| 52 | +function foreach_reachable_mtable(visit, world::UInt) |
| 53 | + visit(TYPE_TYPE_MT) || return |
| 54 | + visit(NONFUNCTION_MT) || return |
| 55 | + for mod in loaded_modules_array() |
| 56 | + foreach_module_mtable(visit, mod, world) |
| 57 | + end |
| 58 | +end |
| 59 | + |
| 60 | +function invalidate_code_for_globalref!(gr::GlobalRef, src::CodeInfo) |
| 61 | + found_any = false |
| 62 | + labelchangemap = nothing |
| 63 | + stmts = src.code |
| 64 | + function get_labelchangemap() |
| 65 | + if labelchangemap === nothing |
| 66 | + labelchangemap = fill(0, length(stmts)) |
| 67 | + end |
| 68 | + labelchangemap |
| 69 | + end |
| 70 | + isgr(g::GlobalRef) = gr.mod == g.mod && gr.name === g.name |
| 71 | + isgr(g) = false |
| 72 | + for i = 1:length(stmts) |
| 73 | + stmt = stmts[i] |
| 74 | + if isgr(stmt) |
| 75 | + found_any = true |
| 76 | + continue |
| 77 | + end |
| 78 | + found_arg = false |
| 79 | + ngrs = 0 |
| 80 | + for ur in Compiler.userefs(stmt) |
| 81 | + arg = ur[] |
| 82 | + # If any of the GlobalRefs in this stmt match the one that |
| 83 | + # we are about, we need to move out all GlobalRefs to preserve |
| 84 | + # effect order, in case we later invalidate a different GR |
| 85 | + if isa(arg, GlobalRef) |
| 86 | + ngrs += 1 |
| 87 | + if isgr(arg) |
| 88 | + @assert !isa(stmt, PhiNode) |
| 89 | + found_arg = found_any = true |
| 90 | + break |
| 91 | + end |
| 92 | + end |
| 93 | + end |
| 94 | + if found_arg |
| 95 | + get_labelchangemap()[i] += ngrs |
| 96 | + end |
| 97 | + end |
| 98 | + next_empty_idx = 1 |
| 99 | + if labelchangemap !== nothing |
| 100 | + Compiler.cumsum_ssamap!(labelchangemap) |
| 101 | + new_stmts = Vector(undef, length(stmts)+labelchangemap[end]) |
| 102 | + new_ssaflags = Vector{UInt32}(undef, length(new_stmts)) |
| 103 | + new_debuginfo = Compiler.DebugInfoStream(nothing, src.debuginfo, length(new_stmts)) |
| 104 | + new_debuginfo.def = src.debuginfo.def |
| 105 | + for i = 1:length(stmts) |
| 106 | + stmt = stmts[i] |
| 107 | + urs = Compiler.userefs(stmt) |
| 108 | + new_stmt_idx = i+labelchangemap[i] |
| 109 | + for ur in urs |
| 110 | + arg = ur[] |
| 111 | + if isa(arg, SSAValue) |
| 112 | + ur[] = SSAValue(arg.id + labelchangemap[arg.id]) |
| 113 | + elseif next_empty_idx != new_stmt_idx && isa(arg, GlobalRef) |
| 114 | + new_debuginfo.codelocs[3next_empty_idx - 2] = i |
| 115 | + new_stmts[next_empty_idx] = arg |
| 116 | + new_ssaflags[next_empty_idx] = UInt32(0) |
| 117 | + ur[] = SSAValue(next_empty_idx) |
| 118 | + next_empty_idx += 1 |
| 119 | + end |
| 120 | + end |
| 121 | + @assert new_stmt_idx == next_empty_idx |
| 122 | + new_stmts[new_stmt_idx] = urs[] |
| 123 | + new_debuginfo.codelocs[3new_stmt_idx - 2] = i |
| 124 | + new_ssaflags[new_stmt_idx] = src.ssaflags[i] |
| 125 | + next_empty_idx = new_stmt_idx+1 |
| 126 | + end |
| 127 | + src.code = new_stmts |
| 128 | + src.ssavaluetypes = length(new_stmts) |
| 129 | + src.ssaflags = new_ssaflags |
| 130 | + src.debuginfo = Core.DebugInfo(new_debuginfo, length(new_stmts)) |
| 131 | + end |
| 132 | + return found_any |
| 133 | +end |
| 134 | + |
| 135 | +function invalidate_code_for_globalref!(gr::GlobalRef, new_max_world::UInt) |
| 136 | + valid_in_valuepos = false |
| 137 | + foreach_reachable_mtable(new_max_world) do mt::Core.MethodTable |
| 138 | + for method in MethodList(mt) |
| 139 | + if isdefined(method, :source) |
| 140 | + src = _uncompressed_ir(method) |
| 141 | + old_stmts = src.code |
| 142 | + if invalidate_code_for_globalref!(gr, src) |
| 143 | + if src.code !== old_stmts |
| 144 | + method.debuginfo = src.debuginfo |
| 145 | + method.source = src |
| 146 | + method.source = ccall(:jl_compress_ir, Ref{String}, (Any, Ptr{Cvoid}), method, C_NULL) |
| 147 | + end |
| 148 | + |
| 149 | + for mi in specializations(method) |
| 150 | + ci = mi.cache |
| 151 | + while true |
| 152 | + if ci.max_world > new_max_world |
| 153 | + ccall(:jl_invalidate_code_instance, Cvoid, (Any, UInt), ci, new_max_world) |
| 154 | + end |
| 155 | + isdefined(ci, :next) || break |
| 156 | + ci = ci.next |
| 157 | + end |
| 158 | + end |
| 159 | + end |
| 160 | + end |
| 161 | + end |
| 162 | + return true |
| 163 | + end |
| 164 | +end |
0 commit comments