Skip to content

Commit 36f6381

Browse files
authored
Assign operation (#5)
* Add basic `assign` operation Note that it doesn't properly support updates, so only unmasked accumulations work correctly * Add vector_compare and matrix_compare test utility * Create select_by_indices and rename apply_mask->select_by_mask select_by_indices is needed to properly implement assignment without accumulation * Improve `update` to work with `assign` `assign` has indices that function similar to a mask, so providing a mask during update causes a double mask challenge of figuring out what to keep, etc. Add tests to cover all the cases. Also allow scalar value in `build` to make an iso-valued tensor. * Add support for assign with scalar input Allow BinaryOps to specify 0 or 1 to indicate the output matches the dtype of input arg 0 or 1
1 parent df7e629 commit 36f6381

File tree

9 files changed

+894
-288
lines changed

9 files changed

+894
-288
lines changed

mlir_graphblas/implementations.py

Lines changed: 183 additions & 44 deletions
Large diffs are not rendered by default.

mlir_graphblas/operations.py

Lines changed: 151 additions & 31 deletions
Large diffs are not rendered by default.

mlir_graphblas/operators.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,9 @@ def __init__(self, func, *, input=None, output=None):
6060
If input is defined, it must be one of (bool, int, float) to
6161
indicate the restricted allowable input dtypes
6262
If output is defined, it must be one of (bool, int, float) to
63-
indicate the output will always be of that type
63+
indicate the output will always be of that type or an
64+
an integer (0, 1, ...) to indicate that the output dtype
65+
will match the dtype of argument 0, 1, ...
6466
"""
6567
super().__init__(func.__name__)
6668
self.func = func
@@ -70,7 +72,8 @@ def __init__(self, func, *, input=None, output=None):
7072
self.input = input
7173
# Validate output
7274
if output is not None:
73-
assert output in {bool, int, float}
75+
if type(output) is not int:
76+
assert output in {bool, int, float}
7477
self.output = output
7578

7679
@classmethod
@@ -94,9 +97,19 @@ def validate_input(self, input_val):
9497
elif self.input is float and not val_dtype.is_float():
9598
raise GrbDomainMismatch("input must be float type")
9699

97-
def get_output_type(self, input_dtype):
100+
def get_output_type(self, left_input_dtype, right_input_dtype=None):
98101
if self.output is None:
99-
return input_dtype
102+
if right_input_dtype is None:
103+
return left_input_dtype
104+
if left_input_dtype != right_input_dtype:
105+
raise TypeError(f"Unable to infer output type from {left_input_dtype} and {right_input_dtype}")
106+
return left_input_dtype
107+
elif self.output == 0:
108+
return left_input_dtype
109+
elif self.output == 1:
110+
if right_input_dtype is None:
111+
raise TypeError("No type provided for expected 2nd input argument")
112+
return right_input_dtype
100113
return self._type_convert[self.output]
101114

102115

@@ -132,6 +145,13 @@ def name_of_op(x, y, input_dtype):
132145
def __call__(self, x, y):
133146
dtype = self._dtype_of(x)
134147
dtype2 = self._dtype_of(y)
148+
if self.output == 0:
149+
self.validate_input(x)
150+
return self.func(x, y, dtype)
151+
if self.output == 1:
152+
self.validate_input(y)
153+
return self.func(x, y, dtype2)
154+
# If we reached this point, inputs must have the same dtype
135155
if dtype is not dtype2:
136156
raise TypeError(f"Types must match, {dtype} != {dtype2}")
137157
self.validate_input(x)
@@ -417,12 +437,12 @@ def oneb(x, y, dtype):
417437
BinaryOp.pair = BinaryOp.oneb
418438

419439

420-
@BinaryOp._register
440+
@BinaryOp._register(output=0) # dtype matches x
421441
def first(x, y, dtype):
422442
return x
423443

424444

425-
@BinaryOp._register
445+
@BinaryOp._register(output=1) # dtype matches y
426446
def second(x, y, dtype):
427447
return y
428448

mlir_graphblas/tensor.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -246,8 +246,8 @@ def __repr__(self):
246246
return f'Vector<{self.dtype.gb_name}, size={self.shape[0]}>'
247247

248248
@classmethod
249-
def new(cls, dtype, size: int):
250-
return cls(dtype, (size,))
249+
def new(cls, dtype, size: int, *, intermediate_result=False):
250+
return cls(dtype, (size,), intermediate_result=intermediate_result)
251251

252252
def resize(self, size: int):
253253
raise NotImplementedError()
@@ -271,6 +271,7 @@ def build(self, indices, values, *, dup=None, sparsity=None):
271271
272272
indices: list or numpy array of int
273273
values: list or numpy array with matching dtype as declared in `.new()`
274+
can also be a scalar value to make the Vector iso-valued
274275
dup: BinaryOp used to combined entries with the same index
275276
NOTE: this is currently not support; passing dup will raise an error
276277
sparsity: list of string or DimLevelType
@@ -287,7 +288,17 @@ def build(self, indices, values, *, dup=None, sparsity=None):
287288
if not isinstance(indices, np.ndarray):
288289
indices = np.array(indices, dtype=np.uint64)
289290
if not isinstance(values, np.ndarray):
290-
values = np.array(values, dtype=self.dtype.np_type)
291+
if hasattr(values, '__len__'):
292+
values = np.array(values, dtype=self.dtype.np_type)
293+
else:
294+
if type(values) is Scalar:
295+
if values.dtype != self.dtype:
296+
raise TypeError("Scalar value must have same dtype as Vector")
297+
if values.nvals() == 0:
298+
# Empty Scalar means nothing to build
299+
return
300+
values = values.extract_element()
301+
values = np.ones(indices.shape, dtype=self.dtype.np_type) * values
291302
if sparsity is None:
292303
sparsity = [DimLevelType.compressed]
293304
self._to_sparse_tensor(indices, values, sparsity=sparsity, ordering=[0])
@@ -329,8 +340,8 @@ def is_colwise(self):
329340
return tuple(self._ordering) != self.permutation
330341

331342
@classmethod
332-
def new(cls, dtype, nrows: int, ncols: int):
333-
return cls(dtype, (nrows, ncols))
343+
def new(cls, dtype, nrows: int, ncols: int, *, intermediate_result=False):
344+
return cls(dtype, (nrows, ncols), intermediate_result=intermediate_result)
334345

335346
def diag(self, k: int):
336347
raise NotImplementedError()
@@ -356,13 +367,15 @@ def nvals(self):
356367

357368
return nvals(self)
358369

359-
def build(self, row_indices, col_indices, values, *, dup=None, sparsity=None, colwise=False):
370+
def build(self, row_indices, col_indices, values, *,
371+
dup=None, sparsity=None, colwise=False):
360372
"""
361373
Build the underlying MLIRSparseTensor structure from COO.
362374
363375
row_indices: list or numpy array of int
364376
col_indices: list or numpy array of int
365377
values: list or numpy array with matching dtype as declared in `.new()`
378+
can also be a scalar value to make the Vector iso-valued
366379
dup: BinaryOp used to combined entries with the same (row, col) coordinate
367380
NOTE: this is currently not support; passing dup will raise an error
368381
sparsity: list of string or DimLevelType
@@ -383,7 +396,17 @@ def build(self, row_indices, col_indices, values, *, dup=None, sparsity=None, co
383396
col_indices = np.array(col_indices, dtype=np.uint64)
384397
indices = np.stack([row_indices, col_indices], axis=1)
385398
if not isinstance(values, np.ndarray):
386-
values = np.array(values, dtype=self.dtype.np_type)
399+
if hasattr(values, '__len__'):
400+
values = np.array(values, dtype=self.dtype.np_type)
401+
else:
402+
if type(values) is Scalar:
403+
if values.dtype != self.dtype:
404+
raise TypeError("Scalar value must have same dtype as Matrix")
405+
if values.nvals() == 0:
406+
# Empty Scalar means nothing to build
407+
return
408+
values = values.extract_element()
409+
values = np.ones(indices.shape, dtype=self.dtype.np_type) * values
387410
ordering = [1, 0] if colwise else [0, 1]
388411
if sparsity is None:
389412
sparsity = [DimLevelType.dense, DimLevelType.compressed]

0 commit comments

Comments
 (0)