- 
          
- 
                Notifications
    You must be signed in to change notification settings 
- Fork 151
Description
We need to extend our existing ufunc support—currently provided by the Elemwise Op—with a new Op that covers NumPy's gufuncs.  Let's call this proposed Op Blockwise.
This new Blockwise Op would allow us to generalize at least a few existing Ops, and it would provide a very direct bridge to Numba and JAX's equivalent functionality (e.g. Numba's direct support for gufuncs and JAX's vmap).  It would also allow us to implement a very convenient np.vectorize-like helper function.
The implementation details behind a Blockwise Op will likely involve a lot of the same logic as Elemwise and RandomVariable.  At a high level, the gradient methods in the former could be extended to account for non-scalar elements, while the latter demonstrates some of the relevant shape logic.
The RandomVariable Op already works similarly to gufuncs, because it supports generic "base" random variable "op"s that map distinctly shaped inputs to potentially non-scalar outputs.  A good example is the MultinomialRV Op; its gufunc-like signature would be (), (n) -> (n).
Here's an illustration:
import numba
import numpy as np
import aesara.tensor as at
#
# Existing `gufunc`-like functionality provided by `RandomVariable`
#
X_rv = at.random.multinomial(np.r_[1, 2], np.array([[0.5, 0.5], [0.4, 0.6]]))
X_rv.eval()
# array([[0, 1],
#        [2, 0]])
#
# A NumPy `vectorize` equivalent (this doesn't create a `gufunc`)
#
multinomial_vect = np.vectorize(np.random.multinomial, signature="(),(n)->(n)")
multinomial_vect(np.r_[1, 2], np.array([[0.5, 0.5], [0.4, 0.6]]))
# array([[1, 0],
#        [2, 0]])
#
# A Numba example that creates a NumPy `gufunc`
#
@numba.guvectorize([(numba.int64, numba.float64[:], numba.int64[:])], "(),(n)->(n)")
def multinomial_numba(a, b, out):
    out[:] = np.random.multinomial(a, b)
multinomial_numba(np.r_[1, 2], np.array([[0.5, 0.5], [0.4, 0.6]]))
# array([[1, 0],
#        [1, 1]])See the originating discussion here.
Metadata
Metadata
Assignees
Labels
Type
Projects
Status