Skip to content

Commit 650721c

Browse files
committed
add global clear_workspaces
1 parent 04ff4c4 commit 650721c

File tree

5 files changed

+26
-4
lines changed

5 files changed

+26
-4
lines changed

example/test_linear.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,5 @@ def dequantize(qweight, qzeros, scales, group_size: int = 128):
116116
print(f'abs_diff {abs_diff:4f}, '
117117
f'rel_diff {rel_diff:4f}, '
118118
f'outliers {outliers:4f}')
119+
120+
tm.Linear.clear_workspaces()

src/turbomind/api/python/bind.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,5 +342,6 @@ PYBIND11_MODULE(_turbomind_ext, m)
342342
auto _out = TorchTensorToTurbomindTensor(out);
343343
auto stream = reinterpret_cast<cudaStream_t>(stream_id);
344344
return self->forward(*_in, *_out, stream);
345-
});
345+
})
346+
.def_static("clear_workspaces", &turbomind::Linear::clearWorkspaces);
346347
}

src/turbomind/api/python/linear.cc

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,10 @@ struct Linear::Impl {
283283
}
284284
}
285285

286+
static void clearWorkspaces() {
287+
workspace_cache_.clear();
288+
}
289+
286290
private:
287291
static gemm::Workspace& getWorkspace(int device_id, cudaStream_t stream)
288292
{
@@ -295,9 +299,13 @@ struct Linear::Impl {
295299
}
296300

297301
// create a new workspace if cache missed
298-
auto workspace = std::shared_ptr<gemm::Workspace>(new gemm::Workspace, [](gemm::Workspace* p) {
299-
cudaFreeAsync(p->barriers, 0);
300-
cudaFreeAsync(p->partials, 0);
302+
auto workspace = std::shared_ptr<gemm::Workspace>(new gemm::Workspace, [device_id](gemm::Workspace* p) {
303+
int old{};
304+
check_cuda_error(cudaGetDevice(&old));
305+
check_cuda_error(cudaSetDevice(device_id));
306+
check_cuda_error(cudaFree(p->barriers));
307+
check_cuda_error(cudaFree(p->partials));
308+
check_cuda_error(cudaSetDevice(old));
301309
});
302310

303311
workspace->barriers_size = gemm::Gemm::kBarriersSize;
@@ -349,4 +357,9 @@ void Linear::forward(const Tensor& in, Tensor& out, cudaStream_t stream)
349357
{
350358
impl_->forward(in, out, stream);
351359
}
360+
361+
void Linear::clearWorkspaces() {
362+
Linear::Impl::clearWorkspaces();
363+
}
364+
352365
} // namespace turbomind

src/turbomind/api/python/linear.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ class Linear {
2828
void forward(const Tensor& in, Tensor& out, cudaStream_t stream = nullptr);
2929
~Linear() {}
3030

31+
static void clearWorkspaces();
32+
3133
private:
3234
struct Impl;
3335
std::shared_ptr<Impl> impl_;

turbomind/linear.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,3 +188,7 @@ def __call__(self, x: torch.Tensor):
188188

189189
def to_half(x: torch.Tensor):
190190
return x.to(torch.half)
191+
192+
@classmethod
193+
def clear_workspaces(cls):
194+
return _turbomind_ext.Linear.clear_workspaces()

0 commit comments

Comments
 (0)