Skip to content

Commit 0071a04

Browse files
authored
Clarify Semantics of how Shape and Data works. (#97)
* Fixed #90 * Starting to clarify some semantic * With the semantics clarified, the consopts need to change a bit too * Updated the semantics to make it more clear * Added an example to Dense.Data() to clarify the semantics. Added tests for certain consopts that may be breaking -race * Added mmap example for FromMemory * Fixed ap.T and clarified shapes better in ap.go Added an example for T * Fixes SelectByIndices * Unnecessary checks for 0 removed, given that ProdInts have changed in function
1 parent a47ba92 commit 0071a04

11 files changed

+258
-40
lines changed

ap.go

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,12 @@ import (
88

99
// An AP is an access pattern. It tells the various ndarrays how to access their data through the use of strides
1010
// Through the AP, there are several definitions of things, most notably there are two very specific "special cases":
11-
// Scalar has Dims() of 0. However, its shape can take several forms:
12-
// - (1, 1)
11+
// Scalar has Dims() of 0.
1312
// - (1)
13+
// Scalarlikes are higher order tensors, but each with a size of 1. The Dims() are not 0.
14+
// - (1, 1)
15+
// - (1, 1, 1)
16+
// - (1, 1, 1, 1), etc
1417
// Vector has Dims() of 1, but its shape can take several forms:
1518
// - (x, 1)
1619
// - (1, x)
@@ -121,9 +124,12 @@ func (ap *AP) IsColVec() bool { return ap.shape.IsColVec() }
121124
// IsRowVec returns true when the access pattern has the shape (1, x)
122125
func (ap *AP) IsRowVec() bool { return ap.shape.IsRowVec() }
123126

124-
// IsScalar returns true if the access pattern indicates it's a scalar value
127+
// IsScalar returns true if the access pattern indicates it's a scalar value.
125128
func (ap *AP) IsScalar() bool { return ap.shape.IsScalar() }
126129

130+
// IsScalarEquiv returns true if the access pattern is equivalent to a scalar shape.
131+
func (ap *AP) IsScalarEquiv() bool { return ap.shape.IsScalarEquiv() }
132+
127133
// IsMatrix returns true if it's a matrix. This is mostly a convenience method. RowVec and ColVecs are also considered matrices
128134
func (ap *AP) IsMatrix() bool { return len(ap.shape) == 2 }
129135

@@ -297,6 +303,7 @@ func (ap *AP) S(size int, slices ...Slice) (newAP AP, ndStart, ndEnd int, err er
297303

298304
// T returns the transposed metadata based on the given input
299305
func (ap *AP) T(axes ...int) (retVal AP, a []int, err error) {
306+
300307
// prep axes
301308
if len(axes) > 0 && len(axes) != ap.Dims() {
302309
err = errors.Errorf(dimMismatch, ap.Dims(), len(axes))
@@ -312,6 +319,10 @@ func (ap *AP) T(axes ...int) (retVal AP, a []int, err error) {
312319
}
313320
a = axes
314321

322+
if ap.shape.IsScalarEquiv() {
323+
return ap.Clone(), a, noopError{}
324+
}
325+
315326
// if axes is 0, 1, 2, 3... then no op
316327
if monotonic, incr1 := IsMonotonicInts(axes); monotonic && incr1 && axes[0] == 0 {
317328
return ap.Clone(), a, noopError{}

ap_test.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,14 @@ func TestAccessPatternIsX(t *testing.T) {
112112

113113
ap = dummyScalar1()
114114
assert.True(ap.IsScalar())
115+
assert.True(ap.IsScalarEquiv())
115116
assert.False(ap.IsVector())
116117
assert.False(ap.IsColVec())
117118
assert.False(ap.IsRowVec())
118119

119120
ap = dummyScalar2()
120-
assert.True(ap.IsScalar())
121+
assert.False(ap.IsScalar())
122+
assert.True(ap.IsScalarEquiv())
121123
assert.False(ap.IsVector())
122124
assert.False(ap.IsColVec())
123125
assert.False(ap.IsRowVec())

consopt.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,17 @@ func FromScalar(x interface{}, argMask ...[]bool) ConsOpt {
112112
xvi.Set(reflect.ValueOf(x))
113113
uptr := unsafe.Pointer(xv.Pointer())
114114

115+
var v interface{}
116+
if !tt.Shape().IsScalar() {
117+
sl := reflect.MakeSlice(reflect.SliceOf(xt), 1, 1)
118+
zeroth := sl.Index(0)
119+
zeroth.Set(reflect.ValueOf(x))
120+
v = sl.Interface()
121+
}
115122
tt.array.Ptr = uptr
116123
tt.array.L = 1
117124
tt.array.C = 1
118-
tt.v = x
125+
tt.v = v
119126
tt.t = Dtype{xt}
120127
tt.mask = mask
121128

@@ -146,7 +153,6 @@ func FromMemory(ptr uintptr, memsize uintptr) ConsOpt {
146153
switch tt := t.(type) {
147154
case *Dense:
148155
tt.v = nil // if there were any underlying slices it should be GC'd
149-
150156
tt.array.Ptr = unsafe.Pointer(ptr)
151157
tt.array.L = int(memsize / tt.t.Size())
152158
tt.array.C = int(memsize / tt.t.Size())

consopt_test.go

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
package tensor
2+
3+
import (
4+
"fmt"
5+
"io/ioutil"
6+
"os"
7+
"syscall"
8+
"testing"
9+
"testing/quick"
10+
"unsafe"
11+
12+
"github.com/stretchr/testify/assert"
13+
)
14+
15+
type F64 float64
16+
17+
func newF64(f float64) *F64 { r := F64(f); return &r }
18+
19+
func (f *F64) Uintptr() uintptr { return uintptr(unsafe.Pointer(f)) }
20+
21+
func (f *F64) MemSize() uintptr { return 8 }
22+
23+
func (f *F64) Pointer() unsafe.Pointer { return unsafe.Pointer(f) }
24+
25+
func Test_FromMemory(t *testing.T) {
26+
fn := func(F float64) bool {
27+
f := newF64(F)
28+
T := New(WithShape(), Of(Float64), FromMemory(f.Uintptr(), f.MemSize()))
29+
data := T.Data().(float64)
30+
31+
if data != F {
32+
return false
33+
}
34+
return true
35+
}
36+
if err := quick.Check(fn, &quick.Config{MaxCount: 1000000}); err != nil {
37+
t.Logf("%v", err)
38+
}
39+
40+
f, err := ioutil.TempFile("", "test")
41+
if err != nil {
42+
t.Fatal(err)
43+
}
44+
// fill in with fake data
45+
backing := make([]byte, 8*1024*1024) // 1024*1024 matrix of float64
46+
asFloats := *(*[]float64)(unsafe.Pointer(&backing))
47+
asFloats = asFloats[: 1024*1024 : 1024*1024]
48+
asFloats[0] = 3.14
49+
asFloats[2] = 6.28
50+
asFloats[1024*1024-1] = 3.14
51+
asFloats[1024*1024-3] = 6.28
52+
f.Write(backing)
53+
54+
// defer cleanup
55+
defer os.Remove(f.Name())
56+
57+
// do the mmap stuff
58+
stat, err := f.Stat()
59+
if err != nil {
60+
t.Fatal(err)
61+
}
62+
63+
size := int(stat.Size())
64+
fd := int(f.Fd())
65+
bs, err := syscall.Mmap(fd, 0, size, syscall.PROT_READ, syscall.MAP_SHARED)
66+
if err != nil {
67+
t.Fatal(err)
68+
}
69+
defer func() {
70+
if err := syscall.Munmap(bs); err != nil {
71+
t.Error(err)
72+
}
73+
}()
74+
T := New(WithShape(1024, 1024), Of(Float64), FromMemory(uintptr(unsafe.Pointer(&bs[0])), uintptr(size)))
75+
76+
s := fmt.Sprintf("%v", T)
77+
expected := `⎡3.14 0 6.28 0 ... 0 0 0 0⎤
78+
⎢ 0 0 0 0 ... 0 0 0 0⎥
79+
⎢ 0 0 0 0 ... 0 0 0 0⎥
80+
⎢ 0 0 0 0 ... 0 0 0 0⎥
81+
.
82+
.
83+
.
84+
⎢ 0 0 0 0 ... 0 0 0 0⎥
85+
⎢ 0 0 0 0 ... 0 0 0 0⎥
86+
⎢ 0 0 0 0 ... 0 0 0 0⎥
87+
⎣ 0 0 0 0 ... 0 6.28 0 3.14⎦
88+
`
89+
if s != expected {
90+
t.Errorf("Expected mmap'd tensor to be exactly the same.")
91+
}
92+
93+
assert.True(t, T.IsManuallyManaged())
94+
}

defaultengine_matop_misc.go

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,6 @@ func (StdEng) denseRepeat(t, reuse DenseTensor, newShape Shape, axis, size int,
7272
outers = 1
7373
} else {
7474
outers = ProdInts(t.Shape()[0:axis])
75-
if outers == 0 {
76-
outers = 1
77-
}
7875
}
7976

8077
var stride, newStride int

defaultengine_selbyidx.go

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,8 @@ func (e StdEng) selectByIdx(axis int, indices []int, typ reflect.Type, dataA, da
7575
axStride := apA.strides[axis]
7676
retStride := apRet.strides[axis]
7777
var outerRetStride int
78-
if outer == 0 {
78+
if axis == 0 {
7979
// then it's the outermost
80-
outer = 1
8180
outerRetStride = apRet.strides[axis] * 2
8281
} else {
8382
outerRetStride = apRet.strides[axis-1]
@@ -185,9 +184,7 @@ func (e StdEng) selectByIndicesB(axis int, indices []int, typ reflect.Type, data
185184
axStride := apB.strides[axis]
186185
retStride := apRet.strides[axis]
187186
var outerRetStride int
188-
if outer == 0 {
189-
// then it's the outermost
190-
outer = 1
187+
if axis == 0 {
191188
outerRetStride = apRet.strides[axis] * 2
192189
} else {
193190
outerRetStride = apRet.strides[axis-1]

example_dense_basics_test.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package tensor
2+
3+
import (
4+
"fmt"
5+
)
6+
7+
// Data shows how the shape of the *Dense actually affects the return value of .Data().
8+
func ExampleDense_Data() {
9+
T := New(WithShape(2, 2), WithBacking([]float64{1, 2, 3, 4}))
10+
fmt.Printf("Basics:\n======\nAny kind of arrays: %v\n", T.Data())
11+
12+
fmt.Printf("\nScalar-like\n===========\n")
13+
T = New(WithShape(), FromScalar(3.14))
14+
fmt.Printf("WithShape(), FromScalar: %v\n", T.Data())
15+
16+
T = New(WithShape(), WithBacking([]float64{3.14}))
17+
fmt.Printf("WithShape(), With a slice of 1 as backing: %v\n", T.Data())
18+
19+
T = New(WithShape(1), FromScalar(3.14))
20+
fmt.Printf("WithShape(1), With an initial scalar: %v\n", T.Data())
21+
22+
T = New(WithShape(1, 1), WithBacking([]float64{3.14}))
23+
fmt.Printf("WithShape(1, 1), With an initial scalar: %v\n", T.Data())
24+
25+
T = New(WithShape(1, 1), FromScalar(3.14))
26+
fmt.Printf("WithShape(1, 1), With an initial scalar: %v\n", T.Data())
27+
28+
T.Reshape()
29+
fmt.Printf("After reshaping to (): %v\n", T.Data())
30+
31+
// Output:
32+
// Basics:
33+
// ======
34+
// Any kind of arrays: [1 2 3 4]
35+
//
36+
// Scalar-like
37+
// ===========
38+
// WithShape(), FromScalar: 3.14
39+
// WithShape(), With a slice of 1 as backing: 3.14
40+
// WithShape(1), With an initial scalar: [3.14]
41+
// WithShape(1, 1), With an initial scalar: [3.14]
42+
// WithShape(1, 1), With an initial scalar: [3.14]
43+
// After reshaping to (): 3.14
44+
45+
}

example_dense_matop_test.go

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,3 +289,80 @@ func ExampleRepeat_uncommonUses() {
289289
// Once again, observe that the 1st element ([2 5]) has been repeated 3 times, while the rest have been repeated twice
290290

291291
}
292+
293+
func ExampleT() {
294+
// Usual example of 2D matrix being transposed:
295+
M := New(WithBacking([]int{1, 2, 3, 4, 5, 6}), WithShape(2, 3))
296+
M2, err := T(M)
297+
if err != nil {
298+
fmt.Printf("Err: %v\n", err)
299+
}
300+
fmt.Printf("M:\n%v\nM2:\n%v\n", M, M2)
301+
302+
// T accepts optional parameters describing the permutation of axes.
303+
// In a 2D case, there are only two options: (0, 1) or (1, 0).
304+
// The latter is default if no parameters are passed in.
305+
// The former is a no-op as rearranging a matrix so that the 0th axis becomes the 0th axis
306+
// and the first axis becomes the first axis is not going to do anything.
307+
//
308+
// However, note that M3 is a different result.
309+
M3, err := T(M, 0, 1)
310+
if err != nil {
311+
fmt.Printf("Err: %v\n", err)
312+
}
313+
fmt.Printf("M3:\n%v\nM == M3: %t", M3, M == M3)
314+
315+
// Output:
316+
// M:
317+
// ⎡1 2 3⎤
318+
// ⎣4 5 6⎦
319+
//
320+
// M2:
321+
// ⎡1 4⎤
322+
// ⎢2 5⎥
323+
// ⎣3 6⎦
324+
//
325+
// M3:
326+
// ⎡1 2 3⎤
327+
// ⎣4 5 6⎦
328+
//
329+
// M == M3: false
330+
331+
}
332+
333+
func ExampleT_scalarlike() {
334+
// Be aware when dealing with scalarlike tensors
335+
// scalar/scalarlikes have no effect when calling T()
336+
// but the result is put into a new tensor
337+
S := New(WithBacking([]float32{3.14}), WithShape())
338+
S2, err := T(S)
339+
if err != nil {
340+
fmt.Printf("Err %v", err)
341+
}
342+
fmt.Printf("S: %v S2 %v S == S2: %t\n", S, S2, S == S2)
343+
344+
// however do note that scalars and scalarlikes are not the same thing.
345+
// for example, consider this:
346+
_, err = T(S, 1, 0)
347+
fmt.Printf("error when the axes are more than the shape's dims: %v\n", err)
348+
349+
// but if you have a tensor that is a scalar-like:
350+
S.Reshape(1, 1)
351+
S2, err = T(S, 1, 0)
352+
if err != nil {
353+
fmt.Printf("Err: %v\n", err)
354+
}
355+
fmt.Printf("S:\n%v\nS2:\n%v\nS == S2: %t\n", S, S2, S == S2)
356+
357+
// Output:
358+
// S: 3.14 S2 3.14 S == S2: false
359+
// error when the axes are more than the shape's dims: Dimension mismatch. Expected 0, got 2
360+
// S:
361+
// ⎡3.14⎤
362+
//
363+
// S2:
364+
// ⎡3.14⎤
365+
//
366+
// S == S2: false
367+
368+
}

shape.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ func (s Shape) TotalSize() int {
2626

2727
// CalcStrides calculates the default strides for a shape
2828
func (s Shape) CalcStrides() []int {
29-
if s.IsScalar() {
29+
if s.IsScalarEquiv() {
3030
return nil
3131
}
3232

@@ -52,7 +52,7 @@ func (s Shape) CalcStrides() []int {
5252
// CalcStridesWithMask is similar to CalcStrides, except that it has an argument, masks. It is used to mask out given dimensions
5353
// during calculation of stride
5454
func (s Shape) CalcStridesWithMask(mask []bool) []int {
55-
if s.IsScalar() {
55+
if s.IsScalarEquiv() {
5656
return nil
5757
}
5858

@@ -87,7 +87,7 @@ func (s Shape) CalcStridesWithMask(mask []bool) []int {
8787

8888
// CalcStridesColMajor is like CalcStrides, but assumes a col major layout
8989
func (s Shape) CalcStridesColMajor() []int {
90-
if s.IsScalar() {
90+
if s.IsScalarEquiv() {
9191
return nil
9292
}
9393

@@ -155,7 +155,7 @@ func (s Shape) Clone() Shape {
155155

156156
// IsScalar returns true if the access pattern indicates it's a scalar value
157157
func (s Shape) IsScalar() bool {
158-
return len(s) == 0 || (len(s) == 1 && s[0] == 1)
158+
return len(s) == 0
159159
}
160160

161161
// IsScalarEquiv returns true if the access pattern indicates it's a scalar-like value

0 commit comments

Comments
 (0)