From 2dd05a55df6ab9a85dd0140e7780de801cb31623 Mon Sep 17 00:00:00 2001 From: DanielSun11 Date: Sat, 31 May 2025 00:30:00 +0800 Subject: [PATCH 1/4] fix logit grad --- paddle/phi/kernels/funcs/activation_functor.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index be391396658fa3..0f25ca778cea39 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -1215,7 +1215,7 @@ struct LogitGradFunctor { // logit(x)' = 1/(x*(1-x)) dx.device(d) = (x < static_cast(eps) || x > static_cast(1.0 - eps)) - .select(p.constant(static_cast(0)), + .select(p.constant(static_cast(NAN)), dout * (static_cast(1) / ((static_cast(1) - x) * x))); } }; @@ -3352,6 +3352,7 @@ struct CudaLogitGradFunctor : public BaseActivationFunctor { float eps; MT zero = static_cast(0.0f); MT one = static_cast(1.0f); + MT nan = static_cast(NAN); typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"eps", &eps}}; @@ -3360,7 +3361,7 @@ struct CudaLogitGradFunctor : public BaseActivationFunctor { __device__ __forceinline__ T operator()(const T dout, const T arg_x) const { MT x = static_cast(arg_x); MT dx = (x < static_cast(eps) || x > one - static_cast(eps)) - ? zero + ? nan : (static_cast(dout) / (x * (one - x))); return static_cast(dx); } From 62ab1203873eed51a4fbae895f0fb434bdc02bad Mon Sep 17 00:00:00 2001 From: DanielSun11 Date: Mon, 2 Jun 2025 23:30:38 +0800 Subject: [PATCH 2/4] fix kernel bug and add unittest --- paddle/phi/kernels/funcs/activation_functor.h | 29 +++++++++--- test/legacy_test/test_logit_op.py | 46 ++++++++++++++----- 2 files changed, 57 insertions(+), 18 deletions(-) diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index 0f25ca778cea39..06e99c4d2d43c7 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -1213,10 +1213,17 @@ struct LogitGradFunctor { template void operator()(Device d, X x, dOut dout, dX dx, P p, float eps) const { // logit(x)' = 1/(x*(1-x)) - dx.device(d) = - (x < static_cast(eps) || x > static_cast(1.0 - eps)) - .select(p.constant(static_cast(NAN)), - dout * (static_cast(1) / ((static_cast(1) - x) * x))); + if (!eps) { + dx.device(d) = (x < static_cast(eps) || x > static_cast(1.0 - eps)) + .select(p.constant(static_cast(NAN)), + dout * (static_cast(1) / + ((static_cast(1) - x) * x))); + } else { + dx.device(d) = (x < static_cast(eps) || x > static_cast(1.0 - eps)) + .select(p.constant(static_cast(0)), + dout * (static_cast(1) / + ((static_cast(1) - x) * x))); + } } }; @@ -3360,9 +3367,17 @@ struct CudaLogitGradFunctor : public BaseActivationFunctor { // logit(x)' = 1/(x*(1-x)) __device__ __forceinline__ T operator()(const T dout, const T arg_x) const { MT x = static_cast(arg_x); - MT dx = (x < static_cast(eps) || x > one - static_cast(eps)) - ? nan - : (static_cast(dout) / (x * (one - x))); + MT dx; + if (!eps) { + dx = (x < static_cast(eps) || x > one - static_cast(eps)) + ? nan + : (static_cast(dout) / (x * (one - x))); + } else { + dx = (x < static_cast(eps) || x > one - static_cast(eps)) + ? zero + : (static_cast(dout) / (x * (one - x))); + } + return static_cast(dx); } static constexpr ActBwdOpFwdDeps FwdDeps() { diff --git a/test/legacy_test/test_logit_op.py b/test/legacy_test/test_logit_op.py index b4c04c4b63e263..9fca33352797e5 100644 --- a/test/legacy_test/test_logit_op.py +++ b/test/legacy_test/test_logit_op.py @@ -65,7 +65,11 @@ def test_check_output(self): def test_check_grad(self): self.check_grad( - ['X'], ['Out'], user_defined_grads=[self.x_grad], check_pir=True + ['X'], + ['Out'], + user_defined_grads=[self.x_grad], + max_relative_error=0.01, + check_pir=True, ) @@ -78,11 +82,6 @@ def set_attrs(self): def test_check_output(self): self.check_output(check_pir=True) - def test_check_grad(self): - self.check_grad( - ['X'], ['Out'], user_defined_grads=[self.x_grad], check_pir=True - ) - class TestLogitOpFp16(TestLogitOp): def set_attrs(self): @@ -93,11 +92,6 @@ def set_attrs(self): def test_check_output(self): self.check_output(check_pir=True) - def test_check_grad(self): - self.check_grad( - ['X'], ['Out'], user_defined_grads=[self.x_grad], check_pir=True - ) - @unittest.skipIf( not core.is_compiled_with_cuda() @@ -202,5 +196,35 @@ def test_errors(self): self.assertRaises(TypeError, paddle.logit, x, dtype='int32') +class TestLogitAPI_NAN_Val(unittest.TestCase): + def setUp(self): + self.init_input_output() + self.place = [paddle.CPUPlace()] + if paddle.base.core.is_compiled_with_cuda(): + self.place.append(paddle.CUDAPlace(0)) + + def init_input_output(self): + self.x = [-0.1, 1.1, 2] + self.expect_out = [np.nan, np.nan, np.nan] + self.expect_x_grad = [np.nan, np.nan, np.nan] + + def test_nan_val(self): + def _test_nan_val_with_place(place): + with paddle.base.dygraph.guard(): + x = paddle.to_tensor(self.x, stop_gradient=False, place=place) + y = paddle.logit(x) + loss = y.sum() + loss.backward() + np.testing.assert_allclose( + y.numpy(), self.expect_out, rtol=1e-05 + ) + np.testing.assert_allclose( + x.grad.numpy(), self.expect_x_grad, rtol=1e-05 + ) + + for place in self.place: + _test_nan_val_with_place(place) + + if __name__ == "__main__": unittest.main() From 6ccac7c24d0926cce662953bdc55cd623aefc2b4 Mon Sep 17 00:00:00 2001 From: DanielSun11 Date: Mon, 2 Jun 2025 23:50:57 +0800 Subject: [PATCH 3/4] rm rhol --- test/legacy_test/test_logit_op.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/legacy_test/test_logit_op.py b/test/legacy_test/test_logit_op.py index 9fca33352797e5..65faf6b1ea69cd 100644 --- a/test/legacy_test/test_logit_op.py +++ b/test/legacy_test/test_logit_op.py @@ -68,7 +68,6 @@ def test_check_grad(self): ['X'], ['Out'], user_defined_grads=[self.x_grad], - max_relative_error=0.01, check_pir=True, ) From 0af02065ae29757e2b7a27ba2854dbd8957220a0 Mon Sep 17 00:00:00 2001 From: DanielSun11 Date: Wed, 4 Jun 2025 17:41:37 +0800 Subject: [PATCH 4/4] fix codestyle error --- test/legacy_test/test_logit_op.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/legacy_test/test_logit_op.py b/test/legacy_test/test_logit_op.py index 6da307bf8adf9f..e9552e2f7e9307 100644 --- a/test/legacy_test/test_logit_op.py +++ b/test/legacy_test/test_logit_op.py @@ -265,7 +265,8 @@ def _test_nan_val_with_place(place): for place in self.place: _test_nan_val_with_place(place) - class TestLogitAPICase1(unittest.TestCase): + +class TestLogitAPICase1(unittest.TestCase): def init_data(self): self.x_shape = [120] self.x_dtype = "float64" @@ -276,5 +277,6 @@ def init_data(self): self.x_shape = [120] self.x_dtype = "float16" + if __name__ == "__main__": unittest.main()