From 39dcf7ce57769224ff48628e33a9fd7303cef762 Mon Sep 17 00:00:00 2001 From: William Grant Date: Mon, 16 Jan 2023 18:34:43 -0500 Subject: [PATCH] Add function complexities and intermediates to cache. In order to inline cached functions we must store the native and llvm layer intermediates, along with the complexities. This will reduce cache load time but should increase compiled function speed and predictability. --- typed_python/compiler/binary_shared_object.py | 50 +++++++++++++-- typed_python/compiler/compiler_cache.py | 62 ++++++++++++++++--- typed_python/compiler/compiler_cache_test.py | 1 + typed_python/compiler/llvm_compiler.py | 5 +- typed_python/compiler/native_ast_to_llvm.py | 48 +++++++++----- .../compiler/python_to_native_converter.py | 6 +- 6 files changed, 139 insertions(+), 33 deletions(-) diff --git a/typed_python/compiler/binary_shared_object.py b/typed_python/compiler/binary_shared_object.py index 5c2e5765..beda16fd 100644 --- a/typed_python/compiler/binary_shared_object.py +++ b/typed_python/compiler/binary_shared_object.py @@ -36,17 +36,31 @@ def __init__(self, binarySharedObject, diskPath, functionPointers, serializedGlo class BinarySharedObject: """Models a shared object library (.so) loadable on linux systems.""" - def __init__(self, binaryForm, functionTypes, serializedGlobalVariableDefinitions, globalDependencies): + def __init__(self, + binaryForm, + functionTypes, + serializedGlobalVariableDefinitions, + globalDependencies, + functionComplexities, + functionIRs, + serializedFunctionDefinitions + ): """ Args: binaryForm: a bytes object containing the actual compiled code for the module serializedGlobalVariableDefinitions: a map from name to GlobalVariableDefinition - globalDependencies: a dict from function linkname to the list of global variables it depends on + globalDependencies: a dict from function name to the list of global variables it depends on + functionComplexities: a dict from function name to the total number of llvm instructions in the function (used for inlining) + functionIRs: a dict from function name to the llvm IR Functions (used for inlining) + functionDefinitions: a dict from function name to the native_ast.Functions (used for inlining) """ self.binaryForm = binaryForm self.functionTypes = functionTypes self.serializedGlobalVariableDefinitions = serializedGlobalVariableDefinitions self.globalDependencies = globalDependencies + self.functionComplexities = functionComplexities + self.functionIRs = functionIRs + self.serializedFunctionDefinitions = serializedFunctionDefinitions self.hash = sha_hash(binaryForm) @property @@ -54,14 +68,32 @@ def definedSymbols(self): return self.functionTypes.keys() @staticmethod - def fromDisk(path, serializedGlobalVariableDefinitions, functionNameToType, globalDependencies): + def fromDisk(path, + serializedGlobalVariableDefinitions, + functionNameToType, + globalDependencies, + functionComplexities, + functionIRs, + serializedFunctionDefinitions): with open(path, "rb") as f: binaryForm = f.read() - return BinarySharedObject(binaryForm, functionNameToType, serializedGlobalVariableDefinitions, globalDependencies) + return BinarySharedObject(binaryForm, + functionNameToType, + serializedGlobalVariableDefinitions, + globalDependencies, + functionComplexities, + functionIRs, + serializedFunctionDefinitions) @staticmethod - def fromModule(module, serializedGlobalVariableDefinitions, functionNameToType, globalDependencies): + def fromModule(module, + serializedGlobalVariableDefinitions, + functionNameToType, + globalDependencies, + functionComplexities, + functionIRs, + serializedFunctionDefinitions): target_triple = llvm.get_process_triple() target = llvm.Target.from_triple(target_triple) target_machine_shared_object = target.create_target_machine(reloc='pic', codemodel='default') @@ -82,7 +114,13 @@ def fromModule(module, serializedGlobalVariableDefinitions, functionNameToType, ) with open(os.path.join(tf, "module.so"), "rb") as so_file: - return BinarySharedObject(so_file.read(), functionNameToType, serializedGlobalVariableDefinitions, globalDependencies) + return BinarySharedObject(so_file.read(), + functionNameToType, + serializedGlobalVariableDefinitions, + globalDependencies, + functionComplexities, + functionIRs, + serializedFunctionDefinitions) def load(self, storageDir): """Instantiate this .so in temporary storage and return a dict from symbol -> integer function pointer""" diff --git a/typed_python/compiler/compiler_cache.py b/typed_python/compiler/compiler_cache.py index c3a193b8..90e54298 100644 --- a/typed_python/compiler/compiler_cache.py +++ b/typed_python/compiler/compiler_cache.py @@ -15,12 +15,14 @@ import os import uuid import shutil +import llvmlite.ir from typing import Optional, List from typed_python.compiler.binary_shared_object import LoadedBinarySharedObject, BinarySharedObject from typed_python.compiler.directed_graph import DirectedGraph from typed_python.compiler.typed_call_target import TypedCallTarget +import typed_python.compiler.native_ast as native_ast from typed_python.SerializationContext import SerializationContext from typed_python import Dict, ListOf @@ -67,6 +69,8 @@ def __init__(self, cacheDir): self.targetsLoaded: Dict[str, TypedCallTarget] = {} # the set of link_names for functions with linked and validated globals (i.e. ready to be run). self.targetsValidated = set() + # the total number of instructions for each link_name + self.targetComplexity = Dict(str, int)() # link_name -> link_name self.function_dependency_graph = DirectedGraph() # dict from link_name to list of global names (should be llvm keys in serialisedGlobalDefinitions) @@ -90,6 +94,21 @@ def getTarget(self, func_name: str) -> TypedCallTarget: self.loadForSymbol(link_name) return self.targetsLoaded[link_name] + def getIR(self, func_name: str) -> llvmlite.ir.Function: + if not self.hasSymbol(func_name): + raise ValueError(f'symbol not found for func_name {func_name}') + link_name = self._select_link_name(func_name) + module_hash = self.link_name_to_module_hash[link_name] + return self.loadedBinarySharedObjects[module_hash].binarySharedObject.functionIRs[func_name] + + def getDefinition(self, func_name: str) -> native_ast.Function: + if not self.hasSymbol(func_name): + raise ValueError(f'symbol not found for func_name {func_name}') + link_name = self._select_link_name(func_name) + module_hash = self.link_name_to_module_hash[link_name] + serialized_definition = self.loadedBinarySharedObjects[module_hash].binarySharedObject.serializedFunctionDefinitions[func_name] + return SerializationContext().deserialize(serialized_definition) + def _generate_link_name(self, func_name: str, module_hash: str) -> str: return func_name + "." + module_hash @@ -126,6 +145,14 @@ def loadForSymbol(self, linkName: str) -> None: if not self.loadedBinarySharedObjects[moduleHash].validateGlobalVariables(definitionsToLink): raise RuntimeError('failed to validate globals when loading:', linkName) + def complexityForSymbol(self, func_name: str) -> int: + """Get the total number of LLVM instructions for a given symbol.""" + try: + link_name = self._select_link_name(func_name) + return self.targetComplexity[link_name] + except KeyError as e: + raise ValueError(f'No complexity value cached for {func_name}') from e + def loadModuleByHash(self, moduleHash: str) -> None: """Load a module by name. @@ -139,23 +166,23 @@ def loadModuleByHash(self, moduleHash: str) -> None: # TODO (Will) - store these names as module consts, use one .dat only with open(os.path.join(targetDir, "type_manifest.dat"), "rb") as f: - # func_name -> typedcalltarget callTargets = SerializationContext().deserialize(f.read()) - with open(os.path.join(targetDir, "globals_manifest.dat"), "rb") as f: serializedGlobalVarDefs = SerializationContext().deserialize(f.read()) - with open(os.path.join(targetDir, "native_type_manifest.dat"), "rb") as f: functionNameToNativeType = SerializationContext().deserialize(f.read()) - with open(os.path.join(targetDir, "submodules.dat"), "rb") as f: submodules = SerializationContext().deserialize(f.read(), ListOf(str)) - with open(os.path.join(targetDir, "function_dependencies.dat"), "rb") as f: dependency_edgelist = SerializationContext().deserialize(f.read()) - with open(os.path.join(targetDir, "global_dependencies.dat"), "rb") as f: globalDependencies = SerializationContext().deserialize(f.read()) + with open(os.path.join(targetDir, "function_complexities.dat"), "rb") as f: + functionComplexities = SerializationContext().deserialize(f.read()) + with open(os.path.join(targetDir, "function_irs.dat"), "rb") as f: + functionIRs = SerializationContext().deserialize(f.read()) + with open(os.path.join(targetDir, "function_definitions.dat"), "rb") as f: + functionDefinitions = SerializationContext().deserialize(f.read()) # load the submodules first for submodule in submodules: @@ -167,7 +194,10 @@ def loadModuleByHash(self, moduleHash: str) -> None: modulePath, serializedGlobalVarDefs, functionNameToNativeType, - globalDependencies + globalDependencies, + functionComplexities, + functionIRs, + functionDefinitions ).loadFromPath(modulePath) self.loadedBinarySharedObjects[moduleHash] = loaded @@ -177,8 +207,11 @@ def loadModuleByHash(self, moduleHash: str) -> None: assert link_name not in self.targetsLoaded self.targetsLoaded[link_name] = callTarget - link_name_global_dependencies = {self._generate_link_name(x, moduleHash): y for x, y in globalDependencies.items()} + for func_name, complexity in functionComplexities.items(): + link_name = self._generate_link_name(func_name, moduleHash) + self.targetComplexity[link_name] = complexity + link_name_global_dependencies = {self._generate_link_name(x, moduleHash): y for x, y in globalDependencies.items()} assert not any(key in self.global_dependencies for key in link_name_global_dependencies) self.global_dependencies.update(link_name_global_dependencies) @@ -222,6 +255,10 @@ def addModule(self, binarySharedObject, nameToTypedCallTarget, linkDependencies, path = self.writeModuleToDisk(binarySharedObject, hashToUse, nameToTypedCallTarget, dependentHashes, link_name_dependency_edgelist) + for func_name, complexity in binarySharedObject.functionComplexities.items(): + link_name = self._generate_link_name(func_name, hashToUse) + self.targetComplexity[link_name] = complexity + self.loadedBinarySharedObjects[hashToUse] = ( binarySharedObject.loadFromPath(os.path.join(path, "module.so")) ) @@ -314,6 +351,15 @@ def writeModuleToDisk(self, binarySharedObject, hashToUse, nameToTypedCallTarget with open(os.path.join(tempTargetDir, "global_dependencies.dat"), "wb") as f: f.write(SerializationContext().serialize(binarySharedObject.globalDependencies)) + with open(os.path.join(tempTargetDir, "function_complexities.dat"), "wb") as f: + f.write(SerializationContext().serialize(binarySharedObject.functionComplexities)) + + with open(os.path.join(tempTargetDir, "function_irs.dat"), "wb") as f: + f.write(SerializationContext().serialize(binarySharedObject.functionIRs)) + + with open(os.path.join(tempTargetDir, "function_definitions.dat"), "wb") as f: + f.write(SerializationContext().serialize(binarySharedObject.serializedFunctionDefinitions)) + try: os.rename(tempTargetDir, targetDir) except IOError: diff --git a/typed_python/compiler/compiler_cache_test.py b/typed_python/compiler/compiler_cache_test.py index 639e929e..3f45cccd 100644 --- a/typed_python/compiler/compiler_cache_test.py +++ b/typed_python/compiler/compiler_cache_test.py @@ -16,6 +16,7 @@ import threading import os import pytest + from typed_python.test_util import evaluateExprInFreshProcess MAIN_MODULE = """ diff --git a/typed_python/compiler/llvm_compiler.py b/typed_python/compiler/llvm_compiler.py index f33e5edb..b9e31bd6 100644 --- a/typed_python/compiler/llvm_compiler.py +++ b/typed_python/compiler/llvm_compiler.py @@ -123,7 +123,10 @@ def buildSharedObject(self, functions): mod, serializedGlobalVariableDefinitions, module.functionNameToType, - module.globalDependencies + module.globalDependencies, + {name: self.converter.totalFunctionComplexity(name) for name in functions}, + {name: self.converter._functions_by_name[name] for name in functions}, + {name: SerializationContext().serialize(self.converter._function_definitions[name]) for name in functions}, ) def function_pointer_by_name(self, name): diff --git a/typed_python/compiler/native_ast_to_llvm.py b/typed_python/compiler/native_ast_to_llvm.py index bbd2027d..310881de 100644 --- a/typed_python/compiler/native_ast_to_llvm.py +++ b/typed_python/compiler/native_ast_to_llvm.py @@ -18,6 +18,7 @@ from typed_python.compiler.global_variable_definition import GlobalVariableDefinition from typed_python.compiler.module_definition import ModuleDefinition from typing import Dict + llvm_i8ptr = llvmlite.ir.IntType(8).as_pointer() llvm_i8 = llvmlite.ir.IntType(8) llvm_i32 = llvmlite.ir.IntType(32) @@ -642,6 +643,7 @@ def namedCallTargetToLLVM(self, target: native_ast.NamedCallTarget) -> TypedLLVM 2. The function is in function_definitions, in which case we grab the function definition and make an inlining decision. 3. We have a compiler cache, and the function is in it. We add to external_function_references. """ + assert isinstance(target, native_ast.NamedCallTarget) if target.external: if target.name not in self.external_function_references: func_type = llvmlite.ir.FunctionType( @@ -673,24 +675,29 @@ def namedCallTargetToLLVM(self, target: native_ast.NamedCallTarget) -> TypedLLVM func = self.external_function_references[target.name] else: - # TODO (Will): decide whether to inline cached code assert self.compilerCache is not None and self.compilerCache.hasSymbol(target.name) # this function is defined in a shared object that we've loaded from a prior - # invocation - if target.name not in self.external_function_references: - func_type = llvmlite.ir.FunctionType( - type_to_llvm_type(target.output_type), - [type_to_llvm_type(x) for x in target.arg_types], - var_arg=target.varargs - ) + # invocation. Again, first make an inlining decision. + if ( + self.compilerCache.complexityForSymbol(target.name) < CROSS_MODULE_INLINE_COMPLEXITY + ): + self.converter.generateDefinition(target.name) + func = self.converter.repeatFunctionInModule(target.name, self.module) + else: + if target.name not in self.external_function_references: + func_type = llvmlite.ir.FunctionType( + type_to_llvm_type(target.output_type), + [type_to_llvm_type(x) for x in target.arg_types], + var_arg=target.varargs + ) - assert target.name not in self.converter._function_definitions, target.name + assert target.name not in self.converter._function_definitions, target.name - self.external_function_references[target.name] = ( - llvmlite.ir.Function(self.module, func_type, target.name) - ) + self.external_function_references[target.name] = ( + llvmlite.ir.Function(self.module, func_type, target.name) + ) - func = self.external_function_references[target.name] + func = self.external_function_references[target.name] return TypedLLVMValue( func, @@ -1528,6 +1535,18 @@ def totalFunctionComplexity(self, name): return res + def generateDefinition(self, name: str) -> None: + """Pull the TypedCallTarget matching `name` from the cache, and use to rebuild + the function definition. Add to _function_definitions and _functions_by_name. + """ + assert self.compilerCache is not None + + definition = self.compilerCache.getDefinition(name) + llvm_func = self.compilerCache.getIR(name) + + self._functions_by_name[name] = llvm_func + self._function_definitions[name] = definition + def repeatFunctionInModule(self, name, module): """Request that the function given by 'name' be inlined into 'module'. @@ -1580,7 +1599,6 @@ def add_functions(self, names_to_definitions): [type_to_llvm_type(x[1]) for x in function.args] ) self._functions_by_name[name] = llvmlite.ir.Function(module, func_type, name) - self._functions_by_name[name].linkage = 'external' self._function_definitions[name] = function @@ -1664,6 +1682,7 @@ def add_functions(self, names_to_definitions): # want to repeat its definition in this particular module. for name in self._inlineRequests: names_to_definitions[name] = self._function_definitions[name] + self._inlineRequests.clear() # define a function that accepts a pointer and fills it out with a table of pointer values @@ -1674,7 +1693,6 @@ def add_functions(self, names_to_definitions): output=native_ast.Void, args=[native_ast.Void.pointer().pointer()] ) - return ModuleDefinition( str(module), functionTypes, diff --git a/typed_python/compiler/python_to_native_converter.py b/typed_python/compiler/python_to_native_converter.py index f8a99f40..131ece7d 100644 --- a/typed_python/compiler/python_to_native_converter.py +++ b/typed_python/compiler/python_to_native_converter.py @@ -17,7 +17,7 @@ from typed_python.hash import Hash from types import ModuleType -from typing import Dict +from typing import Dict, Optional from typed_python import Class import typed_python.python_ast as python_ast import typed_python._types as _types @@ -288,10 +288,10 @@ def deleteTarget(self, linkName): self._targets.pop(linkName) def setTarget(self, linkName, target): - assert(isinstance(target, TypedCallTarget)) + assert (isinstance(target, TypedCallTarget)) self._targets[linkName] = target - def getTarget(self, linkName) -> TypedCallTarget: + def getTarget(self, linkName) -> Optional[TypedCallTarget]: if linkName in self._targets: return self._targets[linkName]