From f78b4fc64c3d377678bd94162b701aa0e71f5fe0 Mon Sep 17 00:00:00 2001 From: loki-veera Date: Wed, 4 Sep 2024 12:20:01 +0200 Subject: [PATCH] fix np to torch in bumpy torch --- src/optimize_2d_momentum_bumpy_torch.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/optimize_2d_momentum_bumpy_torch.py b/src/optimize_2d_momentum_bumpy_torch.py index c444536..cd7f335 100644 --- a/src/optimize_2d_momentum_bumpy_torch.py +++ b/src/optimize_2d_momentum_bumpy_torch.py @@ -33,7 +33,7 @@ def bumpy_function(pos: th.Tensor) -> th.Tensor: import matplotlib.pyplot as plt from matplotlib import cm - # TODO: use jax to find the gradient. + # TODO: use torch to find the gradient. nx, ny = (1001, 1001) x = th.linspace(-3, 3, nx) @@ -57,7 +57,7 @@ def bumpy_function(pos: th.Tensor) -> th.Tensor: step_total = 100 pos_list = [start_pos] - velocity_vec = np.array((0.0, 0.0)) + velocity_vec = th.tensor((0.0, 0.0)) # TODO: Implement gradient descent with momentum. for pos in pos_list: @@ -69,5 +69,5 @@ def bumpy_function(pos: th.Tensor) -> th.Tensor: np.array(my), np.array(mz), pos_list, - "writer_grad_bumpy_plot_jax", + "writer_grad_bumpy_plot_torch", )