Skip to content

Commit 2b91d9c

Browse files
vtjnashKristofferC
authored andcommitted
cfunction: reimplement, as originally planned, for reliable performance (#57226)
(cherry picked from commit ca7cf30)
1 parent 75f3690 commit 2b91d9c

13 files changed

+926
-325
lines changed

Compiler/src/typeinfer.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1266,6 +1266,7 @@ function typeinf_ext_toplevel(methods::Vector{Any}, worlds::Vector{UInt}, trim::
12661266
tocompile = Vector{CodeInstance}()
12671267
codeinfos = []
12681268
# first compute the ABIs of everything
1269+
latest = true # whether this_world == world_counter()
12691270
for this_world in reverse(sort!(worlds))
12701271
interp = NativeInterpreter(this_world)
12711272
for i = 1:length(methods)
@@ -1278,18 +1279,18 @@ function typeinf_ext_toplevel(methods::Vector{Any}, worlds::Vector{UInt}, trim::
12781279
# then we want to compile and emit this
12791280
if item.def.primary_world <= this_world <= item.def.deleted_world
12801281
ci = typeinf_ext(interp, item, SOURCE_MODE_NOT_REQUIRED)
1281-
ci isa CodeInstance && !use_const_api(ci) && push!(tocompile, ci)
1282+
ci isa CodeInstance && push!(tocompile, ci)
12821283
end
1283-
elseif item isa SimpleVector
1284+
elseif item isa SimpleVector && latest
12841285
(rt::Type, sig::Type) = item
12851286
# make a best-effort attempt to enqueue the relevant code for the ccallable
12861287
ptr = ccall(:jl_get_specialization1,
12871288
#= MethodInstance =# Ptr{Cvoid}, (Any, Csize_t, Cint),
12881289
sig, this_world, #= mt_cache =# 0)
12891290
if ptr !== C_NULL
1290-
mi = unsafe_pointer_to_objref(ptr)
1291+
mi = unsafe_pointer_to_objref(ptr)::MethodInstance
12911292
ci = typeinf_ext(interp, mi, SOURCE_MODE_NOT_REQUIRED)
1292-
ci isa CodeInstance && !use_const_api(ci) && push!(tocompile, ci)
1293+
ci isa CodeInstance && push!(tocompile, ci)
12931294
end
12941295
# additionally enqueue the ccallable entrypoint / adapter, which implicitly
12951296
# invokes the above ci
@@ -1305,7 +1306,7 @@ function typeinf_ext_toplevel(methods::Vector{Any}, worlds::Vector{UInt}, trim::
13051306
mi = get_ci_mi(callee)
13061307
def = mi.def
13071308
if use_const_api(callee)
1308-
src = codeinfo_for_const(interp, mi, code.rettype_const)
1309+
src = codeinfo_for_const(interp, mi, callee.rettype_const)
13091310
elseif haskey(interp.codegen, callee)
13101311
src = interp.codegen[callee]
13111312
elseif isa(def, Method) && ccall(:jl_get_module_infer, Cint, (Any,), def.module) == 0 && !trim
@@ -1327,6 +1328,7 @@ function typeinf_ext_toplevel(methods::Vector{Any}, worlds::Vector{UInt}, trim::
13271328
println("warning: failed to get code for ", mi)
13281329
end
13291330
end
1331+
latest = false
13301332
end
13311333
return codeinfos
13321334
end

src/aotcompile.cpp

Lines changed: 174 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,7 @@ static void resolve_workqueue(jl_codegen_params_t &params, egal_set &method_root
423423
if (decls.functionObject == "jl_fptr_args") {
424424
preal_decl = decls.specFunctionObject;
425425
}
426-
else if (decls.functionObject != "jl_fptr_sparam" && decls.functionObject != "jl_f_opaque_closure_call") {
426+
else if (decls.functionObject != "jl_fptr_sparam" && decls.functionObject != "jl_f_opaque_closure_call" && decls.functionObject != "jl_fptr_const_return") {
427427
preal_decl = decls.specFunctionObject;
428428
preal_specsig = true;
429429
}
@@ -439,6 +439,13 @@ static void resolve_workqueue(jl_codegen_params_t &params, egal_set &method_root
439439
Module *mod = proto.decl->getParent();
440440
assert(proto.decl->isDeclaration());
441441
Function *pinvoke = nullptr;
442+
if (preal_decl.empty() && jl_atomic_load_relaxed(&codeinst->invoke) == jl_fptr_const_return_addr) {
443+
std::string gf_thunk_name = emit_abi_constreturn(mod, params, proto.specsig, codeinst);
444+
preal_specsig = proto.specsig;
445+
if (invokeName.empty())
446+
invokeName = "jl_fptr_const_return";
447+
preal_decl = mod->getNamedValue(gf_thunk_name)->getName();
448+
}
442449
if (preal_decl.empty()) {
443450
if (invokeName.empty() && params.params->trim) {
444451
jl_safe_printf("warning: bailed out to invoke when compiling: ");
@@ -483,6 +490,7 @@ static void resolve_workqueue(jl_codegen_params_t &params, egal_set &method_root
483490
ocinvokeDecl = pinvoke->getName();
484491
assert(!ocinvokeDecl.empty());
485492
assert(ocinvokeDecl != "jl_fptr_args");
493+
assert(ocinvokeDecl != "jl_fptr_const_return");
486494
assert(ocinvokeDecl != "jl_fptr_sparam");
487495
// merge and/or rename this prototype to the real function
488496
if (Value *specfun = mod->getNamedValue(ocinvokeDecl)) {
@@ -499,6 +507,134 @@ static void resolve_workqueue(jl_codegen_params_t &params, egal_set &method_root
499507
JL_GC_POP();
500508
}
501509

510+
/// Link the function in the source module into the destination module if
511+
/// needed, setting up mapping information.
512+
/// Similar to orc::cloneFunctionDecl, but more complete for greater correctness
513+
Function *IRLinker_copyFunctionProto(Module *DstM, Function *SF) {
514+
// If there is no linkage to be performed or we are linking from the source,
515+
// bring SF over, if we haven't already.
516+
if (SF->getParent() == DstM)
517+
return SF;
518+
if (auto *F = DstM->getNamedValue(SF->getName()))
519+
return cast<Function>(F);
520+
auto *F = Function::Create(SF->getFunctionType(), SF->getLinkage(),
521+
SF->getAddressSpace(), SF->getName(), DstM);
522+
F->copyAttributesFrom(SF);
523+
F->IsNewDbgInfoFormat = SF->IsNewDbgInfoFormat;
524+
525+
// Remove these copied constants since they point to the source module.
526+
F->setPersonalityFn(nullptr);
527+
F->setPrefixData(nullptr);
528+
F->setPrologueData(nullptr);
529+
return F;
530+
}
531+
532+
static Function *aot_abi_converter(jl_codegen_params_t &params, Module *M, jl_value_t *declrt, jl_value_t *sigt, size_t nargs, bool specsig, jl_code_instance_t *codeinst, Module *defM, StringRef func, StringRef specfunc, bool target_specsig)
533+
{
534+
std::string gf_thunk_name;
535+
if (!specfunc.empty()) {
536+
Value *llvmtarget = IRLinker_copyFunctionProto(M, defM->getFunction(specfunc));
537+
gf_thunk_name = emit_abi_converter(M, params, declrt, sigt, nargs, specsig, codeinst, llvmtarget, target_specsig);
538+
}
539+
else {
540+
Value *llvmtarget = func.empty() ? nullptr : IRLinker_copyFunctionProto(M, defM->getFunction(func));
541+
gf_thunk_name = emit_abi_dispatcher(M, params, declrt, sigt, nargs, specsig, codeinst, llvmtarget);
542+
}
543+
auto F = M->getFunction(gf_thunk_name);
544+
assert(F);
545+
return F;
546+
}
547+
548+
static void generate_cfunc_thunks(jl_codegen_params_t &params, jl_compiled_functions_t &compiled_functions)
549+
{
550+
DenseMap<jl_method_instance_t*, jl_code_instance_t*> compiled_mi;
551+
for (auto &def : compiled_functions) {
552+
jl_code_instance_t *this_code = def.first;
553+
jl_method_instance_t *mi = jl_get_ci_mi(this_code);
554+
if (this_code->owner == jl_nothing && jl_atomic_load_relaxed(&this_code->max_world) == ~(size_t)0 && this_code->def == (jl_value_t*)mi)
555+
compiled_mi[mi] = this_code;
556+
}
557+
size_t latestworld = jl_atomic_load_acquire(&jl_world_counter);
558+
for (cfunc_decl_t &cfunc : params.cfuncs) {
559+
Module *M = cfunc.theFptr->getParent();
560+
jl_value_t *sigt = cfunc.sigt;
561+
JL_GC_PROMISE_ROOTED(sigt);
562+
jl_value_t *declrt = cfunc.declrt;
563+
JL_GC_PROMISE_ROOTED(declrt);
564+
Function *unspec = aot_abi_converter(params, M, declrt, sigt, cfunc.nargs, cfunc.specsig, nullptr, nullptr, "", "", false);
565+
jl_code_instance_t *codeinst = nullptr;
566+
auto assign_fptr = [&params, &cfunc, &codeinst, &unspec](Function *f) {
567+
ConstantArray *init = cast<ConstantArray>(cfunc.cfuncdata->getInitializer());
568+
SmallVector<Constant*,6> initvals;
569+
for (unsigned i = 0; i < init->getNumOperands(); ++i)
570+
initvals.push_back(init->getOperand(i));
571+
assert(initvals.size() == 6);
572+
assert(initvals[0]->isNullValue());
573+
if (codeinst) {
574+
Constant *llvmcodeinst = literal_pointer_val_slot(params, f->getParent(), (jl_value_t*)codeinst);
575+
initvals[0] = llvmcodeinst; // plast_codeinst
576+
}
577+
assert(initvals[2]->isNullValue());
578+
initvals[2] = unspec;
579+
cfunc.cfuncdata->setInitializer(ConstantArray::get(init->getType(), initvals));
580+
cfunc.theFptr->setInitializer(f);
581+
};
582+
Module *defM = nullptr;
583+
StringRef func;
584+
jl_method_instance_t *mi = jl_get_specialization1((jl_tupletype_t*)sigt, latestworld, 0);
585+
if (mi) {
586+
auto it = compiled_mi.find(mi);
587+
if (it != compiled_mi.end()) {
588+
codeinst = it->second;
589+
JL_GC_PROMISE_ROOTED(codeinst);
590+
auto defs = compiled_functions.find(codeinst);
591+
defM = std::get<0>(defs->second).getModuleUnlocked();
592+
const jl_llvm_functions_t &decls = std::get<1>(defs->second);
593+
func = decls.functionObject;
594+
StringRef specfunc = decls.specFunctionObject;
595+
jl_value_t *astrt = codeinst->rettype;
596+
if (astrt != (jl_value_t*)jl_bottom_type &&
597+
jl_type_intersection(astrt, declrt) == jl_bottom_type) {
598+
// Do not warn if the function never returns since it is
599+
// occasionally required by the C API (typically error callbacks)
600+
// even though we're likely to encounter memory errors in that case
601+
jl_printf(JL_STDERR, "WARNING: cfunction: return type of %s does not match\n", name_from_method_instance(mi));
602+
}
603+
if (func == "jl_fptr_const_return") {
604+
std::string gf_thunk_name = emit_abi_constreturn(M, params, declrt, sigt, cfunc.nargs, cfunc.specsig, codeinst->rettype_const);
605+
auto F = M->getFunction(gf_thunk_name);
606+
assert(F);
607+
assign_fptr(F);
608+
continue;
609+
}
610+
else if (func == "jl_fptr_args") {
611+
assert(!specfunc.empty());
612+
if (!cfunc.specsig && jl_subtype(astrt, declrt)) {
613+
assign_fptr(IRLinker_copyFunctionProto(M, defM->getFunction(specfunc)));
614+
continue;
615+
}
616+
assign_fptr(aot_abi_converter(params, M, declrt, sigt, cfunc.nargs, cfunc.specsig, codeinst, defM, func, specfunc, false));
617+
continue;
618+
}
619+
else if (func == "jl_fptr_sparam" || func == "jl_f_opaque_closure_call") {
620+
func = ""; // use jl_invoke instead for these, since we don't declare these prototypes
621+
}
622+
else {
623+
assert(!specfunc.empty());
624+
if (jl_egal(mi->specTypes, sigt) && jl_egal(declrt, astrt)) {
625+
assign_fptr(IRLinker_copyFunctionProto(M, defM->getFunction(specfunc)));
626+
continue;
627+
}
628+
assign_fptr(aot_abi_converter(params, M, declrt, sigt, cfunc.nargs, cfunc.specsig, codeinst, defM, func, specfunc, true));
629+
continue;
630+
}
631+
}
632+
}
633+
Function *f = codeinst ? aot_abi_converter(params, M, declrt, sigt, cfunc.nargs, cfunc.specsig, codeinst, defM, func, "", false) : unspec;
634+
return assign_fptr(f);
635+
}
636+
}
637+
502638

503639
// takes the running content that has collected in the shadow module and dump it to disk
504640
// this builds the object file portion of the sysimage files for fast startup
@@ -651,7 +787,11 @@ void *jl_emit_native_impl(jl_array_t *codeinfos, LLVMOrcThreadSafeModuleRef llvm
651787
orc::ThreadSafeModule result_m = jl_create_ts_module(name_from_method_instance(jl_get_ci_mi(codeinst)),
652788
params.tsctx, clone.getModuleUnlocked()->getDataLayout(),
653789
Triple(clone.getModuleUnlocked()->getTargetTriple()));
654-
jl_llvm_functions_t decls = jl_emit_codeinst(result_m, codeinst, src, params);
790+
jl_llvm_functions_t decls;
791+
if (jl_atomic_load_relaxed(&codeinst->invoke) == jl_fptr_const_return_addr)
792+
decls.functionObject = "jl_fptr_const_return";
793+
else
794+
decls = jl_emit_codeinst(result_m, codeinst, src, params);
655795
record_method_roots(method_roots, jl_get_ci_mi(codeinst));
656796
if (result_m)
657797
compiled_functions[codeinst] = {std::move(result_m), std::move(decls)};
@@ -671,6 +811,8 @@ void *jl_emit_native_impl(jl_array_t *codeinfos, LLVMOrcThreadSafeModuleRef llvm
671811
}
672812
// finally, make sure all referenced methods get fixed up, particularly if the user declined to compile them
673813
resolve_workqueue(params, method_roots, compiled_functions);
814+
// including generating cfunction thunks
815+
generate_cfunc_thunks(params, compiled_functions);
674816
aot_optimize_roots(params, method_roots, compiled_functions);
675817
params.temporary_roots = nullptr;
676818
JL_GC_POP();
@@ -728,9 +870,12 @@ void *jl_emit_native_impl(jl_array_t *codeinfos, LLVMOrcThreadSafeModuleRef llvm
728870
else if (func == "jl_fptr_sparam") {
729871
func_id = -2;
730872
}
731-
else if (decls.functionObject == "jl_f_opaque_closure_call") {
873+
else if (func == "jl_f_opaque_closure_call") {
732874
func_id = -4;
733875
}
876+
else if (func == "jl_fptr_const_return") {
877+
func_id = -5;
878+
}
734879
else {
735880
//Safe b/c context is locked by params
736881
data->jl_sysimg_fvars.push_back(cast<Function>(clone.getModuleUnlocked()->getNamedValue(func)));
@@ -2201,7 +2346,7 @@ extern "C" JL_DLLEXPORT_CODEGEN jl_code_info_t *jl_gdbdumpcode(jl_method_instanc
22012346
// for use in reflection from Julia.
22022347
// This is paired with jl_dump_function_ir and jl_dump_function_asm, either of which will free all memory allocated here
22032348
extern "C" JL_DLLEXPORT_CODEGEN
2204-
void jl_get_llvmf_defn_impl(jl_llvmf_dump_t* dump, jl_method_instance_t *mi, jl_code_info_t *src, char getwrapper, char optimize, const jl_cgparams_t params)
2349+
void jl_get_llvmf_defn_impl(jl_llvmf_dump_t *dump, jl_method_instance_t *mi, jl_code_info_t *src, char getwrapper, char optimize, const jl_cgparams_t params)
22052350
{
22062351
// emit this function into a new llvm module
22072352
dump->F = nullptr;
@@ -2223,7 +2368,31 @@ void jl_get_llvmf_defn_impl(jl_llvmf_dump_t* dump, jl_method_instance_t *mi, jl_
22232368
output.imaging_mode = jl_options.image_codegen;
22242369
output.temporary_roots = jl_alloc_array_1d(jl_array_any_type, 0);
22252370
JL_GC_PUSH1(&output.temporary_roots);
2226-
auto decls = jl_emit_code(m, mi, src, mi->specTypes, src->rettype, output);
2371+
jl_llvm_functions_t decls = jl_emit_code(m, mi, src, mi->specTypes, src->rettype, output);
2372+
// while not required, also emit the cfunc thunks, based on the
2373+
// inferred ABIs of their targets in the current latest world,
2374+
// since otherwise it is challenging to see all relevant codes
2375+
jl_compiled_functions_t compiled_functions;
2376+
size_t latestworld = jl_atomic_load_acquire(&jl_world_counter);
2377+
for (cfunc_decl_t &cfunc : output.cfuncs) {
2378+
jl_value_t *sigt = cfunc.sigt;
2379+
JL_GC_PROMISE_ROOTED(sigt);
2380+
jl_method_instance_t *mi = jl_get_specialization1((jl_tupletype_t*)sigt, latestworld, 0);
2381+
if (mi == nullptr)
2382+
continue;
2383+
jl_code_instance_t *codeinst = jl_type_infer(mi, latestworld, SOURCE_MODE_NOT_REQUIRED);
2384+
if (codeinst == nullptr || compiled_functions.count(codeinst))
2385+
continue;
2386+
orc::ThreadSafeModule decl_m = jl_create_ts_module("extern", ctx);
2387+
jl_llvm_functions_t decls;
2388+
if (jl_atomic_load_relaxed(&codeinst->invoke) == jl_fptr_const_return_addr)
2389+
decls.functionObject = "jl_fptr_const_return";
2390+
else
2391+
decls = jl_emit_codedecls(decl_m, codeinst, output);
2392+
compiled_functions[codeinst] = {std::move(decl_m), std::move(decls)};
2393+
}
2394+
generate_cfunc_thunks(output, compiled_functions);
2395+
compiled_functions.clear();
22272396
output.temporary_roots = nullptr;
22282397
JL_GC_POP(); // GC the global_targets array contents now since reflection doesn't need it
22292398

src/ccall.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1970,6 +1970,8 @@ static jl_cgval_t emit_ccall(jl_codectx_t &ctx, jl_value_t **args, size_t nargs)
19701970
return retval;
19711971
}
19721972

1973+
static inline Constant *literal_static_pointer_val(const void *p, Type *T);
1974+
19731975
jl_cgval_t function_sig_t::emit_a_ccall(
19741976
jl_codectx_t &ctx,
19751977
const native_sym_arg_t &symarg,

0 commit comments

Comments
 (0)