Skip to content

Commit e3b127e

Browse files
authored
handleReuse: add safe flag to skip expensive call to BorrowInt (#107)
* handleReuse: add unsafe flag to skip expensive call to BorrowInt * handleReuse: add safe flag to skip expensive call to BorrowInt
1 parent d5ff158 commit e3b127e

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

dense_linalg.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ func (t *Dense) MatVecMul(other Tensor, opts ...FuncOpt) (retVal *Dense, err err
8282
// check whether retVal has the same size as the resulting matrix would be: mx1
8383
fo := ParseFuncOpts(opts...)
8484
defer returnOpOpt(fo)
85-
if retVal, err = handleReuse(fo.Reuse(), expectedShape); err != nil {
85+
if retVal, err = handleReuse(fo.Reuse(), expectedShape, fo.Safe()); err != nil {
8686
err = errors.Wrapf(err, opFail, "MatVecMul")
8787
return
8888
}
@@ -131,7 +131,7 @@ func (t *Dense) MatMul(other Tensor, opts ...FuncOpt) (retVal *Dense, err error)
131131

132132
fo := ParseFuncOpts(opts...)
133133
defer returnOpOpt(fo)
134-
if retVal, err = handleReuse(fo.Reuse(), expectedShape); err != nil {
134+
if retVal, err = handleReuse(fo.Reuse(), expectedShape, fo.Safe()); err != nil {
135135
err = errors.Wrapf(err, opFail, "MatMul")
136136
return
137137
}
@@ -170,7 +170,7 @@ func (t *Dense) Outer(other Tensor, opts ...FuncOpt) (retVal *Dense, err error)
170170

171171
fo := ParseFuncOpts(opts...)
172172
defer returnOpOpt(fo)
173-
if retVal, err = handleReuse(fo.Reuse(), expectedShape); err != nil {
173+
if retVal, err = handleReuse(fo.Reuse(), expectedShape, fo.Safe()); err != nil {
174174
err = errors.Wrapf(err, opFail, "Outer")
175175
return
176176
}
@@ -380,13 +380,15 @@ func (t *Dense) SVD(uv, full bool) (s, u, v *Dense, err error) {
380380
/* UTILITY FUNCTIONS */
381381

382382
// handleReuse extracts a *Dense from Tensor, and checks the shape of the reuse Tensor
383-
func handleReuse(reuse Tensor, expectedShape Shape) (retVal *Dense, err error) {
383+
func handleReuse(reuse Tensor, expectedShape Shape, safe bool) (retVal *Dense, err error) {
384384
if reuse != nil {
385385
if retVal, err = assertDense(reuse); err != nil {
386386
err = errors.Wrapf(err, opFail, "handling reuse")
387387
return
388388
}
389-
389+
if !safe {
390+
return
391+
}
390392
if err = reuseCheckShape(retVal, expectedShape); err != nil {
391393
err = errors.Wrapf(err, "Unable to process reuse *Dense Tensor. Shape error.")
392394
return

0 commit comments

Comments
 (0)