Skip to content

Commit 81e8a5f

Browse files
authored
Merge pull request #16 from teddykoker/fix15
Fix CUDA Leak + Input Validation
2 parents d480043 + 7fdfa94 commit 81e8a5f

File tree

3 files changed

+11
-7
lines changed

3 files changed

+11
-7
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ pip install torchsort
2020
To build the CUDA extension you will need the CUDA toolchain installed. If you
2121
want to build in an environment without a CUDA runtime (e.g. docker), you will
2222
need to export the environment variable
23-
`TORCH_CUDA_ARCH_LIST="Pascal;Volta;Turing"` before installing.
23+
`TORCH_CUDA_ARCH_LIST="Pascal;Volta;Turing;Ampere"` before installing.
2424

2525
## Usage
2626

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def ext_modules():
5151

5252
setup(
5353
name="torchsort",
54-
version="0.1.2",
54+
version="0.1.3",
5555
description="Differentiable sorting and ranking in PyTorch",
5656
author="Teddy Koker",
5757
url="https://github.com/teddykoker/torchsort",

torchsort/ops.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,16 @@
3535
def soft_rank(values, regularization="l2", regularization_strength=1.0):
3636
if len(values.shape) != 2:
3737
raise ValueError(f"'values' should be a 2d-tensor but got {values.shape}")
38+
if regularization not in ["l2", "kl"]:
39+
raise ValueError(f"'regularization' should be a 'l2' or 'kl'")
3840
return SoftRank.apply(values, regularization, regularization_strength)
3941

4042

4143
def soft_sort(values, regularization="l2", regularization_strength=1.0):
4244
if len(values.shape) != 2:
4345
raise ValueError(f"'values' should be a 2d-tensor but got {values.shape}")
46+
if regularization not in ["l2", "kl"]:
47+
raise ValueError(f"'regularization' should be a 'l2' or 'kl'")
4448
return SoftSort.apply(values, regularization, regularization_strength)
4549

4650

@@ -90,19 +94,19 @@ def forward(ctx, tensor, regularization="l2", regularization_strength=1.0):
9094
if ctx.regularization == "l2":
9195
dual_sol = isotonic_l2[s.device.type](s - w)
9296
ret = (s - dual_sol).gather(1, inv_permutation)
93-
ctx.factor = 1.0
97+
factor = torch.tensor(1.0, device=s.device)
9498
else:
9599
dual_sol = isotonic_kl[s.device.type](s, torch.log(w))
96100
ret = torch.exp((s - dual_sol).gather(1, inv_permutation))
97-
ctx.factor = ret
101+
factor = ret
98102

99-
ctx.save_for_backward(s, dual_sol, permutation, inv_permutation)
103+
ctx.save_for_backward(factor, s, dual_sol, permutation, inv_permutation)
100104
return ret
101105

102106
@staticmethod
103107
def backward(ctx, grad_output):
104-
grad = (grad_output * ctx.factor).clone()
105-
s, dual_sol, permutation, inv_permutation = ctx.saved_tensors
108+
factor, s, dual_sol, permutation, inv_permutation = ctx.saved_tensors
109+
grad = (grad_output * factor).clone()
106110
if ctx.regularization == "l2":
107111
grad -= isotonic_l2_backward[s.device.type](
108112
s, dual_sol, grad.gather(1, permutation)

0 commit comments

Comments
 (0)