diff --git a/typed_python/compiler/binary_shared_object.py b/typed_python/compiler/binary_shared_object.py index 5c2e5765..beda16fd 100644 --- a/typed_python/compiler/binary_shared_object.py +++ b/typed_python/compiler/binary_shared_object.py @@ -36,17 +36,31 @@ def __init__(self, binarySharedObject, diskPath, functionPointers, serializedGlo class BinarySharedObject: """Models a shared object library (.so) loadable on linux systems.""" - def __init__(self, binaryForm, functionTypes, serializedGlobalVariableDefinitions, globalDependencies): + def __init__(self, + binaryForm, + functionTypes, + serializedGlobalVariableDefinitions, + globalDependencies, + functionComplexities, + functionIRs, + serializedFunctionDefinitions + ): """ Args: binaryForm: a bytes object containing the actual compiled code for the module serializedGlobalVariableDefinitions: a map from name to GlobalVariableDefinition - globalDependencies: a dict from function linkname to the list of global variables it depends on + globalDependencies: a dict from function name to the list of global variables it depends on + functionComplexities: a dict from function name to the total number of llvm instructions in the function (used for inlining) + functionIRs: a dict from function name to the llvm IR Functions (used for inlining) + functionDefinitions: a dict from function name to the native_ast.Functions (used for inlining) """ self.binaryForm = binaryForm self.functionTypes = functionTypes self.serializedGlobalVariableDefinitions = serializedGlobalVariableDefinitions self.globalDependencies = globalDependencies + self.functionComplexities = functionComplexities + self.functionIRs = functionIRs + self.serializedFunctionDefinitions = serializedFunctionDefinitions self.hash = sha_hash(binaryForm) @property @@ -54,14 +68,32 @@ def definedSymbols(self): return self.functionTypes.keys() @staticmethod - def fromDisk(path, serializedGlobalVariableDefinitions, functionNameToType, globalDependencies): + def fromDisk(path, + serializedGlobalVariableDefinitions, + functionNameToType, + globalDependencies, + functionComplexities, + functionIRs, + serializedFunctionDefinitions): with open(path, "rb") as f: binaryForm = f.read() - return BinarySharedObject(binaryForm, functionNameToType, serializedGlobalVariableDefinitions, globalDependencies) + return BinarySharedObject(binaryForm, + functionNameToType, + serializedGlobalVariableDefinitions, + globalDependencies, + functionComplexities, + functionIRs, + serializedFunctionDefinitions) @staticmethod - def fromModule(module, serializedGlobalVariableDefinitions, functionNameToType, globalDependencies): + def fromModule(module, + serializedGlobalVariableDefinitions, + functionNameToType, + globalDependencies, + functionComplexities, + functionIRs, + serializedFunctionDefinitions): target_triple = llvm.get_process_triple() target = llvm.Target.from_triple(target_triple) target_machine_shared_object = target.create_target_machine(reloc='pic', codemodel='default') @@ -82,7 +114,13 @@ def fromModule(module, serializedGlobalVariableDefinitions, functionNameToType, ) with open(os.path.join(tf, "module.so"), "rb") as so_file: - return BinarySharedObject(so_file.read(), functionNameToType, serializedGlobalVariableDefinitions, globalDependencies) + return BinarySharedObject(so_file.read(), + functionNameToType, + serializedGlobalVariableDefinitions, + globalDependencies, + functionComplexities, + functionIRs, + serializedFunctionDefinitions) def load(self, storageDir): """Instantiate this .so in temporary storage and return a dict from symbol -> integer function pointer""" diff --git a/typed_python/compiler/compiler_cache.py b/typed_python/compiler/compiler_cache.py index c3a193b8..90e54298 100644 --- a/typed_python/compiler/compiler_cache.py +++ b/typed_python/compiler/compiler_cache.py @@ -15,12 +15,14 @@ import os import uuid import shutil +import llvmlite.ir from typing import Optional, List from typed_python.compiler.binary_shared_object import LoadedBinarySharedObject, BinarySharedObject from typed_python.compiler.directed_graph import DirectedGraph from typed_python.compiler.typed_call_target import TypedCallTarget +import typed_python.compiler.native_ast as native_ast from typed_python.SerializationContext import SerializationContext from typed_python import Dict, ListOf @@ -67,6 +69,8 @@ def __init__(self, cacheDir): self.targetsLoaded: Dict[str, TypedCallTarget] = {} # the set of link_names for functions with linked and validated globals (i.e. ready to be run). self.targetsValidated = set() + # the total number of instructions for each link_name + self.targetComplexity = Dict(str, int)() # link_name -> link_name self.function_dependency_graph = DirectedGraph() # dict from link_name to list of global names (should be llvm keys in serialisedGlobalDefinitions) @@ -90,6 +94,21 @@ def getTarget(self, func_name: str) -> TypedCallTarget: self.loadForSymbol(link_name) return self.targetsLoaded[link_name] + def getIR(self, func_name: str) -> llvmlite.ir.Function: + if not self.hasSymbol(func_name): + raise ValueError(f'symbol not found for func_name {func_name}') + link_name = self._select_link_name(func_name) + module_hash = self.link_name_to_module_hash[link_name] + return self.loadedBinarySharedObjects[module_hash].binarySharedObject.functionIRs[func_name] + + def getDefinition(self, func_name: str) -> native_ast.Function: + if not self.hasSymbol(func_name): + raise ValueError(f'symbol not found for func_name {func_name}') + link_name = self._select_link_name(func_name) + module_hash = self.link_name_to_module_hash[link_name] + serialized_definition = self.loadedBinarySharedObjects[module_hash].binarySharedObject.serializedFunctionDefinitions[func_name] + return SerializationContext().deserialize(serialized_definition) + def _generate_link_name(self, func_name: str, module_hash: str) -> str: return func_name + "." + module_hash @@ -126,6 +145,14 @@ def loadForSymbol(self, linkName: str) -> None: if not self.loadedBinarySharedObjects[moduleHash].validateGlobalVariables(definitionsToLink): raise RuntimeError('failed to validate globals when loading:', linkName) + def complexityForSymbol(self, func_name: str) -> int: + """Get the total number of LLVM instructions for a given symbol.""" + try: + link_name = self._select_link_name(func_name) + return self.targetComplexity[link_name] + except KeyError as e: + raise ValueError(f'No complexity value cached for {func_name}') from e + def loadModuleByHash(self, moduleHash: str) -> None: """Load a module by name. @@ -139,23 +166,23 @@ def loadModuleByHash(self, moduleHash: str) -> None: # TODO (Will) - store these names as module consts, use one .dat only with open(os.path.join(targetDir, "type_manifest.dat"), "rb") as f: - # func_name -> typedcalltarget callTargets = SerializationContext().deserialize(f.read()) - with open(os.path.join(targetDir, "globals_manifest.dat"), "rb") as f: serializedGlobalVarDefs = SerializationContext().deserialize(f.read()) - with open(os.path.join(targetDir, "native_type_manifest.dat"), "rb") as f: functionNameToNativeType = SerializationContext().deserialize(f.read()) - with open(os.path.join(targetDir, "submodules.dat"), "rb") as f: submodules = SerializationContext().deserialize(f.read(), ListOf(str)) - with open(os.path.join(targetDir, "function_dependencies.dat"), "rb") as f: dependency_edgelist = SerializationContext().deserialize(f.read()) - with open(os.path.join(targetDir, "global_dependencies.dat"), "rb") as f: globalDependencies = SerializationContext().deserialize(f.read()) + with open(os.path.join(targetDir, "function_complexities.dat"), "rb") as f: + functionComplexities = SerializationContext().deserialize(f.read()) + with open(os.path.join(targetDir, "function_irs.dat"), "rb") as f: + functionIRs = SerializationContext().deserialize(f.read()) + with open(os.path.join(targetDir, "function_definitions.dat"), "rb") as f: + functionDefinitions = SerializationContext().deserialize(f.read()) # load the submodules first for submodule in submodules: @@ -167,7 +194,10 @@ def loadModuleByHash(self, moduleHash: str) -> None: modulePath, serializedGlobalVarDefs, functionNameToNativeType, - globalDependencies + globalDependencies, + functionComplexities, + functionIRs, + functionDefinitions ).loadFromPath(modulePath) self.loadedBinarySharedObjects[moduleHash] = loaded @@ -177,8 +207,11 @@ def loadModuleByHash(self, moduleHash: str) -> None: assert link_name not in self.targetsLoaded self.targetsLoaded[link_name] = callTarget - link_name_global_dependencies = {self._generate_link_name(x, moduleHash): y for x, y in globalDependencies.items()} + for func_name, complexity in functionComplexities.items(): + link_name = self._generate_link_name(func_name, moduleHash) + self.targetComplexity[link_name] = complexity + link_name_global_dependencies = {self._generate_link_name(x, moduleHash): y for x, y in globalDependencies.items()} assert not any(key in self.global_dependencies for key in link_name_global_dependencies) self.global_dependencies.update(link_name_global_dependencies) @@ -222,6 +255,10 @@ def addModule(self, binarySharedObject, nameToTypedCallTarget, linkDependencies, path = self.writeModuleToDisk(binarySharedObject, hashToUse, nameToTypedCallTarget, dependentHashes, link_name_dependency_edgelist) + for func_name, complexity in binarySharedObject.functionComplexities.items(): + link_name = self._generate_link_name(func_name, hashToUse) + self.targetComplexity[link_name] = complexity + self.loadedBinarySharedObjects[hashToUse] = ( binarySharedObject.loadFromPath(os.path.join(path, "module.so")) ) @@ -314,6 +351,15 @@ def writeModuleToDisk(self, binarySharedObject, hashToUse, nameToTypedCallTarget with open(os.path.join(tempTargetDir, "global_dependencies.dat"), "wb") as f: f.write(SerializationContext().serialize(binarySharedObject.globalDependencies)) + with open(os.path.join(tempTargetDir, "function_complexities.dat"), "wb") as f: + f.write(SerializationContext().serialize(binarySharedObject.functionComplexities)) + + with open(os.path.join(tempTargetDir, "function_irs.dat"), "wb") as f: + f.write(SerializationContext().serialize(binarySharedObject.functionIRs)) + + with open(os.path.join(tempTargetDir, "function_definitions.dat"), "wb") as f: + f.write(SerializationContext().serialize(binarySharedObject.serializedFunctionDefinitions)) + try: os.rename(tempTargetDir, targetDir) except IOError: diff --git a/typed_python/compiler/compiler_cache_test.py b/typed_python/compiler/compiler_cache_test.py index 639e929e..3f45cccd 100644 --- a/typed_python/compiler/compiler_cache_test.py +++ b/typed_python/compiler/compiler_cache_test.py @@ -16,6 +16,7 @@ import threading import os import pytest + from typed_python.test_util import evaluateExprInFreshProcess MAIN_MODULE = """ diff --git a/typed_python/compiler/llvm_compiler.py b/typed_python/compiler/llvm_compiler.py index f33e5edb..b9e31bd6 100644 --- a/typed_python/compiler/llvm_compiler.py +++ b/typed_python/compiler/llvm_compiler.py @@ -123,7 +123,10 @@ def buildSharedObject(self, functions): mod, serializedGlobalVariableDefinitions, module.functionNameToType, - module.globalDependencies + module.globalDependencies, + {name: self.converter.totalFunctionComplexity(name) for name in functions}, + {name: self.converter._functions_by_name[name] for name in functions}, + {name: SerializationContext().serialize(self.converter._function_definitions[name]) for name in functions}, ) def function_pointer_by_name(self, name): diff --git a/typed_python/compiler/native_ast_to_llvm.py b/typed_python/compiler/native_ast_to_llvm.py index bbd2027d..310881de 100644 --- a/typed_python/compiler/native_ast_to_llvm.py +++ b/typed_python/compiler/native_ast_to_llvm.py @@ -18,6 +18,7 @@ from typed_python.compiler.global_variable_definition import GlobalVariableDefinition from typed_python.compiler.module_definition import ModuleDefinition from typing import Dict + llvm_i8ptr = llvmlite.ir.IntType(8).as_pointer() llvm_i8 = llvmlite.ir.IntType(8) llvm_i32 = llvmlite.ir.IntType(32) @@ -642,6 +643,7 @@ def namedCallTargetToLLVM(self, target: native_ast.NamedCallTarget) -> TypedLLVM 2. The function is in function_definitions, in which case we grab the function definition and make an inlining decision. 3. We have a compiler cache, and the function is in it. We add to external_function_references. """ + assert isinstance(target, native_ast.NamedCallTarget) if target.external: if target.name not in self.external_function_references: func_type = llvmlite.ir.FunctionType( @@ -673,24 +675,29 @@ def namedCallTargetToLLVM(self, target: native_ast.NamedCallTarget) -> TypedLLVM func = self.external_function_references[target.name] else: - # TODO (Will): decide whether to inline cached code assert self.compilerCache is not None and self.compilerCache.hasSymbol(target.name) # this function is defined in a shared object that we've loaded from a prior - # invocation - if target.name not in self.external_function_references: - func_type = llvmlite.ir.FunctionType( - type_to_llvm_type(target.output_type), - [type_to_llvm_type(x) for x in target.arg_types], - var_arg=target.varargs - ) + # invocation. Again, first make an inlining decision. + if ( + self.compilerCache.complexityForSymbol(target.name) < CROSS_MODULE_INLINE_COMPLEXITY + ): + self.converter.generateDefinition(target.name) + func = self.converter.repeatFunctionInModule(target.name, self.module) + else: + if target.name not in self.external_function_references: + func_type = llvmlite.ir.FunctionType( + type_to_llvm_type(target.output_type), + [type_to_llvm_type(x) for x in target.arg_types], + var_arg=target.varargs + ) - assert target.name not in self.converter._function_definitions, target.name + assert target.name not in self.converter._function_definitions, target.name - self.external_function_references[target.name] = ( - llvmlite.ir.Function(self.module, func_type, target.name) - ) + self.external_function_references[target.name] = ( + llvmlite.ir.Function(self.module, func_type, target.name) + ) - func = self.external_function_references[target.name] + func = self.external_function_references[target.name] return TypedLLVMValue( func, @@ -1528,6 +1535,18 @@ def totalFunctionComplexity(self, name): return res + def generateDefinition(self, name: str) -> None: + """Pull the TypedCallTarget matching `name` from the cache, and use to rebuild + the function definition. Add to _function_definitions and _functions_by_name. + """ + assert self.compilerCache is not None + + definition = self.compilerCache.getDefinition(name) + llvm_func = self.compilerCache.getIR(name) + + self._functions_by_name[name] = llvm_func + self._function_definitions[name] = definition + def repeatFunctionInModule(self, name, module): """Request that the function given by 'name' be inlined into 'module'. @@ -1580,7 +1599,6 @@ def add_functions(self, names_to_definitions): [type_to_llvm_type(x[1]) for x in function.args] ) self._functions_by_name[name] = llvmlite.ir.Function(module, func_type, name) - self._functions_by_name[name].linkage = 'external' self._function_definitions[name] = function @@ -1664,6 +1682,7 @@ def add_functions(self, names_to_definitions): # want to repeat its definition in this particular module. for name in self._inlineRequests: names_to_definitions[name] = self._function_definitions[name] + self._inlineRequests.clear() # define a function that accepts a pointer and fills it out with a table of pointer values @@ -1674,7 +1693,6 @@ def add_functions(self, names_to_definitions): output=native_ast.Void, args=[native_ast.Void.pointer().pointer()] ) - return ModuleDefinition( str(module), functionTypes, diff --git a/typed_python/compiler/python_to_native_converter.py b/typed_python/compiler/python_to_native_converter.py index f8a99f40..131ece7d 100644 --- a/typed_python/compiler/python_to_native_converter.py +++ b/typed_python/compiler/python_to_native_converter.py @@ -17,7 +17,7 @@ from typed_python.hash import Hash from types import ModuleType -from typing import Dict +from typing import Dict, Optional from typed_python import Class import typed_python.python_ast as python_ast import typed_python._types as _types @@ -288,10 +288,10 @@ def deleteTarget(self, linkName): self._targets.pop(linkName) def setTarget(self, linkName, target): - assert(isinstance(target, TypedCallTarget)) + assert (isinstance(target, TypedCallTarget)) self._targets[linkName] = target - def getTarget(self, linkName) -> TypedCallTarget: + def getTarget(self, linkName) -> Optional[TypedCallTarget]: if linkName in self._targets: return self._targets[linkName]