Skip to content

Commit 603fd0b

Browse files
authored
Merge pull request #51 from GPUEngineering/hf/slice-ptr-matrices
hotfix / include ptrMatrices in matrix-axis slices
2 parents 43da320 + 454b4a5 commit 603fd0b

File tree

4 files changed

+60
-10
lines changed

4 files changed

+60
-10
lines changed

CHANGELOG.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,21 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
<!-- ---------------------
9+
v1.3.2
10+
--------------------- -->
11+
## v1.3.2 - 8-11-2024
12+
13+
### Fixed
14+
15+
- When slicing a `DTensor` along `axis=2`, update the pointer to matrices
16+
- We got rid of warning `DTensor<T>::createRandomTensor`
17+
818

919
<!-- ---------------------
1020
v1.3.1
1121
--------------------- -->
12-
## v1.3.1 - 8-11-2024
22+
## v1.3.1 - 7-11-2024
1323

1424
### Fixed
1525

ci/script.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@ tests() {
3333
if [ -z "${hwInfoOrin}" ]; then
3434

3535
# -- run compute sanitizer
36-
cd ./build/test
36+
pushd ./build/test
3737
mem=$(/usr/local/cuda/bin/compute-sanitizer --tool memcheck --leak-check=full ./device_test)
3838
grep "0 errors" <<< "$mem"
39-
cd ../..
39+
popd
4040

4141
# ------------------------------------
4242
# Run example executable

include/tensor.cuh

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -560,9 +560,8 @@ DTensor<T> DTensor<T>::createRandomTensor(size_t numRows, size_t numCols, size_t
560560
auto randVec = generateIntRandomVector(numRows * numCols * numMats, low, hi);
561561
DTensor<T> a(randVec, numRows, numCols, numMats);
562562
return a;
563-
} else {
564-
throw std::invalid_argument("[createRandomTensor] unsupported type T");
565563
}
564+
throw std::invalid_argument("[createRandomTensor] unsupported type T");
566565
}
567566

568567
template<typename T>
@@ -640,11 +639,13 @@ template<typename T>
640639
DTensor<T>::DTensor(const DTensor<T> &other, size_t axis, size_t from, size_t to) {
641640
if (from > to) throw std::invalid_argument("from > to");
642641
size_t offset = 0, len = to - from + 1;
642+
m_d_ptrMatrices = nullptr;
643643
if (axis == 2) {
644644
offset = other.m_numRows * other.m_numCols * from;
645645
m_numRows = other.m_numRows;
646646
m_numCols = other.m_numCols;
647647
m_numMats = len;
648+
m_d_ptrMatrices = other.m_d_ptrMatrices + from;
648649
} else if (axis == 1) {
649650
offset = other.m_numRows * from;
650651
m_numRows = other.m_numRows;
@@ -659,10 +660,6 @@ DTensor<T>::DTensor(const DTensor<T> &other, size_t axis, size_t from, size_t to
659660
m_d_data = other.m_d_data + offset;
660661
m_doDestroyData = false;
661662
m_doDestroyPtrMatrices = false;
662-
if (axis != 2) {
663-
// m_d_ptrMatrices is not needed for vectors and matrices
664-
m_d_ptrMatrices = nullptr;
665-
}
666663
}
667664

668665
template<typename T>

test/testTensor.cu

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ void tensorSlicingConstructorAxis2() {
194194
EXPECT_EQ(3, tensSlice.numCols());
195195
EXPECT_EQ(2, tensSlice.numMats());
196196
EXPECT_EQ(tens.raw(), tensSlice.raw()); // it is indeed a slice
197+
EXPECT_TRUE(tensSlice.ptrMatrices() != nullptr);
197198
}
198199

199200
TEST_F(TensorTest, tensorSlicingConstructorAxis2) {
@@ -215,6 +216,7 @@ void tensorSlicingConstructorAxis1() {
215216
EXPECT_EQ(2, tenzSlice.numRows());
216217
EXPECT_EQ(2, tenzSlice.numCols());
217218
EXPECT_EQ(1, tenzSlice.numMats());
219+
EXPECT_TRUE(tenzSlice.ptrMatrices() == nullptr);
218220
std::vector<T> expected = {3, 4, 5, 6};
219221
std::vector<T> tenzSliceDown(4);
220222
tenzSlice.download(tenzSliceDown);
@@ -229,7 +231,7 @@ TEST_F(TensorTest, tensorSlicingConstructorAxis1) {
229231

230232
/* ---------------------------------------
231233
* Tensor: Slicing constructor
232-
* axis = 0 (columns)
234+
* axis = 0 (rows)
233235
* --------------------------------------- */
234236

235237
TEMPLATE_WITH_TYPE_T
@@ -240,6 +242,7 @@ void tensorSlicingConstructorAxis0() {
240242
EXPECT_EQ(2, tenzSlice.numRows());
241243
EXPECT_EQ(1, tenzSlice.numCols());
242244
EXPECT_EQ(1, tenzSlice.numMats());
245+
EXPECT_TRUE(tenzSlice.ptrMatrices() == nullptr);
243246
std::vector<T> expected = {3, 4};
244247
std::vector<T> tenzSliceDown(2);
245248
tenzSlice.download(tenzSliceDown);
@@ -738,6 +741,46 @@ TEST_F(TensorTest, tensorAddAB) {
738741
tensorAddAB<float>();
739742
}
740743

744+
/* ---------------------------------------
745+
* Tensor: slice ptrMatrices
746+
* axis = 2 (matrices)
747+
* --------------------------------------- */
748+
749+
TEMPLATE_WITH_TYPE_T
750+
void tensorSliceAxis2PtrMatrices() {
751+
std::vector<T> dataA = TENSOR_DATA_234A;
752+
DTensor<T> d_A(dataA, 2, 3, 4);
753+
DTensor<T> d_ASlice(d_A, 2, 2, 3);
754+
EXPECT_TRUE(d_ASlice.ptrMatrices() == d_A.ptrMatrices() + 2);
755+
}
756+
757+
TEST_F(TensorTest, tensorSliceAxis2PtrMatrices) {
758+
tensorSliceAxis2PtrMatrices<float>();
759+
tensorSliceAxis2PtrMatrices<double>();
760+
tensorSliceAxis2PtrMatrices<int>();
761+
}
762+
763+
/* ---------------------------------------
764+
* Tensor: slice ptrMatrices
765+
* axis = 0 and 1
766+
* --------------------------------------- */
767+
768+
TEMPLATE_WITH_TYPE_T
769+
void tensorSliceAxis01PtrMatrices() {
770+
std::vector<T> dataA = TENSOR_DATA_234A;
771+
DTensor<T> d_A(dataA, 2, 3, 4);
772+
DTensor<T> d_ASlice0(d_A, 0, 0, 1);
773+
EXPECT_TRUE(!d_ASlice0.ptrMatrices());
774+
DTensor<T> d_ASlice1(d_A, 1, 0, 2);
775+
EXPECT_TRUE(!d_ASlice0.ptrMatrices());
776+
}
777+
778+
TEST_F(TensorTest, tensorSliceAxis01PtrMatrices) {
779+
tensorSliceAxis01PtrMatrices<float>();
780+
tensorSliceAxis01PtrMatrices<double>();
781+
tensorSliceAxis01PtrMatrices<int>();
782+
}
783+
741784
/* ---------------------------------------
742785
* Tensor: getRows
743786
* --------------------------------------- */

0 commit comments

Comments
 (0)