@@ -6,121 +6,108 @@ import (
6
6
"github.com/stretchr/testify/assert"
7
7
)
8
8
9
- func TestDense_SelectByIndices (t * testing.T ) {
10
- assert := assert .New (t )
11
-
12
- a := New (WithBacking ([]float64 {0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 , 16 , 17 , 18 , 19 , 20 , 21 , 22 , 23 }), WithShape (3 , 2 , 4 ))
13
- indices := New (WithBacking ([]int {1 , 1 }))
14
-
15
- e := StdEng {}
16
-
17
- a1 , err := e .SelectByIndices (a , indices , 1 )
18
- if err != nil {
19
- t .Errorf ("%v" , err )
20
- }
21
- correct1 := []float64 {4 , 5 , 6 , 7 , 4 , 5 , 6 , 7 , 12 , 13 , 14 , 15 , 12 , 13 , 14 , 15 , 20 , 21 , 22 , 23 , 20 , 21 , 22 , 23 }
22
- assert .Equal (correct1 , a1 .Data ())
23
-
24
- a0 , err := e .SelectByIndices (a , indices , 0 )
25
- if err != nil {
26
- t .Errorf ("%v" , err )
27
- }
28
- correct0 := []float64 {8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 }
29
- assert .Equal (correct0 , a0 .Data ())
9
+ type selByIndicesTest struct {
10
+ Name string
11
+ Data interface {}
12
+ Shape Shape
13
+ Indices []int
14
+ Axis int
15
+ WillErr bool
16
+
17
+ Correct interface {}
18
+ CorrectShape Shape
19
+ }
30
20
31
- a2 , err := e .SelectByIndices (a , indices , 2 )
32
- if err != nil {
33
- t .Errorf ("%v" , err )
34
- }
35
- correct2 := []float64 {1 , 1 , 5 , 5 , 9 , 9 , 13 , 13 , 17 , 17 , 21 , 21 }
36
- assert .Equal (correct2 , a2 .Data ())
21
+ var selByIndicesTests = []selByIndicesTest {
22
+ {Name : "3-tensor, axis 0" , Data : Range (Float64 , 0 , 24 ), Shape : Shape {3 , 2 , 4 }, Indices : []int {1 , 1 }, Axis : 0 , WillErr : false ,
23
+ Correct : []float64 {8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 }, CorrectShape : Shape {2 , 2 , 4 }},
37
24
38
- // !safe
39
- aUnsafe := a .Clone ().(* Dense )
40
- indices = New (WithBacking ([]int {1 , 1 , 1 }))
41
- aUnsafeSelect , err := e .SelectByIndices (aUnsafe , indices , 0 , UseUnsafe ())
42
- if err != nil {
43
- t .Errorf ("%v" , err )
44
- }
45
- correctUnsafe := []float64 {8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 }
46
- assert .Equal (correctUnsafe , aUnsafeSelect .Data ())
25
+ {Name : "3-tensor, axis 1" , Data : Range (Float64 , 0 , 24 ), Shape : Shape {3 , 2 , 4 }, Indices : []int {1 , 1 }, Axis : 1 , WillErr : false ,
26
+ Correct : []float64 {4 , 5 , 6 , 7 , 4 , 5 , 6 , 7 , 12 , 13 , 14 , 15 , 12 , 13 , 14 , 15 , 20 , 21 , 22 , 23 , 20 , 21 , 22 , 23 }, CorrectShape : Shape {3 , 2 , 4 }},
47
27
48
- // 3 indices, just to make sure the sanity of the algorithm
49
- indices = New (WithBacking ([]int {1 , 1 , 1 }))
50
- a1 , err = e .SelectByIndices (a , indices , 1 )
51
- if err != nil {
52
- t .Errorf ("%v" , err )
53
- }
54
- correct1 = []float64 {
55
- 4 , 5 , 6 , 7 ,
56
- 4 , 5 , 6 , 7 ,
57
- 4 , 5 , 6 , 7 ,
28
+ {Name : "3-tensor, axis 2" , Data : Range (Float64 , 0 , 24 ), Shape : Shape {3 , 2 , 4 }, Indices : []int {1 , 1 }, Axis : 2 , WillErr : false ,
29
+ Correct : []float64 {1 , 1 , 5 , 5 , 9 , 9 , 13 , 13 , 17 , 17 , 21 , 21 }, CorrectShape : Shape {3 , 2 , 2 }},
58
30
59
- 12 , 13 , 14 , 15 ,
60
- 12 , 13 , 14 , 15 ,
61
- 12 , 13 , 14 , 15 ,
31
+ {Name : "Vector, axis 0" , Data : Range (Int , 0 , 5 ), Shape : Shape {5 }, Indices : []int {1 , 1 }, Axis : 0 , WillErr : false ,
32
+ Correct : []int {1 , 1 }, CorrectShape : Shape {2 }},
62
33
63
- 20 , 21 , 22 , 23 ,
64
- 20 , 21 , 22 , 23 ,
65
- 20 , 21 , 22 , 23 ,
66
- }
67
- assert .Equal (correct1 , a1 .Data ())
34
+ {Name : "Vector, axis 1" , Data : Range (Int , 0 , 5 ), Shape : Shape {5 }, Indices : []int {1 , 1 }, Axis : 1 , WillErr : true ,
35
+ Correct : []int {1 , 1 }, CorrectShape : Shape {2 }},
36
+ {Name : "(4,2) Matrix, with (10) indices" , Data : Range (Float32 , 0 , 8 ), Shape : Shape {4 , 2 }, Indices : []int {1 , 1 , 1 , 1 , 0 , 2 , 2 , 2 , 2 , 0 }, Axis : 0 , WillErr : false ,
37
+ Correct : []float32 {2 , 3 , 2 , 3 , 2 , 3 , 2 , 3 , 0 , 1 , 4 , 5 , 4 , 5 , 4 , 5 , 4 , 5 , 0 , 1 }, CorrectShape : Shape {10 , 2 }},
38
+ {Name : "(2,1) Matrx (colvec)m with (10) indies" , Data : Range (Float64 , 0 , 2 ), Shape : Shape {2 , 1 }, Indices : []int {1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 }, Axis : 0 , WillErr : false ,
39
+ Correct : []float64 {1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 }, CorrectShape : Shape {10 },
40
+ },
41
+ }
68
42
69
- a0 , err = e .SelectByIndices (a , indices , 0 )
70
- if err != nil {
71
- t .Errorf ("%v" , err )
43
+ func TestDense_SelectByIndices (t * testing.T ) {
44
+ assert := assert .New (t )
45
+ for i , tc := range selByIndicesTests {
46
+ T := New (WithShape (tc .Shape ... ), WithBacking (tc .Data ))
47
+ indices := New (WithBacking (tc .Indices ))
48
+ ret , err := ByIndices (T , indices , tc .Axis )
49
+ if checkErr (t , tc .WillErr , err , tc .Name , i ) {
50
+ continue
51
+ }
52
+ assert .Equal (tc .Correct , ret .Data ())
53
+ assert .True (tc .CorrectShape .Eq (ret .Shape ()))
72
54
}
73
- correct0 = []float64 {8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 }
74
- assert .Equal (correct0 , a0 .Data ())
55
+ }
75
56
76
- a2 , err = e .SelectByIndices (a , indices , 2 )
77
- if err != nil {
78
- t .Errorf ("%v" , err )
79
- }
80
- correct2 = []float64 {1 , 1 , 1 , 5 , 5 , 5 , 9 , 9 , 9 , 13 , 13 , 13 , 17 , 17 , 17 , 21 , 21 , 21 }
81
- assert .Equal (correct2 , a2 .Data ())
57
+ var selByIndicesBTests = []struct {
58
+ selByIndicesTest
59
+
60
+ CorrectGrad interface {}
61
+ CorrectGradShape Shape
62
+ }{
63
+ {
64
+ selByIndicesTest : selByIndicesTests [0 ],
65
+ CorrectGrad : []float64 {0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 16 , 18 , 20 , 22 , 24 , 26 , 28 , 30 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 },
66
+ CorrectGradShape : Shape {3 , 2 , 4 },
67
+ },
68
+ {
69
+ selByIndicesTest : selByIndicesTests [1 ],
70
+ CorrectGrad : []float64 {0 , 0 , 0 , 0 , 8 , 10 , 12 , 14 , 0 , 0 , 0 , 0 , 24 , 26 , 28 , 30 , 0 , 0 , 0 , 0 , 40 , 42 , 44 , 46 },
71
+ CorrectGradShape : Shape {3 , 2 , 4 },
72
+ },
73
+ {
74
+ selByIndicesTest : selByIndicesTests [2 ],
75
+ CorrectGrad : []float64 {0 , 2 , 0 , 0 , 0 , 10 , 0 , 0 , 0 , 18 , 0 , 0 , 0 , 26 , 0 , 0 , 0 , 34 , 0 , 0 , 0 , 42 , 0 , 0 },
76
+ CorrectGradShape : Shape {3 , 2 , 4 },
77
+ },
78
+ {
79
+ selByIndicesTest : selByIndicesTests [3 ],
80
+ CorrectGrad : []int {0 , 2 , 0 , 0 , 0 },
81
+ CorrectGradShape : Shape {5 },
82
+ },
83
+ {
84
+ selByIndicesTest : selByIndicesTests [5 ],
85
+ CorrectGrad : []float32 {4 , 6 , 8 , 12 , 8 , 12 , 0 , 0 },
86
+ CorrectGradShape : Shape {4 , 2 },
87
+ },
88
+ {
89
+ selByIndicesTest : selByIndicesTests [6 ],
90
+ CorrectGrad : []float64 {0 , 10 },
91
+ CorrectGradShape : Shape {2 , 1 },
92
+ },
82
93
}
83
94
84
95
func TestDense_SelectByIndicesB (t * testing.T ) {
85
96
86
- a := New (WithBacking ([]float64 {0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 , 16 , 17 , 18 , 19 , 20 , 21 , 22 , 23 }), WithShape (3 , 2 , 4 ))
87
- indices := New (WithBacking ([]int {1 , 1 }))
88
-
89
- t .Logf ("a\n %v" , a )
90
-
91
- e := StdEng {}
92
-
93
- a1 , err := e .SelectByIndices (a , indices , 1 )
94
- if err != nil {
95
- t .Errorf ("%v" , err )
96
- }
97
- t .Logf ("a1\n %v" , a1 )
98
-
99
- a1Grad , err := e .SelectByIndicesB (a , a1 , indices , 1 )
100
- if err != nil {
101
- t .Errorf ("%v" , err )
102
- }
103
- t .Logf ("a1Grad \n %v" , a1Grad )
104
-
105
- a0 , err := e .SelectByIndices (a , indices , 0 )
106
- if err != nil {
107
- t .Errorf ("%v" , err )
108
- }
109
- t .Logf ("a0\n %v" , a0 )
110
- a0Grad , err := e .SelectByIndicesB (a , a0 , indices , 0 )
111
- if err != nil {
112
- t .Errorf ("%v" , err )
97
+ assert := assert .New (t )
98
+ for i , tc := range selByIndicesBTests {
99
+ T := New (WithShape (tc .Shape ... ), WithBacking (tc .Data ))
100
+ indices := New (WithBacking (tc .Indices ))
101
+ ret , err := ByIndices (T , indices , tc .Axis )
102
+ if checkErr (t , tc .WillErr , err , tc .Name , i ) {
103
+ continue
104
+ }
105
+ grad , err := ByIndicesB (T , ret , indices , tc .Axis )
106
+ if checkErr (t , tc .WillErr , err , tc .Name , i ) {
107
+ continue
108
+ }
109
+ assert .Equal (tc .CorrectGrad , grad .Data (), "%v" , tc .Name )
110
+ assert .True (tc .CorrectGradShape .Eq (grad .Shape ()), "%v - Grad shape should be %v. Got %v instead" , tc .Name , tc .CorrectGradShape , grad .Shape ())
113
111
}
114
- t .Logf ("a0Grad\n %v" , a0Grad )
115
112
116
- a2 , err := e .SelectByIndices (a , indices , 2 )
117
- if err != nil {
118
- t .Errorf ("%v" , err )
119
- }
120
- t .Logf ("\n %v" , a2 )
121
- a2Grad , err := e .SelectByIndicesB (a , a2 , indices , 2 )
122
- if err != nil {
123
- t .Errorf ("%v" , err )
124
- }
125
- t .Logf ("a2Grad\n %v" , a2Grad )
126
113
}
0 commit comments