Skip to content

Add function complexities and intermediates to cache #442

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: will-compiler-cache-partial-load-multi-module
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 44 additions & 6 deletions typed_python/compiler/binary_shared_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,32 +36,64 @@ def __init__(self, binarySharedObject, diskPath, functionPointers, serializedGlo
class BinarySharedObject:
"""Models a shared object library (.so) loadable on linux systems."""

def __init__(self, binaryForm, functionTypes, serializedGlobalVariableDefinitions, globalDependencies):
def __init__(self,
binaryForm,
functionTypes,
serializedGlobalVariableDefinitions,
globalDependencies,
functionComplexities,
functionIRs,
serializedFunctionDefinitions
):
"""
Args:
binaryForm: a bytes object containing the actual compiled code for the module
serializedGlobalVariableDefinitions: a map from name to GlobalVariableDefinition
globalDependencies: a dict from function linkname to the list of global variables it depends on
globalDependencies: a dict from function name to the list of global variables it depends on
functionComplexities: a dict from function name to the total number of llvm instructions in the function (used for inlining)
functionIRs: a dict from function name to the llvm IR Functions (used for inlining)
functionDefinitions: a dict from function name to the native_ast.Functions (used for inlining)
"""
self.binaryForm = binaryForm
self.functionTypes = functionTypes
self.serializedGlobalVariableDefinitions = serializedGlobalVariableDefinitions
self.globalDependencies = globalDependencies
self.functionComplexities = functionComplexities
self.functionIRs = functionIRs
self.serializedFunctionDefinitions = serializedFunctionDefinitions
self.hash = sha_hash(binaryForm)

@property
def definedSymbols(self):
return self.functionTypes.keys()

@staticmethod
def fromDisk(path, serializedGlobalVariableDefinitions, functionNameToType, globalDependencies):
def fromDisk(path,
serializedGlobalVariableDefinitions,
functionNameToType,
globalDependencies,
functionComplexities,
functionIRs,
serializedFunctionDefinitions):
with open(path, "rb") as f:
binaryForm = f.read()

return BinarySharedObject(binaryForm, functionNameToType, serializedGlobalVariableDefinitions, globalDependencies)
return BinarySharedObject(binaryForm,
functionNameToType,
serializedGlobalVariableDefinitions,
globalDependencies,
functionComplexities,
functionIRs,
serializedFunctionDefinitions)

@staticmethod
def fromModule(module, serializedGlobalVariableDefinitions, functionNameToType, globalDependencies):
def fromModule(module,
serializedGlobalVariableDefinitions,
functionNameToType,
globalDependencies,
functionComplexities,
functionIRs,
serializedFunctionDefinitions):
target_triple = llvm.get_process_triple()
target = llvm.Target.from_triple(target_triple)
target_machine_shared_object = target.create_target_machine(reloc='pic', codemodel='default')
Expand All @@ -82,7 +114,13 @@ def fromModule(module, serializedGlobalVariableDefinitions, functionNameToType,
)

with open(os.path.join(tf, "module.so"), "rb") as so_file:
return BinarySharedObject(so_file.read(), functionNameToType, serializedGlobalVariableDefinitions, globalDependencies)
return BinarySharedObject(so_file.read(),
functionNameToType,
serializedGlobalVariableDefinitions,
globalDependencies,
functionComplexities,
functionIRs,
serializedFunctionDefinitions)

def load(self, storageDir):
"""Instantiate this .so in temporary storage and return a dict from symbol -> integer function pointer"""
Expand Down
62 changes: 54 additions & 8 deletions typed_python/compiler/compiler_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
import os
import uuid
import shutil
import llvmlite.ir

from typing import Optional, List

from typed_python.compiler.binary_shared_object import LoadedBinarySharedObject, BinarySharedObject
from typed_python.compiler.directed_graph import DirectedGraph
from typed_python.compiler.typed_call_target import TypedCallTarget
import typed_python.compiler.native_ast as native_ast
from typed_python.SerializationContext import SerializationContext
from typed_python import Dict, ListOf

Expand Down Expand Up @@ -67,6 +69,8 @@ def __init__(self, cacheDir):
self.targetsLoaded: Dict[str, TypedCallTarget] = {}
# the set of link_names for functions with linked and validated globals (i.e. ready to be run).
self.targetsValidated = set()
# the total number of instructions for each link_name
self.targetComplexity = Dict(str, int)()
# link_name -> link_name
self.function_dependency_graph = DirectedGraph()
# dict from link_name to list of global names (should be llvm keys in serialisedGlobalDefinitions)
Expand All @@ -90,6 +94,21 @@ def getTarget(self, func_name: str) -> TypedCallTarget:
self.loadForSymbol(link_name)
return self.targetsLoaded[link_name]

def getIR(self, func_name: str) -> llvmlite.ir.Function:
if not self.hasSymbol(func_name):
raise ValueError(f'symbol not found for func_name {func_name}')
link_name = self._select_link_name(func_name)
module_hash = self.link_name_to_module_hash[link_name]
return self.loadedBinarySharedObjects[module_hash].binarySharedObject.functionIRs[func_name]

def getDefinition(self, func_name: str) -> native_ast.Function:
if not self.hasSymbol(func_name):
raise ValueError(f'symbol not found for func_name {func_name}')
link_name = self._select_link_name(func_name)
module_hash = self.link_name_to_module_hash[link_name]
serialized_definition = self.loadedBinarySharedObjects[module_hash].binarySharedObject.serializedFunctionDefinitions[func_name]
return SerializationContext().deserialize(serialized_definition)

def _generate_link_name(self, func_name: str, module_hash: str) -> str:
return func_name + "." + module_hash

Expand Down Expand Up @@ -126,6 +145,14 @@ def loadForSymbol(self, linkName: str) -> None:
if not self.loadedBinarySharedObjects[moduleHash].validateGlobalVariables(definitionsToLink):
raise RuntimeError('failed to validate globals when loading:', linkName)

def complexityForSymbol(self, func_name: str) -> int:
"""Get the total number of LLVM instructions for a given symbol."""
try:
link_name = self._select_link_name(func_name)
return self.targetComplexity[link_name]
except KeyError as e:
raise ValueError(f'No complexity value cached for {func_name}') from e

def loadModuleByHash(self, moduleHash: str) -> None:
"""Load a module by name.

Expand All @@ -139,23 +166,23 @@ def loadModuleByHash(self, moduleHash: str) -> None:

# TODO (Will) - store these names as module consts, use one .dat only
with open(os.path.join(targetDir, "type_manifest.dat"), "rb") as f:
# func_name -> typedcalltarget
callTargets = SerializationContext().deserialize(f.read())

with open(os.path.join(targetDir, "globals_manifest.dat"), "rb") as f:
serializedGlobalVarDefs = SerializationContext().deserialize(f.read())

with open(os.path.join(targetDir, "native_type_manifest.dat"), "rb") as f:
functionNameToNativeType = SerializationContext().deserialize(f.read())

with open(os.path.join(targetDir, "submodules.dat"), "rb") as f:
submodules = SerializationContext().deserialize(f.read(), ListOf(str))

with open(os.path.join(targetDir, "function_dependencies.dat"), "rb") as f:
dependency_edgelist = SerializationContext().deserialize(f.read())

with open(os.path.join(targetDir, "global_dependencies.dat"), "rb") as f:
globalDependencies = SerializationContext().deserialize(f.read())
with open(os.path.join(targetDir, "function_complexities.dat"), "rb") as f:
functionComplexities = SerializationContext().deserialize(f.read())
with open(os.path.join(targetDir, "function_irs.dat"), "rb") as f:
functionIRs = SerializationContext().deserialize(f.read())
with open(os.path.join(targetDir, "function_definitions.dat"), "rb") as f:
functionDefinitions = SerializationContext().deserialize(f.read())

# load the submodules first
for submodule in submodules:
Expand All @@ -167,7 +194,10 @@ def loadModuleByHash(self, moduleHash: str) -> None:
modulePath,
serializedGlobalVarDefs,
functionNameToNativeType,
globalDependencies
globalDependencies,
functionComplexities,
functionIRs,
functionDefinitions
).loadFromPath(modulePath)

self.loadedBinarySharedObjects[moduleHash] = loaded
Expand All @@ -177,8 +207,11 @@ def loadModuleByHash(self, moduleHash: str) -> None:
assert link_name not in self.targetsLoaded
self.targetsLoaded[link_name] = callTarget

link_name_global_dependencies = {self._generate_link_name(x, moduleHash): y for x, y in globalDependencies.items()}
for func_name, complexity in functionComplexities.items():
link_name = self._generate_link_name(func_name, moduleHash)
self.targetComplexity[link_name] = complexity

link_name_global_dependencies = {self._generate_link_name(x, moduleHash): y for x, y in globalDependencies.items()}
assert not any(key in self.global_dependencies for key in link_name_global_dependencies)

self.global_dependencies.update(link_name_global_dependencies)
Expand Down Expand Up @@ -222,6 +255,10 @@ def addModule(self, binarySharedObject, nameToTypedCallTarget, linkDependencies,

path = self.writeModuleToDisk(binarySharedObject, hashToUse, nameToTypedCallTarget, dependentHashes, link_name_dependency_edgelist)

for func_name, complexity in binarySharedObject.functionComplexities.items():
link_name = self._generate_link_name(func_name, hashToUse)
self.targetComplexity[link_name] = complexity

self.loadedBinarySharedObjects[hashToUse] = (
binarySharedObject.loadFromPath(os.path.join(path, "module.so"))
)
Expand Down Expand Up @@ -314,6 +351,15 @@ def writeModuleToDisk(self, binarySharedObject, hashToUse, nameToTypedCallTarget
with open(os.path.join(tempTargetDir, "global_dependencies.dat"), "wb") as f:
f.write(SerializationContext().serialize(binarySharedObject.globalDependencies))

with open(os.path.join(tempTargetDir, "function_complexities.dat"), "wb") as f:
f.write(SerializationContext().serialize(binarySharedObject.functionComplexities))

with open(os.path.join(tempTargetDir, "function_irs.dat"), "wb") as f:
f.write(SerializationContext().serialize(binarySharedObject.functionIRs))

with open(os.path.join(tempTargetDir, "function_definitions.dat"), "wb") as f:
f.write(SerializationContext().serialize(binarySharedObject.serializedFunctionDefinitions))

try:
os.rename(tempTargetDir, targetDir)
except IOError:
Expand Down
1 change: 1 addition & 0 deletions typed_python/compiler/compiler_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import threading
import os
import pytest

from typed_python.test_util import evaluateExprInFreshProcess

MAIN_MODULE = """
Expand Down
5 changes: 4 additions & 1 deletion typed_python/compiler/llvm_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,10 @@ def buildSharedObject(self, functions):
mod,
serializedGlobalVariableDefinitions,
module.functionNameToType,
module.globalDependencies
module.globalDependencies,
{name: self.converter.totalFunctionComplexity(name) for name in functions},
{name: self.converter._functions_by_name[name] for name in functions},
{name: SerializationContext().serialize(self.converter._function_definitions[name]) for name in functions},
)

def function_pointer_by_name(self, name):
Expand Down
48 changes: 33 additions & 15 deletions typed_python/compiler/native_ast_to_llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typed_python.compiler.global_variable_definition import GlobalVariableDefinition
from typed_python.compiler.module_definition import ModuleDefinition
from typing import Dict

llvm_i8ptr = llvmlite.ir.IntType(8).as_pointer()
llvm_i8 = llvmlite.ir.IntType(8)
llvm_i32 = llvmlite.ir.IntType(32)
Expand Down Expand Up @@ -642,6 +643,7 @@ def namedCallTargetToLLVM(self, target: native_ast.NamedCallTarget) -> TypedLLVM
2. The function is in function_definitions, in which case we grab the function definition and make an inlining decision.
3. We have a compiler cache, and the function is in it. We add to external_function_references.
"""
assert isinstance(target, native_ast.NamedCallTarget)
if target.external:
if target.name not in self.external_function_references:
func_type = llvmlite.ir.FunctionType(
Expand Down Expand Up @@ -673,24 +675,29 @@ def namedCallTargetToLLVM(self, target: native_ast.NamedCallTarget) -> TypedLLVM

func = self.external_function_references[target.name]
else:
# TODO (Will): decide whether to inline cached code
assert self.compilerCache is not None and self.compilerCache.hasSymbol(target.name)
# this function is defined in a shared object that we've loaded from a prior
# invocation
if target.name not in self.external_function_references:
func_type = llvmlite.ir.FunctionType(
type_to_llvm_type(target.output_type),
[type_to_llvm_type(x) for x in target.arg_types],
var_arg=target.varargs
)
# invocation. Again, first make an inlining decision.
if (
self.compilerCache.complexityForSymbol(target.name) < CROSS_MODULE_INLINE_COMPLEXITY
):
self.converter.generateDefinition(target.name)
func = self.converter.repeatFunctionInModule(target.name, self.module)
else:
if target.name not in self.external_function_references:
func_type = llvmlite.ir.FunctionType(
type_to_llvm_type(target.output_type),
[type_to_llvm_type(x) for x in target.arg_types],
var_arg=target.varargs
)

assert target.name not in self.converter._function_definitions, target.name
assert target.name not in self.converter._function_definitions, target.name

self.external_function_references[target.name] = (
llvmlite.ir.Function(self.module, func_type, target.name)
)
self.external_function_references[target.name] = (
llvmlite.ir.Function(self.module, func_type, target.name)
)

func = self.external_function_references[target.name]
func = self.external_function_references[target.name]

return TypedLLVMValue(
func,
Expand Down Expand Up @@ -1528,6 +1535,18 @@ def totalFunctionComplexity(self, name):

return res

def generateDefinition(self, name: str) -> None:
"""Pull the TypedCallTarget matching `name` from the cache, and use to rebuild
the function definition. Add to _function_definitions and _functions_by_name.
"""
assert self.compilerCache is not None

definition = self.compilerCache.getDefinition(name)
llvm_func = self.compilerCache.getIR(name)

self._functions_by_name[name] = llvm_func
self._function_definitions[name] = definition

def repeatFunctionInModule(self, name, module):
"""Request that the function given by 'name' be inlined into 'module'.

Expand Down Expand Up @@ -1580,7 +1599,6 @@ def add_functions(self, names_to_definitions):
[type_to_llvm_type(x[1]) for x in function.args]
)
self._functions_by_name[name] = llvmlite.ir.Function(module, func_type, name)

self._functions_by_name[name].linkage = 'external'
self._function_definitions[name] = function

Expand Down Expand Up @@ -1664,6 +1682,7 @@ def add_functions(self, names_to_definitions):
# want to repeat its definition in this particular module.
for name in self._inlineRequests:
names_to_definitions[name] = self._function_definitions[name]

self._inlineRequests.clear()

# define a function that accepts a pointer and fills it out with a table of pointer values
Expand All @@ -1674,7 +1693,6 @@ def add_functions(self, names_to_definitions):
output=native_ast.Void,
args=[native_ast.Void.pointer().pointer()]
)

return ModuleDefinition(
str(module),
functionTypes,
Expand Down
6 changes: 3 additions & 3 deletions typed_python/compiler/python_to_native_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from typed_python.hash import Hash
from types import ModuleType
from typing import Dict
from typing import Dict, Optional
from typed_python import Class
import typed_python.python_ast as python_ast
import typed_python._types as _types
Expand Down Expand Up @@ -288,10 +288,10 @@ def deleteTarget(self, linkName):
self._targets.pop(linkName)

def setTarget(self, linkName, target):
assert(isinstance(target, TypedCallTarget))
assert (isinstance(target, TypedCallTarget))
self._targets[linkName] = target

def getTarget(self, linkName) -> TypedCallTarget:
def getTarget(self, linkName) -> Optional[TypedCallTarget]:
if linkName in self._targets:
return self._targets[linkName]

Expand Down