Skip to content

Commit 4198ecf

Browse files
committed
Better memory management in DTensor<T>::allocateOnDevice
Free all allocated memory if allocation fails Fix code formatting in testTensor allocateOnDevice made void
1 parent 75576ac commit 4198ecf

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

include/tensor.cuh

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -202,9 +202,8 @@ private:
202202
* Allocate `size` number of `T` data on the device.
203203
* @param size number of data elements to allocate
204204
* @param zero sets allocated data to `0`
205-
* @return
206205
*/
207-
bool allocateOnDevice(size_t size, bool zero = false);
206+
void allocateOnDevice(size_t size, bool zero = false);
208207

209208
/**
210209
* Create column-major `std::vector` from a row-major one.
@@ -837,23 +836,24 @@ void DTensor<T>::applyLeftGivensRotation(size_t i, size_t j, const T *c, const T
837836
}
838837

839838
template<typename T>
840-
inline bool DTensor<T>::allocateOnDevice(size_t size, bool zero) {
841-
if (size <= 0) return false;
839+
inline void DTensor<T>::allocateOnDevice(size_t size, bool zero) {
840+
cudaError_t cudaStatus;
841+
if (size <= 0) return;
842842
destroy();
843843
m_doDestroyData = true;
844844
size_t buffer_size = size * sizeof(T);
845-
bool cudaStatus = cudaMalloc(&m_d_data, buffer_size);
846-
if (cudaStatus != cudaSuccess) return false;
845+
gpuErrChk(cudaMalloc(&m_d_data, buffer_size));
847846
if (zero) gpuErrChk(cudaMemset(m_d_data, 0, buffer_size)); // set to zero all elements
848847

849848
if (numMats() > 1) {
850849
m_doDestroyPtrMatrices = true;
851850
cudaStatus = cudaMalloc(&m_d_ptrMatrices, numMats() * sizeof(T *));
851+
if (cudaStatus != cudaSuccess) {
852+
gpuErrChk(cudaFree(m_d_data));
853+
}
852854
} else {
853855
m_doDestroyPtrMatrices = false;
854856
}
855-
856-
return (cudaStatus != cudaSuccess);
857857
}
858858

859859
template<typename T>

test/testTensor.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ TEMPLATE_WITH_TYPE_T
123123
void tensorMoveConstructor() {
124124
DTensor<T> zero(2, 3, 4, true);
125125
DTensor<T> x(std::move(zero));
126-
DTensor<T> y(DTensor < T > {100, 10, 1000});
126+
DTensor<T> y(DTensor<T> {100, 10, 1000});
127127
}
128128

129129
TEST_F(TensorTest, tensorMoveConstructor) {

0 commit comments

Comments
 (0)