|
35 | 35 | def soft_rank(values, regularization="l2", regularization_strength=1.0):
|
36 | 36 | if len(values.shape) != 2:
|
37 | 37 | raise ValueError(f"'values' should be a 2d-tensor but got {values.shape}")
|
| 38 | + if regularization not in ["l2", "kl"]: |
| 39 | + raise ValueError(f"'regularization' should be a 'l2' or 'kl'") |
38 | 40 | return SoftRank.apply(values, regularization, regularization_strength)
|
39 | 41 |
|
40 | 42 |
|
41 | 43 | def soft_sort(values, regularization="l2", regularization_strength=1.0):
|
42 | 44 | if len(values.shape) != 2:
|
43 | 45 | raise ValueError(f"'values' should be a 2d-tensor but got {values.shape}")
|
| 46 | + if regularization not in ["l2", "kl"]: |
| 47 | + raise ValueError(f"'regularization' should be a 'l2' or 'kl'") |
44 | 48 | return SoftSort.apply(values, regularization, regularization_strength)
|
45 | 49 |
|
46 | 50 |
|
@@ -90,19 +94,19 @@ def forward(ctx, tensor, regularization="l2", regularization_strength=1.0):
|
90 | 94 | if ctx.regularization == "l2":
|
91 | 95 | dual_sol = isotonic_l2[s.device.type](s - w)
|
92 | 96 | ret = (s - dual_sol).gather(1, inv_permutation)
|
93 |
| - ctx.factor = 1.0 |
| 97 | + factor = torch.tensor(1.0, device=s.device) |
94 | 98 | else:
|
95 | 99 | dual_sol = isotonic_kl[s.device.type](s, torch.log(w))
|
96 | 100 | ret = torch.exp((s - dual_sol).gather(1, inv_permutation))
|
97 |
| - ctx.factor = ret |
| 101 | + factor = ret |
98 | 102 |
|
99 |
| - ctx.save_for_backward(s, dual_sol, permutation, inv_permutation) |
| 103 | + ctx.save_for_backward(factor, s, dual_sol, permutation, inv_permutation) |
100 | 104 | return ret
|
101 | 105 |
|
102 | 106 | @staticmethod
|
103 | 107 | def backward(ctx, grad_output):
|
104 |
| - grad = (grad_output * ctx.factor).clone() |
105 |
| - s, dual_sol, permutation, inv_permutation = ctx.saved_tensors |
| 108 | + factor, s, dual_sol, permutation, inv_permutation = ctx.saved_tensors |
| 109 | + grad = (grad_output * factor).clone() |
106 | 110 | if ctx.regularization == "l2":
|
107 | 111 | grad -= isotonic_l2_backward[s.device.type](
|
108 | 112 | s, dual_sol, grad.gather(1, permutation)
|
|
0 commit comments