Skip to content

Commit 8727128

Browse files
Merge remote-tracking branch 'origin/main'
# Conflicts: # test/test_graphnet.py
2 parents ec61eff + 62f28aa commit 8727128

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

proxtorch/operators/graphnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ def prox(self, x: torch.Tensor, tau: float) -> torch.Tensor:
1515

1616
def _smooth(self, x: torch.Tensor) -> torch.Tensor:
1717
# The last channel is the for the l1 norm
18-
grad = self.gradient(x)[:-1]/(1-self.l1_ratio)
18+
grad = self.gradient(x)[:-1] / (1 - self.l1_ratio)
1919
# sum of squares of the gradients
20-
norm = torch.sum(grad ** 2)
20+
norm = torch.sum(grad**2)
2121
return 0.5 * norm * self.alpha * (1 - self.l1_ratio)
2222

2323
def _nonsmooth(self, x: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)