Skip to content

Commit dd4ab8f

Browse files
authored
Fix#111 (#112)
* Fixed #90 * Fixed #111 * Boyscout commit to fix python/numpy testing
1 parent e3b127e commit dd4ab8f

File tree

5 files changed

+159
-18
lines changed

5 files changed

+159
-18
lines changed

defaultengine_mapreduce.go

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,15 +178,27 @@ func (e StdEng) OptimizedReduce(a Tensor, axis int, firstFn, lastFn, defaultFn,
178178
}
179179

180180
func (e StdEng) Sum(a Tensor, along ...int) (retVal Tensor, err error) {
181-
return e.reduce("Sum", execution.MonotonicSum, execution.SumMethods, a, along...)
181+
a2 := a
182+
if v, ok := a.(View); ok && v.IsMaterializable() {
183+
a2 = v.Materialize()
184+
}
185+
return e.reduce("Sum", execution.MonotonicSum, execution.SumMethods, a2, along...)
182186
}
183187

184188
func (e StdEng) Min(a Tensor, along ...int) (retVal Tensor, err error) {
185-
return e.reduce("Min", execution.MonotonicMin, execution.MinMethods, a, along...)
189+
a2 := a
190+
if v, ok := a.(View); ok && v.IsMaterializable() {
191+
a2 = v.Materialize()
192+
}
193+
return e.reduce("Min", execution.MonotonicMin, execution.MinMethods, a2, along...)
186194
}
187195

188196
func (e StdEng) Max(a Tensor, along ...int) (retVal Tensor, err error) {
189-
return e.reduce("Max", execution.MonotonicMax, execution.MaxMethods, a, along...)
197+
a2 := a
198+
if v, ok := a.(View); ok && v.IsMaterializable() {
199+
a2 = v.Materialize()
200+
}
201+
return e.reduce("Max", execution.MonotonicMax, execution.MaxMethods, a2, along...)
190202
}
191203

192204
func (e StdEng) reduce(

dense_io_test.go

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package tensor
33
import (
44
"bytes"
55
"encoding/gob"
6+
"io/ioutil"
67
"os"
78
"os/exec"
89
"regexp"
@@ -30,6 +31,19 @@ func TestSaveLoadNumpy(t *testing.T) {
3031
T1D.WriteNpy(f1D)
3132
f1D.Close()
3233

34+
defer func() {
35+
// cleanup
36+
err := os.Remove("test.npy")
37+
if err != nil {
38+
t.Error(err)
39+
}
40+
41+
err = os.Remove("test1D.npy")
42+
if err != nil {
43+
t.Error(err)
44+
}
45+
}()
46+
3347
script := "import numpy as np\nx = np.load('test.npy')\nprint(x)\nx = np.load('test1D.npy')\nprint(x)"
3448
// Configurable python command, in order to be able to use python or python3
3549
pythonCommand := os.Getenv("PYTHON_COMMAND")
@@ -42,6 +56,10 @@ func TestSaveLoadNumpy(t *testing.T) {
4256
if err != nil {
4357
t.Error(err)
4458
}
59+
stderr, err := cmd.StderrPipe()
60+
if err != nil {
61+
t.Error(err)
62+
}
4563

4664
go func() {
4765
defer stdin.Close()
@@ -56,29 +74,21 @@ func TestSaveLoadNumpy(t *testing.T) {
5674
t.Logf("Do you have a python with numpy installed? You can change the python interpreter by setting the environment variable PYTHON_COMMAND. Current value: PYTHON_COMMAND=%s", pythonCommand)
5775
}
5876

77+
importError := `ImportError: No module named numpy`
78+
slurpErr, _ := ioutil.ReadAll(stderr)
79+
if ok, _ := regexp.Match(importError, slurpErr); ok {
80+
t.Skipf("Skipping numpy test. It would appear that you do not have Numpy installed.")
81+
}
82+
5983
if err := cmd.Wait(); err != nil {
60-
t.Error(err)
84+
t.Errorf("%q", err.Error())
6185
}
6286

6387
expected := `\[\[\s*1\.\s*5\.\]\n \[\s*10\.\s*-1\.\]\]\n`
6488
if ok, _ := regexp.Match(expected, buf.Bytes()); !ok {
6589
t.Errorf("Did not successfully read numpy file, \n%q\n%q", buf.String(), expected)
6690
}
6791

68-
if buf.String() != expected {
69-
}
70-
71-
// cleanup
72-
err = os.Remove("test.npy")
73-
if err != nil {
74-
t.Error(err)
75-
}
76-
77-
err = os.Remove("test1D.npy")
78-
if err != nil {
79-
t.Error(err)
80-
}
81-
8292
// ok now to test if it can read
8393
T2 := new(Dense)
8494
buf = new(bytes.Buffer)

dense_reduction_test.go

Lines changed: 14 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

example_dense_matop_test.go

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,68 @@ func ExampleDense_Slice_viewMutation() {
8585
//
8686
}
8787

88+
func ExampleView() {
89+
// Slicing creates a "view" on the original tensor
90+
T := New(WithBacking(Range(Int, 0, 16)), WithShape(4, 4))
91+
fmt.Printf("T:\n%v\n", T)
92+
V, _ := T.Slice(makeRS(1, 3), makeRS(1, 3))
93+
fmt.Printf("V:\n%v\n", V)
94+
95+
// Now we modify V's 0th value
96+
V.(*Dense).Set(0, 1000)
97+
fmt.Printf("V[0] = 1000:\n%v\n", V)
98+
fmt.Printf("T is also mutated:\n%v\n", T)
99+
100+
// Now we materialize the views
101+
fmt.Printf("V is Materializable: %v\n", V.IsMaterializable())
102+
T2 := V.Materialize()
103+
fmt.Printf("T2 == V:\n%v\n", T2)
104+
105+
// Once materialized, it is decoupled from the original tensor
106+
T2.(*Dense).Set(0, 999)
107+
fmt.Printf("T2 is mutated:\n%v\nBut T is not mutated:\n%v\nNeither is V:\n%v", T2, T, V)
108+
// Output:
109+
// T:
110+
// ⎡ 0 1 2 3⎤
111+
// ⎢ 4 5 6 7⎥
112+
// ⎢ 8 9 10 11⎥
113+
// ⎣12 13 14 15⎦
114+
//
115+
// V:
116+
// ⎡ 5 6⎤
117+
// ⎣ 9 10⎦
118+
//
119+
// V[0] = 1000:
120+
// ⎡1000 6⎤
121+
// ⎣ 9 10⎦
122+
//
123+
// T is also mutated:
124+
// ⎡ 0 1 2 3⎤
125+
// ⎢ 4 1000 6 7⎥
126+
// ⎢ 8 9 10 11⎥
127+
// ⎣ 12 13 14 15⎦
128+
//
129+
// V is Materializable: true
130+
// T2 == V:
131+
// ⎡1000 6⎤
132+
// ⎣ 9 10⎦
133+
//
134+
// T2 is mutated:
135+
// ⎡999 6⎤
136+
// ⎣ 9 10⎦
137+
//
138+
// But T is not mutated:
139+
// ⎡ 0 1 2 3⎤
140+
// ⎢ 4 1000 6 7⎥
141+
// ⎢ 8 9 10 11⎥
142+
// ⎣ 12 13 14 15⎦
143+
//
144+
// Neither is V:
145+
// ⎡1000 6⎤
146+
// ⎣ 9 10⎦
147+
148+
}
149+
88150
func ExampleDense_Hstack() {
89151
var T, T1, T2, T3 *Dense
90152
var err error

example_mapreduce_test.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,27 @@ func ExampleSum() {
3131
// Summed along (1, 0): 6
3232
}
3333

34+
func ExampleSum_sliced() {
35+
T := New(WithBacking([]float64{0, 1, 2, 3}), WithShape(2, 2))
36+
fmt.Printf("T:\n%v\n", T)
37+
38+
V, _ := T.Slice(nil, sli(1))
39+
fmt.Printf("V:\n%v\n", V)
40+
41+
Σ, _ := Sum(V)
42+
fmt.Printf("Σ: %v", Σ)
43+
44+
// Output:
45+
// T:
46+
// ⎡0 1⎤
47+
// ⎣2 3⎦
48+
//
49+
// V:
50+
// [1 3]
51+
// Σ: 4
52+
53+
}
54+
3455
func ExampleArgmax() {
3556
T := New(WithBacking([]float64{0, 100, 200, 3}), WithShape(2, 2))
3657
fmt.Printf("T:\n%v\n", T)
@@ -49,6 +70,28 @@ func ExampleArgmax() {
4970
// Argmax is *tensor.Dense of int
5071
}
5172

73+
func ExampleArgmax_sliced() {
74+
T := New(WithBacking([]float64{0, 100, 200, 3}), WithShape(2, 2))
75+
fmt.Printf("T:\n%v\n", T)
76+
77+
// slice creates a view
78+
V, _ := T.Slice(nil, sli(1))
79+
80+
// argmax along the x-axis
81+
am, _ := Argmax(V, 0)
82+
fmt.Printf("Argmax: %v\n", am)
83+
fmt.Printf("Argmax is %T of %v", am, am.Dtype())
84+
85+
// Output:
86+
// T:
87+
// ⎡ 0 100⎤
88+
// ⎣200 3⎦
89+
//
90+
// Argmax: 0
91+
// Argmax is *tensor.Dense of int
92+
93+
}
94+
5295
func ExampleArgmin() {
5396
T := New(WithBacking([]float64{0, 100, 200, 3}), WithShape(2, 2))
5497
fmt.Printf("T:\n%v\n", T)

0 commit comments

Comments
 (0)