Skip to content

Commit 746a09a

Browse files
committed
Many bug fixes, particularly with CUDA
Former-commit-id: da8fd10 [formerly 77e5ca1] Former-commit-id: 982b1a6
1 parent c4bed0b commit 746a09a

File tree

4 files changed

+94
-28
lines changed

4 files changed

+94
-28
lines changed

src/librapid/VERSION.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
#ifndef LIBRAPID_VERSION
2-
#define LIBRAPID_VERSION "0.5.6"
2+
#define LIBRAPID_VERSION "0.5.7"
33
#endif

src/librapid/array/arrayBase.hpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -147,11 +147,11 @@ namespace librapid {
147147
template<typename Derived, typename Device>
148148
class ArrayBase {
149149
public:
150-
using Scalar = typename internal::traits<Derived>::Scalar;
151-
using BaseScalar = typename internal::traits<Scalar>::BaseScalar;
152-
using This = ArrayBase<Derived, Device>;
153-
using Packet = typename internal::traits<Derived>::Packet;
154-
using StorageType = typename internal::traits<Derived>::StorageType;
150+
using Scalar = typename internal::traits<Derived>::Scalar;
151+
using BaseScalar = typename internal::traits<Scalar>::BaseScalar;
152+
using This = ArrayBase<Derived, Device>;
153+
using Packet = typename internal::traits<Derived>::Packet;
154+
using StorageType = typename internal::traits<Derived>::StorageType;
155155
static constexpr ui64 Flags = internal::traits<This>::Flags;
156156

157157
friend Derived;
@@ -265,8 +265,6 @@ void castKernel({1} *dst, {2} *src, i64 size) {{
265265
size /= sizeof(BaseScalar) * 8;
266266
}
267267

268-
fmt::print("Information: {}\n", typeid(BaseScalar).name());
269-
270268
memory::memcpy<BaseScalar, D, BaseScalar, Device>(
271269
res.storage().heap(), eval().storage().heap(), size);
272270
return res;
@@ -462,7 +460,7 @@ void castKernel({1} *dst, {2} *src, i64 size) {{
462460
}
463461

464462
LR_NODISCARD("Do not ignore the result of an evaluated calculation")
465-
auto eval() const { return derived(); }
463+
auto eval() const { return derived().eval(); }
466464

467465
template<typename OtherDerived>
468466
LR_FORCE_INLINE void loadFrom(i64 index, const OtherDerived &other) {
@@ -530,6 +528,12 @@ void castKernel({1} *dst, {2} *src, i64 size) {{
530528
m_extent.str(),
531529
other.extent().str());
532530

531+
// If device differs, we need to copy the data
532+
if constexpr (!std::is_same_v<Device,
533+
typename internal::traits<OtherDerived>::Device>) {
534+
return assignLazy(other.move<Device>());
535+
}
536+
533537
using Selector = functors::AssignOp<Derived, OtherDerived>;
534538
Selector::run(derived(), other.derived());
535539
return derived();

src/librapid/cuda/memUtils.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ To IGNORE this error, just define LIBRAPID_NO_THREAD_CHECK above LibRapid includ
9898
cudaMemcpyAsync(dst, src, sizeof(T) * size, cudaMemcpyDeviceToHost, cudaStream));
9999
} else if constexpr (std::is_same_v<d, device::GPU> &&
100100
std::is_same_v<d_, device::CPU>) {
101+
// fmt::print("Info: {} {} {} {}\n", (void *) dst[0], (void *) dst[1], (void *)src[0], (void *)src[1]);
102+
101103
// Host to Device
102104
cudaSafeCall(
103105
cudaMemcpyAsync(dst, src, sizeof(T) * size, cudaMemcpyHostToDevice, cudaStream));

src/librapid/math/vector.hpp

Lines changed: 79 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -82,20 +82,71 @@ namespace librapid {
8282
}
8383

8484
LR_FORCE_INLINE
85-
VecImpl operator>(const VecImpl &other) {
85+
VecImpl cmp(const VecImpl &other, const char *mode) const {
86+
// Mode:
87+
// 0: ==
88+
// 1: !=
89+
// 2: <
90+
// 3: <=
91+
// 4: >
92+
// 5: >=
93+
8694
VecImpl res(*this);
95+
i16 modeInt = *(i16 *)mode;
96+
fmt::print("Info: {:Lb}\n", modeInt);
97+
fmt::print("Info: {:Lb}\n", ('g' << 8) | 't');
8798
for (i64 i = 0; i < Dims; ++i) {
88-
if (res[i] > other[i]) {
89-
res[i] = 1;
90-
} else {
91-
res[i] = 0;
99+
switch (modeInt) {
100+
case 'e' | ('q' << 8):
101+
if (res[i] == other[i]) {
102+
res[i] = 1;
103+
} else {
104+
res[i] = 0;
105+
}
106+
break;
107+
case 'n' | ('e' << 8):
108+
if (res[i] != other[i]) {
109+
res[i] = 1;
110+
} else {
111+
res[i] = 0;
112+
}
113+
break;
114+
case 'l' | ('t' << 8):
115+
if (res[i] < other[i]) {
116+
res[i] = 1;
117+
} else {
118+
res[i] = 0;
119+
}
120+
break;
121+
case 'l' | ('e' << 8):
122+
if (res[i] <= other[i]) {
123+
res[i] = 1;
124+
} else {
125+
res[i] = 0;
126+
}
127+
break;
128+
case 'g' | ('t' << 8):
129+
if (res[i] > other[i]) {
130+
res[i] = 1;
131+
} else {
132+
res[i] = 0;
133+
}
134+
break;
135+
case 'g' | ('e' << 8):
136+
if (res[i] >= other[i]) {
137+
res[i] = 1;
138+
} else {
139+
res[i] = 0;
140+
}
141+
break;
142+
default: LR_ASSERT(false, "Invalid mode {}", mode);
92143
}
93144
}
94145
return res;
95146
}
96147

97148
LR_FORCE_INLINE
98-
VecImpl cmp(const VecImpl &other, char mode[2]) {
149+
VecImpl cmp(const Scalar &value, const char *mode) const {
99150
// Mode:
100151
// 0: ==
101152
// 1: !=
@@ -105,46 +156,48 @@ namespace librapid {
105156
// 5: >=
106157

107158
VecImpl res(*this);
108-
i16 modeInt = (mode[1] << 8) | mode[0];
159+
i16 modeInt = *(i16 *)mode;
160+
fmt::print("Info: {:Lb}\n", modeInt);
161+
fmt::print("Info: {:Lb}\n", ('g' << 8) | 't');
109162
for (i64 i = 0; i < Dims; ++i) {
110163
switch (modeInt) {
111-
case ('e' << 8) | 'q':
112-
if (res[i] == other[i]) {
164+
case 'e' | ('q' << 8):
165+
if (res[i] == value) {
113166
res[i] = 1;
114167
} else {
115168
res[i] = 0;
116169
}
117170
break;
118-
case ('n' << 8) | 'e':
119-
if (res[i] != other[i]) {
171+
case 'n' | ('e' << 8):
172+
if (res[i] != value) {
120173
res[i] = 1;
121174
} else {
122175
res[i] = 0;
123176
}
124177
break;
125-
case ('l' << 8) | 't':
126-
if (res[i] < other[i]) {
178+
case 'l' | ('t' << 8):
179+
if (res[i] < value) {
127180
res[i] = 1;
128181
} else {
129182
res[i] = 0;
130183
}
131184
break;
132-
case ('l' << 8) | 'e':
133-
if (res[i] <= other[i]) {
185+
case 'l' | ('e' << 8):
186+
if (res[i] <= value) {
134187
res[i] = 1;
135188
} else {
136189
res[i] = 0;
137190
}
138191
break;
139-
case ('g' << 8) | 't':
140-
if (res[i] > other[i]) {
192+
case 'g' | ('t' << 8):
193+
if (res[i] > value) {
141194
res[i] = 1;
142195
} else {
143196
res[i] = 0;
144197
}
145198
break;
146-
case ('g' << 8) | 'e':
147-
if (res[i] >= other[i]) {
199+
case 'g' | ('e' << 8):
200+
if (res[i] >= value) {
148201
res[i] = 1;
149202
} else {
150203
res[i] = 0;
@@ -163,6 +216,13 @@ namespace librapid {
163216
LR_FORCE_INLINE VecImpl operator==(const VecImpl &other) const { return cmp(other, "eq"); }
164217
LR_FORCE_INLINE VecImpl operator!=(const VecImpl &other) const { return cmp(other, "ne"); }
165218

219+
LR_FORCE_INLINE VecImpl operator<(const Scalar &other) const { return cmp(other, "lt"); }
220+
LR_FORCE_INLINE VecImpl operator<=(const Scalar &other) const { return cmp(other, "le"); }
221+
LR_FORCE_INLINE VecImpl operator>(const Scalar &other) const { return cmp(other, "gt"); }
222+
LR_FORCE_INLINE VecImpl operator>=(const Scalar &other) const { return cmp(other, "ge"); }
223+
LR_FORCE_INLINE VecImpl operator==(const Scalar &other) const { return cmp(other, "eq"); }
224+
LR_FORCE_INLINE VecImpl operator!=(const Scalar &other) const { return cmp(other, "ne"); }
225+
166226
LR_NODISCARD("") LR_INLINE Scalar mag2() const { return (m_data * m_data).sum(); }
167227
LR_NODISCARD("") LR_INLINE Scalar mag() const { return ::librapid::sqrt(mag2()); }
168228
LR_NODISCARD("") LR_INLINE Scalar invMag() const { return 1.0 / mag(); }

0 commit comments

Comments
 (0)