Skip to content

Commit d5ff158

Browse files
authored
Fix by indices bug (#106)
There was a subtle bug in `ByIndices`. The tests have also been updated to detect a wider class of bugs.
1 parent 4ce03d1 commit d5ff158

File tree

3 files changed

+111
-107
lines changed

3 files changed

+111
-107
lines changed

api_matop.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,13 +127,20 @@ func Diag(t Tensor) (retVal Tensor, err error) {
127127
// ByIndices allows for selection of value of `a` byt the indices listed in the `indices` tensor.
128128
// The `indices` tensor has to be a vector-like tensor of ints.
129129
func ByIndices(a, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) {
130+
if axis >= a.Shape().Dims() {
131+
return nil, errors.Errorf("Cannot select by indices on axis %d. Input only has %d dims", axis, a.Shape().Dims())
132+
}
130133
if sbi, ok := a.Engine().(ByIndiceser); ok {
131134
return sbi.SelectByIndices(a, indices, axis, opts...)
132135
}
133136
return nil, errors.Errorf("Unable to select by indices. Egnine %T does not support that.", a.Engine())
134137
}
135138

139+
// ByIndicesB is the backpropagation of ByIndices.
136140
func ByIndicesB(a, b, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) {
141+
if axis >= a.Shape().Dims() {
142+
return nil, errors.Errorf("Cannot select by indices on axis %d. Input only has %d dims", axis, a.Shape().Dims())
143+
}
137144
if sbi, ok := a.Engine().(ByIndiceser); ok {
138145
return sbi.SelectByIndicesB(a, b, indices, axis, opts...)
139146
}

defaultengine_selbyidx.go

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,13 @@ func (e StdEng) selectByIdx(axis int, indices []int, typ reflect.Type, dataA, da
8686
dstCoord := make([]int, apRet.shape.Dims())
8787

8888
if isInnermost {
89-
prevStride := apA.strides[axis-1]
90-
retPrevStride := apRet.strides[axis-1]
89+
prevAxis := axis - 1
90+
if prevAxis < 0 {
91+
// this may be the case if input is a vector
92+
prevAxis = 0
93+
}
94+
prevStride := apA.strides[prevAxis]
95+
retPrevStride := apRet.strides[prevAxis]
9196
for i, idx := range indices {
9297
srcCoord[axis] = idx
9398
dstCoord[axis] = i
@@ -194,8 +199,13 @@ func (e StdEng) selectByIndicesB(axis int, indices []int, typ reflect.Type, data
194199
srcCoord := make([]int, apRet.shape.Dims())
195200

196201
if isInnermost {
197-
retPrevStride := apB.strides[axis-1]
198-
prevStride := apRet.strides[axis-1]
202+
prevAxis := axis - 1
203+
if prevAxis < 0 {
204+
// this may be the case if input is a vector
205+
prevAxis = 0
206+
}
207+
retPrevStride := apB.strides[prevAxis]
208+
prevStride := apRet.strides[prevAxis]
199209
for i, idx := range indices {
200210
dstCoord[axis] = idx
201211
srcCoord[axis] = i

dense_selbyidx_test.go

Lines changed: 90 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -6,121 +6,108 @@ import (
66
"github.com/stretchr/testify/assert"
77
)
88

9-
func TestDense_SelectByIndices(t *testing.T) {
10-
assert := assert.New(t)
11-
12-
a := New(WithBacking([]float64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}), WithShape(3, 2, 4))
13-
indices := New(WithBacking([]int{1, 1}))
14-
15-
e := StdEng{}
16-
17-
a1, err := e.SelectByIndices(a, indices, 1)
18-
if err != nil {
19-
t.Errorf("%v", err)
20-
}
21-
correct1 := []float64{4, 5, 6, 7, 4, 5, 6, 7, 12, 13, 14, 15, 12, 13, 14, 15, 20, 21, 22, 23, 20, 21, 22, 23}
22-
assert.Equal(correct1, a1.Data())
23-
24-
a0, err := e.SelectByIndices(a, indices, 0)
25-
if err != nil {
26-
t.Errorf("%v", err)
27-
}
28-
correct0 := []float64{8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15}
29-
assert.Equal(correct0, a0.Data())
9+
type selByIndicesTest struct {
10+
Name string
11+
Data interface{}
12+
Shape Shape
13+
Indices []int
14+
Axis int
15+
WillErr bool
16+
17+
Correct interface{}
18+
CorrectShape Shape
19+
}
3020

31-
a2, err := e.SelectByIndices(a, indices, 2)
32-
if err != nil {
33-
t.Errorf("%v", err)
34-
}
35-
correct2 := []float64{1, 1, 5, 5, 9, 9, 13, 13, 17, 17, 21, 21}
36-
assert.Equal(correct2, a2.Data())
21+
var selByIndicesTests = []selByIndicesTest{
22+
{Name: "3-tensor, axis 0", Data: Range(Float64, 0, 24), Shape: Shape{3, 2, 4}, Indices: []int{1, 1}, Axis: 0, WillErr: false,
23+
Correct: []float64{8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15}, CorrectShape: Shape{2, 2, 4}},
3724

38-
// !safe
39-
aUnsafe := a.Clone().(*Dense)
40-
indices = New(WithBacking([]int{1, 1, 1}))
41-
aUnsafeSelect, err := e.SelectByIndices(aUnsafe, indices, 0, UseUnsafe())
42-
if err != nil {
43-
t.Errorf("%v", err)
44-
}
45-
correctUnsafe := []float64{8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15}
46-
assert.Equal(correctUnsafe, aUnsafeSelect.Data())
25+
{Name: "3-tensor, axis 1", Data: Range(Float64, 0, 24), Shape: Shape{3, 2, 4}, Indices: []int{1, 1}, Axis: 1, WillErr: false,
26+
Correct: []float64{4, 5, 6, 7, 4, 5, 6, 7, 12, 13, 14, 15, 12, 13, 14, 15, 20, 21, 22, 23, 20, 21, 22, 23}, CorrectShape: Shape{3, 2, 4}},
4727

48-
// 3 indices, just to make sure the sanity of the algorithm
49-
indices = New(WithBacking([]int{1, 1, 1}))
50-
a1, err = e.SelectByIndices(a, indices, 1)
51-
if err != nil {
52-
t.Errorf("%v", err)
53-
}
54-
correct1 = []float64{
55-
4, 5, 6, 7,
56-
4, 5, 6, 7,
57-
4, 5, 6, 7,
28+
{Name: "3-tensor, axis 2", Data: Range(Float64, 0, 24), Shape: Shape{3, 2, 4}, Indices: []int{1, 1}, Axis: 2, WillErr: false,
29+
Correct: []float64{1, 1, 5, 5, 9, 9, 13, 13, 17, 17, 21, 21}, CorrectShape: Shape{3, 2, 2}},
5830

59-
12, 13, 14, 15,
60-
12, 13, 14, 15,
61-
12, 13, 14, 15,
31+
{Name: "Vector, axis 0", Data: Range(Int, 0, 5), Shape: Shape{5}, Indices: []int{1, 1}, Axis: 0, WillErr: false,
32+
Correct: []int{1, 1}, CorrectShape: Shape{2}},
6233

63-
20, 21, 22, 23,
64-
20, 21, 22, 23,
65-
20, 21, 22, 23,
66-
}
67-
assert.Equal(correct1, a1.Data())
34+
{Name: "Vector, axis 1", Data: Range(Int, 0, 5), Shape: Shape{5}, Indices: []int{1, 1}, Axis: 1, WillErr: true,
35+
Correct: []int{1, 1}, CorrectShape: Shape{2}},
36+
{Name: "(4,2) Matrix, with (10) indices", Data: Range(Float32, 0, 8), Shape: Shape{4, 2}, Indices: []int{1, 1, 1, 1, 0, 2, 2, 2, 2, 0}, Axis: 0, WillErr: false,
37+
Correct: []float32{2, 3, 2, 3, 2, 3, 2, 3, 0, 1, 4, 5, 4, 5, 4, 5, 4, 5, 0, 1}, CorrectShape: Shape{10, 2}},
38+
{Name: "(2,1) Matrx (colvec)m with (10) indies", Data: Range(Float64, 0, 2), Shape: Shape{2, 1}, Indices: []int{1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, Axis: 0, WillErr: false,
39+
Correct: []float64{1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, CorrectShape: Shape{10},
40+
},
41+
}
6842

69-
a0, err = e.SelectByIndices(a, indices, 0)
70-
if err != nil {
71-
t.Errorf("%v", err)
43+
func TestDense_SelectByIndices(t *testing.T) {
44+
assert := assert.New(t)
45+
for i, tc := range selByIndicesTests {
46+
T := New(WithShape(tc.Shape...), WithBacking(tc.Data))
47+
indices := New(WithBacking(tc.Indices))
48+
ret, err := ByIndices(T, indices, tc.Axis)
49+
if checkErr(t, tc.WillErr, err, tc.Name, i) {
50+
continue
51+
}
52+
assert.Equal(tc.Correct, ret.Data())
53+
assert.True(tc.CorrectShape.Eq(ret.Shape()))
7254
}
73-
correct0 = []float64{8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15}
74-
assert.Equal(correct0, a0.Data())
55+
}
7556

76-
a2, err = e.SelectByIndices(a, indices, 2)
77-
if err != nil {
78-
t.Errorf("%v", err)
79-
}
80-
correct2 = []float64{1, 1, 1, 5, 5, 5, 9, 9, 9, 13, 13, 13, 17, 17, 17, 21, 21, 21}
81-
assert.Equal(correct2, a2.Data())
57+
var selByIndicesBTests = []struct {
58+
selByIndicesTest
59+
60+
CorrectGrad interface{}
61+
CorrectGradShape Shape
62+
}{
63+
{
64+
selByIndicesTest: selByIndicesTests[0],
65+
CorrectGrad: []float64{0, 0, 0, 0, 0, 0, 0, 0, 16, 18, 20, 22, 24, 26, 28, 30, 0, 0, 0, 0, 0, 0, 0, 0},
66+
CorrectGradShape: Shape{3, 2, 4},
67+
},
68+
{
69+
selByIndicesTest: selByIndicesTests[1],
70+
CorrectGrad: []float64{0, 0, 0, 0, 8, 10, 12, 14, 0, 0, 0, 0, 24, 26, 28, 30, 0, 0, 0, 0, 40, 42, 44, 46},
71+
CorrectGradShape: Shape{3, 2, 4},
72+
},
73+
{
74+
selByIndicesTest: selByIndicesTests[2],
75+
CorrectGrad: []float64{0, 2, 0, 0, 0, 10, 0, 0, 0, 18, 0, 0, 0, 26, 0, 0, 0, 34, 0, 0, 0, 42, 0, 0},
76+
CorrectGradShape: Shape{3, 2, 4},
77+
},
78+
{
79+
selByIndicesTest: selByIndicesTests[3],
80+
CorrectGrad: []int{0, 2, 0, 0, 0},
81+
CorrectGradShape: Shape{5},
82+
},
83+
{
84+
selByIndicesTest: selByIndicesTests[5],
85+
CorrectGrad: []float32{4, 6, 8, 12, 8, 12, 0, 0},
86+
CorrectGradShape: Shape{4, 2},
87+
},
88+
{
89+
selByIndicesTest: selByIndicesTests[6],
90+
CorrectGrad: []float64{0, 10},
91+
CorrectGradShape: Shape{2, 1},
92+
},
8293
}
8394

8495
func TestDense_SelectByIndicesB(t *testing.T) {
8596

86-
a := New(WithBacking([]float64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}), WithShape(3, 2, 4))
87-
indices := New(WithBacking([]int{1, 1}))
88-
89-
t.Logf("a\n%v", a)
90-
91-
e := StdEng{}
92-
93-
a1, err := e.SelectByIndices(a, indices, 1)
94-
if err != nil {
95-
t.Errorf("%v", err)
96-
}
97-
t.Logf("a1\n%v", a1)
98-
99-
a1Grad, err := e.SelectByIndicesB(a, a1, indices, 1)
100-
if err != nil {
101-
t.Errorf("%v", err)
102-
}
103-
t.Logf("a1Grad \n%v", a1Grad)
104-
105-
a0, err := e.SelectByIndices(a, indices, 0)
106-
if err != nil {
107-
t.Errorf("%v", err)
108-
}
109-
t.Logf("a0\n%v", a0)
110-
a0Grad, err := e.SelectByIndicesB(a, a0, indices, 0)
111-
if err != nil {
112-
t.Errorf("%v", err)
97+
assert := assert.New(t)
98+
for i, tc := range selByIndicesBTests {
99+
T := New(WithShape(tc.Shape...), WithBacking(tc.Data))
100+
indices := New(WithBacking(tc.Indices))
101+
ret, err := ByIndices(T, indices, tc.Axis)
102+
if checkErr(t, tc.WillErr, err, tc.Name, i) {
103+
continue
104+
}
105+
grad, err := ByIndicesB(T, ret, indices, tc.Axis)
106+
if checkErr(t, tc.WillErr, err, tc.Name, i) {
107+
continue
108+
}
109+
assert.Equal(tc.CorrectGrad, grad.Data(), "%v", tc.Name)
110+
assert.True(tc.CorrectGradShape.Eq(grad.Shape()), "%v - Grad shape should be %v. Got %v instead", tc.Name, tc.CorrectGradShape, grad.Shape())
113111
}
114-
t.Logf("a0Grad\n%v", a0Grad)
115112

116-
a2, err := e.SelectByIndices(a, indices, 2)
117-
if err != nil {
118-
t.Errorf("%v", err)
119-
}
120-
t.Logf("\n%v", a2)
121-
a2Grad, err := e.SelectByIndicesB(a, a2, indices, 2)
122-
if err != nil {
123-
t.Errorf("%v", err)
124-
}
125-
t.Logf("a2Grad\n%v", a2Grad)
126113
}

0 commit comments

Comments
 (0)