Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions torch_semiring_einsum/equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,14 +350,33 @@ def __init__(self, index_map, num_extra_vars, permutation):
self.num_extra_vars = num_extra_vars
self.permutation = permutation

n = max((1 + source_index for source_index, dest_index in index_map),
default=0)
self.source_to_dest = [None] * n
for source_index, dest_index in index_map:
self.source_to_dest[source_index] = dest_index

def lookup(self, arg, var_values):
index = [_COLON] * arg.dim()
for source_index, dest_index in self.index_map:
assert(dest_index == self.permutation[arg.ndim + self.num_extra_vars - len(var_values) + source_index])
index[dest_index] = var_values[source_index]
for i in range(self.num_extra_vars):
index.append(None)
return arg[tuple(index)].permute(self.permutation)

def view(self, arg):
for i in range(self.num_extra_vars):
arg = arg.unsqueeze(-1)
return arg.permute(self.permutation)

def view_lookup(self, argv, var_values):
# TODO: generate this code in __init__ using ast
return argv[tuple(itertools.chain(
(Ellipsis,),
(_COLON if dest_index is None else var_value
for dest_index, var_value in itertools.zip_longest(self.source_to_dest, var_values))))]

def create_reduce_info(input_vars, output_vars):
r"""Pre-compile a data structure that will help reduce the variables
given in ``input_vars`` to the variables in ``output_vars``."""
Expand Down
7 changes: 5 additions & 2 deletions torch_semiring_einsum/extend.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ def semiring_einsum_forward_impl(equation, args, block_size, inputs,
add_in_place, sum_block, multiply_in_place, reduce_info,
include_indexes, output_dtypes=(None,)):

inputs_viewed = [arg_info.view(arg)
for arg, arg_info in zip(inputs, reduce_info.lookup_info)]

def generate_terms():
summed_variable_indexes = reduce_info.get_summed_variable_indexes(
equation,
Expand All @@ -154,11 +157,11 @@ def generate_terms():
# var_values is a tuple of slices.

def generate_factors():
for arg, arg_info in zip(inputs, reduce_info.lookup_info):
for argv, arg_info in zip(inputs_viewed, reduce_info.lookup_info):
# Get a slice of arg based on the current values of the
# reduced variables. The result has a shape of
# output_vars x reduced_vars.
yield arg_info.lookup(arg, var_values)
yield arg_info.view_lookup(argv, var_values)

term_size = reduce_info.get_term_size(equation, args, var_values)
# Multiply the args together.
Expand Down
12 changes: 8 additions & 4 deletions torch_semiring_einsum/log_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,10 @@ def log_einsum_backward(
for i, arg in enumerate(args):
if needs_grad[i]:
reduce_info, output_lookup_info = equation.reduce_all_to_input[i]
args_viewed = [arg_info.view(arg)
for arg, arg_info in zip(args, reduce_info.lookup_info)]
max_values_viewed = output_lookup_info.view(max_values)
C_viewed = output_lookup_info.view(C)

# In this outer loop, we need to sum over all dimensions that
# appear in the output but not in arg i. This is due to a basic
Expand All @@ -169,16 +173,16 @@ def generate_terms():
# This inner loop adds tensor slices together to get a
# term to be used in the outer loop.
def generate_factors():
for arg, arg_info in zip(args, reduce_info.lookup_info):
yield arg_info.lookup(arg, var_values)
for argv, arg_info in zip(args_viewed, reduce_info.lookup_info):
yield arg_info.view_lookup(argv, var_values)

term_size = reduce_info.get_term_size(equation, args, var_values)
term = reduce_in_place(
add_in_place,
generate_factors(),
lambda x: adjust_size(x, term_size))
# Subtract the maximum values to avoid overflow in exp().
term.sub_(output_lookup_info.lookup(max_values, var_values))
term.sub_(output_lookup_info.view_lookup(max_values_viewed, var_values))
term.exp_()
# TODO An advantage of splitting the outer loop into two
# nested loops is that this multiplication could be moved
Expand All @@ -188,7 +192,7 @@ def generate_factors():
# dimension in the output.
# If C is +inf here (because Z was 0), then this will
# result in nan, because term will be 0 and 0 * inf is nan.
term.mul_(output_lookup_info.lookup(C, var_values))
term.mul_(output_lookup_info.view_lookup(C_viewed, var_values))
yield sum_block(term, reduce_info.reduced_dims)

arg_grad = reduce_in_place(add_in_place, generate_terms())
Expand Down