Skip to content

Commit 98988d8

Browse files
authored
REPLCompletions: replace get_type by the proper inference (#49206)
This PR generalizes the idea from #49199 and uses inference to analyze the types of REPL expression. This approach offers several advantages over the current `get_[value|type]`-based implementation: - The need for various special cases is eliminated, as lowering normalizes expressions, and inference handles all language features. - Constant propagation allows us to obtain accurate completions for complex expressions safely (see #36437). Analysis on arbitrary REPL expressions can be done by the following steps: - Lower a given expression - Form a top-level `MethodInstance` from the lowered expression - Run inference on the top-level `MethodInstance` This PR implements `REPLInterpreter`, a custom `AbstractInterpreter` that: - aggressively resolve global bindings to enable reasonable completions for lines like `Mod.a.|` (where `|` is the cursor position) - aggressively concrete evaluates `:inconsistent` calls to provide reasonable completions for cases like `Ref(Some(42))[].|` - does not optimize the inferred code, as `REPLInterpreter` is only used to obtain the type or constant information of given expressions Aggressive binding resolution presents challenges for `REPLInterpreter`'s cache validation (since #40399 hasn't been resolved yet). To avoid cache validation issue, `REPLInterpreter` only allows aggressive binding resolution for top-level frame representing REPL input code (`repl_frame`) and for child `getproperty` frames that are constant propagated from the `repl_frame`. This works, since 1.) these frames are never cached, and 2.) their results are only observed by the non-cached `repl_frame` `REPLInterpreter` also aggressively concrete evaluate `:inconsistent` calls within `repl_frame`, allowing it to get get accurate type information about complex expressions that otherwise can not be constant folded, in a safe way, i.e. it still doesn't evaluate effectful expressions like `pop!(xs)`. Similarly to the aggressive binding resolution, aggressive concrete evaluation doesn't present any cache validation issues because `repl_frame` is never cached. Also note that the code cache for `REPLInterpreter` is separated from the native code cache, ensuring that code caches produced by `REPLInterpreter`, where bindings are aggressively resolved and the code is not really optimized, do not affect the native code execution. A hack has also been added to avoid serializing `CodeInstances`s produced by `REPLInterpreter` during precompilation to workaround #48453. closes #36437 replaces #49199
1 parent 46813d3 commit 98988d8

File tree

2 files changed

+243
-139
lines changed

2 files changed

+243
-139
lines changed

stdlib/REPL/src/REPLCompletions.jl

Lines changed: 174 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ module REPLCompletions
44

55
export completions, shell_completions, bslash_completions, completion_text
66

7+
using Core: CodeInfo, MethodInstance, CodeInstance, Const
8+
const CC = Core.Compiler
79
using Base.Meta
810
using Base: propertynames, something
911

@@ -151,21 +153,21 @@ function complete_symbol(sym::String, @nospecialize(ffunc), context_module::Modu
151153

152154
ex = Meta.parse(lookup_name, raise=false, depwarn=false)
153155

154-
b, found = get_value(ex, context_module)
155-
if found
156-
val = b
157-
if isa(b, Module)
158-
mod = b
156+
res = repl_eval_ex(ex, context_module)
157+
res === nothing && return Completion[]
158+
if res isa Const
159+
val = res.val
160+
if isa(val, Module)
161+
mod = val
159162
lookup_module = true
160163
else
161164
lookup_module = false
162-
t = typeof(b)
165+
t = typeof(val)
163166
end
164-
else # If the value is not found using get_value, the expression contain an advanced expression
167+
else
165168
lookup_module = false
166-
t, found = get_type(ex, context_module)
169+
t = CC.widenconst(res)
167170
end
168-
found || return Completion[]
169171
end
170172

171173
suggestions = Completion[]
@@ -404,133 +406,182 @@ function find_start_brace(s::AbstractString; c_start='(', c_end=')')
404406
return (startind:lastindex(s), method_name_end)
405407
end
406408

407-
# Returns the value in a expression if sym is defined in current namespace fn.
408-
# This method is used to iterate to the value of a expression like:
409-
# :(REPL.REPLCompletions.whitespace_chars) a `dump` of this expression
410-
# will show it consist of Expr, QuoteNode's and Symbol's which all needs to
411-
# be handled differently to iterate down to get the value of whitespace_chars.
412-
function get_value(sym::Expr, fn)
413-
if sym.head === :quote || sym.head === :inert
414-
return sym.args[1], true
415-
end
416-
sym.head !== :. && return (nothing, false)
417-
for ex in sym.args
418-
ex, found = get_value(ex, fn)::Tuple{Any, Bool}
419-
!found && return (nothing, false)
420-
fn, found = get_value(ex, fn)::Tuple{Any, Bool}
421-
!found && return (nothing, false)
422-
end
423-
return (fn, true)
409+
struct REPLInterpreterCache
410+
dict::IdDict{MethodInstance,CodeInstance}
424411
end
425-
get_value(sym::Symbol, fn) = isdefined(fn, sym) ? (getfield(fn, sym), true) : (nothing, false)
426-
get_value(sym::QuoteNode, fn) = (sym.value, true)
427-
get_value(sym::GlobalRef, fn) = get_value(sym.name, sym.mod)
428-
get_value(sym, fn) = (sym, true)
429-
430-
# Return the type of a getfield call expression
431-
function get_type_getfield(ex::Expr, fn::Module)
432-
length(ex.args) == 3 || return Any, false # should never happen, but just for safety
433-
fld, found = get_value(ex.args[3], fn)
434-
fld isa Symbol || return Any, false
435-
obj = ex.args[2]
436-
objt, found = get_type(obj, fn)
437-
found || return Any, false
438-
objt isa DataType || return Any, false
439-
hasfield(objt, fld) || return Any, false
440-
return fieldtype(objt, fld), true
412+
REPLInterpreterCache() = REPLInterpreterCache(IdDict{MethodInstance,CodeInstance}())
413+
const REPL_INTERPRETER_CACHE = REPLInterpreterCache()
414+
415+
function get_code_cache()
416+
# XXX Avoid storing analysis results into the cache that persists across precompilation,
417+
# as [sys|pkg]image currently doesn't support serializing externally created `CodeInstance`.
418+
# Otherwise, `CodeInstance`s created by `REPLInterpreter``, that are much less optimized
419+
# that those produced by `NativeInterpreter`, will leak into the native code cache,
420+
# potentially causing runtime slowdown.
421+
# (see https://github.com/JuliaLang/julia/issues/48453).
422+
if (@ccall jl_generating_output()::Cint) == 1
423+
return REPLInterpreterCache()
424+
else
425+
return REPL_INTERPRETER_CACHE
426+
end
441427
end
442428

443-
# Determines the return type with the Compiler of a function call using the type information of the arguments.
444-
function get_type_call(expr::Expr, fn::Module)
445-
f_name = expr.args[1]
446-
f, found = get_type(f_name, fn)
447-
found || return (Any, false) # If the function f is not found return Any.
448-
args = Any[]
449-
for i in 2:length(expr.args) # Find the type of the function arguments
450-
typ, found = get_type(expr.args[i], fn)
451-
found ? push!(args, typ) : push!(args, Any)
429+
struct REPLInterpreter <: CC.AbstractInterpreter
430+
repl_frame::CC.InferenceResult
431+
world::UInt
432+
inf_params::CC.InferenceParams
433+
opt_params::CC.OptimizationParams
434+
inf_cache::Vector{CC.InferenceResult}
435+
code_cache::REPLInterpreterCache
436+
function REPLInterpreter(repl_frame::CC.InferenceResult;
437+
world::UInt = Base.get_world_counter(),
438+
inf_params::CC.InferenceParams = CC.InferenceParams(),
439+
opt_params::CC.OptimizationParams = CC.OptimizationParams(),
440+
inf_cache::Vector{CC.InferenceResult} = CC.InferenceResult[],
441+
code_cache::REPLInterpreterCache = get_code_cache())
442+
return new(repl_frame, world, inf_params, opt_params, inf_cache, code_cache)
452443
end
453-
world = Base.get_world_counter()
454-
return_type = Core.Compiler.return_type(Tuple{f, args...}, world)
455-
return (return_type, true)
456444
end
457-
458-
# Returns the return type. example: get_type(:(Base.strip("", ' ')), Main) returns (SubString{String}, true)
459-
function try_get_type(sym::Expr, fn::Module)
460-
val, found = get_value(sym, fn)
461-
found && return Core.Typeof(val), found
462-
if sym.head === :call
463-
# getfield call is special cased as the evaluation of getfield provides good type information,
464-
# is inexpensive and it is also performed in the complete_symbol function.
465-
a1 = sym.args[1]
466-
if a1 === :getfield || a1 === GlobalRef(Core, :getfield)
467-
return get_type_getfield(sym, fn)
445+
CC.InferenceParams(interp::REPLInterpreter) = interp.inf_params
446+
CC.OptimizationParams(interp::REPLInterpreter) = interp.opt_params
447+
CC.get_world_counter(interp::REPLInterpreter) = interp.world
448+
CC.get_inference_cache(interp::REPLInterpreter) = interp.inf_cache
449+
CC.code_cache(interp::REPLInterpreter) = CC.WorldView(interp.code_cache, CC.WorldRange(interp.world))
450+
CC.get(wvc::CC.WorldView{REPLInterpreterCache}, mi::MethodInstance, default) = get(wvc.cache.dict, mi, default)
451+
CC.getindex(wvc::CC.WorldView{REPLInterpreterCache}, mi::MethodInstance) = getindex(wvc.cache.dict, mi)
452+
CC.haskey(wvc::CC.WorldView{REPLInterpreterCache}, mi::MethodInstance) = haskey(wvc.cache.dict, mi)
453+
CC.setindex!(wvc::CC.WorldView{REPLInterpreterCache}, ci::CodeInstance, mi::MethodInstance) = setindex!(wvc.cache.dict, ci, mi)
454+
455+
# REPLInterpreter is only used for type analysis, so it should disable optimization entirely
456+
CC.may_optimize(::REPLInterpreter) = false
457+
458+
# REPLInterpreter analyzes a top-level frame, so better to not bail out from it
459+
CC.bail_out_toplevel_call(::REPLInterpreter, ::CC.InferenceLoopState, ::CC.InferenceState) = false
460+
461+
# `REPLInterpreter` aggressively resolves global bindings to enable reasonable completions
462+
# for lines like `Mod.a.|` (where `|` is the cursor position).
463+
# Aggressive binding resolution poses challenges for the inference cache validation
464+
# (until https://github.com/JuliaLang/julia/issues/40399 is implemented).
465+
# To avoid the cache validation issues, `REPLInterpreter` only allows aggressive binding
466+
# resolution for top-level frame representing REPL input code (`repl_frame`) and for child
467+
# `getproperty` frames that are constant propagated from the `repl_frame`. This works, since
468+
# a.) these frames are never cached, and
469+
# b.) their results are only observed by the non-cached `repl_frame`.
470+
#
471+
# `REPLInterpreter` also aggressively concrete evaluate `:inconsistent` calls within
472+
# `repl_frame` to provide reasonable completions for lines like `Ref(Some(42))[].|`.
473+
# Aggressive concrete evaluation allows us to get accurate type information about complex
474+
# expressions that otherwise can not be constant folded, in a safe way, i.e. it still
475+
# doesn't evaluate effectful expressions like `pop!(xs)`.
476+
# Similarly to the aggressive binding resolution, aggressive concrete evaluation doesn't
477+
# present any cache validation issues because `repl_frame` is never cached.
478+
479+
is_repl_frame(interp::REPLInterpreter, sv::CC.InferenceState) = interp.repl_frame === sv.result
480+
481+
# aggressive global binding resolution within `repl_frame`
482+
function CC.abstract_eval_globalref(interp::REPLInterpreter, g::GlobalRef,
483+
sv::CC.InferenceState)
484+
if is_repl_frame(interp, sv)
485+
if CC.isdefined_globalref(g)
486+
return Const(ccall(:jl_get_globalref_value, Any, (Any,), g))
468487
end
469-
return get_type_call(sym, fn)
470-
elseif sym.head === :thunk
471-
thk = sym.args[1]
472-
rt = ccall(:jl_infer_thunk, Any, (Any, Any), thk::Core.CodeInfo, fn)
473-
rt !== Any && return (rt, true)
474-
elseif sym.head === :ref
475-
# some simple cases of `expand`
476-
return try_get_type(Expr(:call, GlobalRef(Base, :getindex), sym.args...), fn)
477-
elseif sym.head === :. && sym.args[2] isa QuoteNode # second check catches broadcasting
478-
return try_get_type(Expr(:call, GlobalRef(Core, :getfield), sym.args...), fn)
479-
elseif sym.head === :toplevel || sym.head === :block
480-
isempty(sym.args) && return (nothing, true)
481-
return try_get_type(sym.args[end], fn)
482-
elseif sym.head === :escape || sym.head === :var"hygienic-scope"
483-
return try_get_type(sym.args[1], fn)
488+
return Union{}
484489
end
485-
return (Any, false)
490+
return @invoke CC.abstract_eval_globalref(interp::CC.AbstractInterpreter, g::GlobalRef,
491+
sv::CC.InferenceState)
486492
end
487493

488-
try_get_type(other, fn::Module) = get_type(other, fn)
494+
function is_repl_frame_getproperty(interp::REPLInterpreter, sv::CC.InferenceState)
495+
def = sv.linfo.def
496+
def isa Method || return false
497+
def.name === :getproperty || return false
498+
sv.cached && return false
499+
return is_repl_frame(interp, sv.parent)
500+
end
489501

490-
function get_type(sym::Expr, fn::Module)
491-
# try to analyze nests of calls. if this fails, try using the expanded form.
492-
val, found = try_get_type(sym, fn)
493-
found && return val, found
494-
# https://github.com/JuliaLang/julia/issues/27184
495-
if isexpr(sym, :macrocall)
496-
_, found = get_type(first(sym.args), fn)
497-
found || return Any, false
498-
end
499-
newsym = try
500-
macroexpand(fn, sym; recursive=false)
501-
catch e
502-
# user code failed in macroexpand (ignore it)
503-
return Any, false
504-
end
505-
val, found = try_get_type(newsym, fn)
506-
if !found
507-
newsym = try
508-
Meta.lower(fn, sym)
509-
catch e
510-
# user code failed in lowering (ignore it)
511-
return Any, false
502+
# aggressive global binding resolution for `getproperty(::Module, ::Symbol)` calls within `repl_frame`
503+
function CC.builtin_tfunction(interp::REPLInterpreter, @nospecialize(f),
504+
argtypes::Vector{Any}, sv::CC.InferenceState)
505+
if f === Core.getglobal && is_repl_frame_getproperty(interp, sv)
506+
if length(argtypes) == 2
507+
a1, a2 = argtypes
508+
if isa(a1, Const) && isa(a2, Const)
509+
a1val, a2val = a1.val, a2.val
510+
if isa(a1val, Module) && isa(a2val, Symbol)
511+
g = GlobalRef(a1val, a2val)
512+
if CC.isdefined_globalref(g)
513+
return Const(ccall(:jl_get_globalref_value, Any, (Any,), g))
514+
end
515+
return Union{}
516+
end
517+
end
512518
end
513-
val, found = try_get_type(newsym, fn)
514519
end
515-
return val, found
520+
return @invoke CC.builtin_tfunction(interp::CC.AbstractInterpreter, f::Any,
521+
argtypes::Vector{Any}, sv::CC.InferenceState)
516522
end
517523

518-
function get_type(sym, fn::Module)
519-
val, found = get_value(sym, fn)
520-
return found ? Core.Typeof(val) : Any, found
524+
# aggressive concrete evaluation for `:inconsistent` frames within `repl_frame`
525+
function CC.concrete_eval_eligible(interp::REPLInterpreter, @nospecialize(f),
526+
result::CC.MethodCallResult, arginfo::CC.ArgInfo,
527+
sv::CC.InferenceState)
528+
if is_repl_frame(interp, sv)
529+
neweffects = CC.Effects(result.effects; consistent=CC.ALWAYS_TRUE)
530+
result = CC.MethodCallResult(result.rt, result.edgecycle, result.edgelimited,
531+
result.edge, neweffects)
532+
end
533+
return @invoke CC.concrete_eval_eligible(interp::CC.AbstractInterpreter, f::Any,
534+
result::CC.MethodCallResult, arginfo::CC.ArgInfo,
535+
sv::CC.InferenceState)
536+
end
537+
538+
function resolve_toplevel_symbols!(mod::Module, src::Core.CodeInfo)
539+
newsrc = copy(src)
540+
@ccall jl_resolve_globals_in_ir(
541+
#=jl_array_t *stmts=# newsrc.code::Any,
542+
#=jl_module_t *m=# mod::Any,
543+
#=jl_svec_t *sparam_vals=# Core.svec()::Any,
544+
#=int binding_effects=# 0::Int)::Cvoid
545+
return newsrc
521546
end
522547

523-
function get_type(T, found::Bool, default_any::Bool)
524-
return found ? T :
525-
default_any ? Any : throw(ArgumentError("argument not found"))
548+
# lower `ex` and run type inference on the resulting top-level expression
549+
function repl_eval_ex(@nospecialize(ex), context_module::Module)
550+
lwr = try
551+
Meta.lower(context_module, ex)
552+
catch # macro expansion failed, etc.
553+
return nothing
554+
end
555+
if lwr isa Symbol
556+
return isdefined(context_module, lwr) ? Const(getfield(context_module, lwr)) : nothing
557+
end
558+
lwr isa Expr || return Const(lwr) # `ex` is literal
559+
isexpr(lwr, :thunk) || return nothing # lowered to `Expr(:error, ...)` or similar
560+
src = lwr.args[1]::Core.CodeInfo
561+
562+
# construct top-level `MethodInstance`
563+
mi = ccall(:jl_new_method_instance_uninit, Ref{Core.MethodInstance}, ());
564+
mi.specTypes = Tuple{}
565+
566+
mi.def = context_module
567+
src = resolve_toplevel_symbols!(context_module, src)
568+
@atomic mi.uninferred = src
569+
570+
result = CC.InferenceResult(mi)
571+
interp = REPLInterpreter(result)
572+
frame = CC.InferenceState(result, src, #=cache=#:no, interp)::CC.InferenceState
573+
574+
CC.typeinf(interp, frame)
575+
576+
return frame.result.result
526577
end
527578

528579
# Method completion on function call expression that look like :(max(1))
529580
MAX_METHOD_COMPLETIONS::Int = 40
530581
function _complete_methods(ex_org::Expr, context_module::Module, shift::Bool)
531-
funct, found = get_type(ex_org.args[1], context_module)::Tuple{Any,Bool}
532-
!found && return 2, funct, [], Set{Symbol}()
533-
582+
funct = repl_eval_ex(ex_org.args[1], context_module)
583+
funct === nothing && return 2, nothing, [], Set{Symbol}()
584+
funct = CC.widenconst(funct)
534585
args_ex, kwargs_ex, kwargs_flag = complete_methods_args(ex_org, context_module, true, true)
535586
return kwargs_flag, funct, args_ex, kwargs_ex
536587
end
@@ -635,7 +686,14 @@ function detect_args_kwargs(funargs::Vector{Any}, context_module::Module, defaul
635686
# argument types
636687
push!(args_ex, Any)
637688
else
638-
push!(args_ex, get_type(get_type(ex, context_module)..., default_any))
689+
argt = repl_eval_ex(ex, context_module)
690+
if argt !== nothing
691+
push!(args_ex, CC.widenconst(argt))
692+
elseif default_any
693+
push!(args_ex, Any)
694+
else
695+
throw(ArgumentError("argument not found"))
696+
end
639697
end
640698
end
641699
end
@@ -709,7 +767,6 @@ function close_path_completion(str, startpos, r, paths, pos)
709767
return lastindex(str) <= pos || str[nextind(str, pos)] != '"'
710768
end
711769

712-
713770
function bslash_completions(string::String, pos::Int)
714771
slashpos = something(findprev(isequal('\\'), string, pos), 0)
715772
if (something(findprev(in(bslash_separators), string, pos), 0) < slashpos &&

0 commit comments

Comments
 (0)