2
2
import Base. Broadcast: BroadcastStyle
3
3
using Base. Broadcast: AbstractArrayStyle, Broadcasted, DefaultArrayStyle
4
4
5
+ # combine_sizes moved from StaticArrays after https://github.com/JuliaArrays/StaticArrays.jl/pull/1008
6
+ # see also https://github.com/JuliaArrays/HybridArrays.jl/issues/50
7
+ @generated function combine_sizes (s:: Tuple{Vararg{Size}} )
8
+ sizes = [sz. parameters[1 ] for sz ∈ s. parameters]
9
+ ndims = 0
10
+ for i = 1 : length (sizes)
11
+ ndims = max (ndims, length (sizes[i]))
12
+ end
13
+ newsize = StaticArrays. StaticDimension[Dynamic () for _ = 1 : ndims]
14
+ for i = 1 : length (sizes)
15
+ s = sizes[i]
16
+ for j = 1 : length (s)
17
+ if s[j] isa Dynamic
18
+ continue
19
+ elseif newsize[j] isa Dynamic || newsize[j] == 1
20
+ newsize[j] = s[j]
21
+ elseif newsize[j] ≠ s[j] && s[j] ≠ 1
22
+ throw (DimensionMismatch (" Tried to broadcast on inputs sized $sizes " ))
23
+ end
24
+ end
25
+ end
26
+ quote
27
+ Base. @_inline_meta
28
+ Size ($ (tuple (newsize... )))
29
+ end
30
+ end
31
+
32
+ function broadcasted_index (oldsize, newindex)
33
+ index = ones (Int, length (oldsize))
34
+ for i = 1 : length (oldsize)
35
+ if oldsize[i] != 1
36
+ index[i] = newindex[i]
37
+ end
38
+ end
39
+ return LinearIndices (oldsize)[index... ]
40
+ end
41
+
42
+ scalar_getindex (x) = x
43
+ scalar_getindex (x:: Ref ) = x[]
44
+
5
45
# Add a new BroadcastStyle for StaticArrays, derived from AbstractArrayStyle
6
46
# A constructor that changes the style parameter N (array dimension) is also required
7
47
struct HybridArrayStyle{N} <: AbstractArrayStyle{N} end
@@ -22,7 +62,7 @@ BroadcastStyle(::HybridArray{M}, ::StaticArrays.StaticArrayStyle{0}) where {M} =
22
62
@inline function Base. copy (B:: Broadcasted{HybridArrayStyle{M}} ) where M
23
63
flat = Broadcast. flatten (B); as = flat. args; f = flat. f
24
64
argsizes = StaticArrays. broadcast_sizes (as... )
25
- destsize = StaticArrays . combine_sizes (argsizes)
65
+ destsize = combine_sizes (argsizes)
26
66
if Length (destsize) === Length {StaticArrays.Dynamic()} ()
27
67
# destination dimension cannot be determined statically; fall back to generic broadcast
28
68
return HybridArray {StaticArrays.size_tuple(destsize)} (copy (convert (Broadcasted{DefaultArrayStyle{M}}, B)))
35
75
@inline function _copyto! (dest, B:: Broadcasted{HybridArrayStyle{M}} ) where M
36
76
flat = Broadcast. flatten (B); as = flat. args; f = flat. f
37
77
argsizes = StaticArrays. broadcast_sizes (as... )
38
- destsize = StaticArrays . combine_sizes ((Size (dest), argsizes... ))
78
+ destsize = combine_sizes ((Size (dest), argsizes... ))
39
79
if Length (destsize) === Length {StaticArrays.Dynamic()} ()
40
80
# destination dimension cannot be determined statically; fall back to generic broadcast!
41
81
return copyto! (dest, convert (Broadcasted{DefaultArrayStyle{M}}, B))
68
108
69
109
make_expr (i) = begin
70
110
if ! (a[i] <: AbstractArray )
71
- return :(StaticArrays . scalar_getindex (a[$ i]))
111
+ return :(scalar_getindex (a[$ i]))
72
112
elseif hasdynamic (Tuple{sizes[i]. .. })
73
113
return :(a[$ i][$ (current_ind... )])
74
114
else
75
- :(a[$ i][$ (StaticArrays . broadcasted_index (sizes[i], current_ind))])
115
+ :(a[$ i][$ (broadcasted_index (sizes[i], current_ind))])
76
116
end
77
117
end
78
118
0 commit comments