Skip to content

Commit e9ad101

Browse files
authored
[0-size Tensor No.20、22、54、56、242、301、186、212、178、226] Add 0-size Tensor support for polygamma/bitwise_not (#72977)
1 parent 446298c commit e9ad101

File tree

3 files changed

+37
-1
lines changed

3 files changed

+37
-1
lines changed

test/legacy_test/test_activation_op_zero_size.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,15 @@
2525
TestCeil,
2626
TestCos,
2727
TestCosh,
28+
TestExpFp32_Prim,
29+
TestExpm1,
2830
TestFloor,
31+
TestHardSwish,
32+
TestLeakyRelu,
2933
TestLogSigmoid,
3034
TestReciprocal,
3135
TestRelu,
36+
TestRelu6,
3237
TestRsqrt,
3338
TestSigmoid,
3439
TestSilu,
@@ -37,6 +42,7 @@
3742
TestSoftsign,
3843
TestSqrt,
3944
TestSquare,
45+
TestSwish,
4046
TestTan,
4147
TestTanh,
4248
TestTanhshrink,
@@ -96,7 +102,12 @@ def test_check_grad(self):
96102
create_test_zero_size_class(TestLogSigmoid)
97103
create_test_zero_size_class(TestFloor)
98104
create_test_zero_size_class(TestCeil)
99-
105+
create_test_zero_size_class(TestExpFp32_Prim)
106+
create_test_zero_size_class(TestExpm1)
107+
create_test_zero_size_class(TestLeakyRelu)
108+
create_test_zero_size_class(TestRelu6)
109+
create_test_zero_size_class(TestHardSwish)
110+
create_test_zero_size_class(TestSwish)
100111

101112
if __name__ == "__main__":
102113
unittest.main()

test/legacy_test/test_bitwise_op.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,11 @@ def init_shape(self):
392392
self.x_shape = []
393393

394394

395+
class TestBitwiseNot_ZeroSize(TestBitwiseNot):
396+
def init_shape(self):
397+
self.x_shape = [0, 3, 4, 5]
398+
399+
395400
class TestBitwiseNotUInt8(TestBitwiseNot):
396401
def init_dtype(self):
397402
self.dtype = np.uint8

test/legacy_test/test_polygamma_op.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,5 +217,25 @@ def test_check_grad(self):
217217
)
218218

219219

220+
class TestPolygammaOp_ZeroSize(TestPolygammaOp):
221+
222+
def init_config(self):
223+
self.dtype = np.float64
224+
self.order = 1
225+
rand_case = np.random.randn(0).astype(self.dtype)
226+
int_case = np.random.randint(low=1, high=100, size=0).astype(self.dtype)
227+
self.case = np.concatenate([rand_case, int_case])
228+
self.inputs = {'x': self.case}
229+
self.attrs = {'n': self.order}
230+
self.target = ref_polygamma(self.inputs['x'], self.order)
231+
232+
def test_check_grad(self):
233+
self.check_grad(
234+
['x'],
235+
'out',
236+
check_pir=True,
237+
)
238+
239+
220240
if __name__ == "__main__":
221241
unittest.main()

0 commit comments

Comments
 (0)