Skip to content

Commit 7968dc3

Browse files
fix test
1 parent b6db941 commit 7968dc3

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

test/Tensorflow.UnitTest/PythonTest.cs

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,13 +133,23 @@ public void assertTrue(bool cond)
133133

134134
public void assertAllClose(NDArray array1, NDArray array2, double eps = 1e-5)
135135
{
136-
Assert.IsTrue(np.allclose(array1, array2, rtol: eps));
136+
CollectionAssert.AreEqual(array1.ToArray(), array2.ToArray(), new CollectionComparer(eps));
137+
138+
//TODO: Assert.IsTrue(np.allclose(array1, array2, rtol: eps));
137139
}
138140

139141
public void assertAllClose(double value, NDArray array2, double eps = 1e-5)
140142
{
143+
if (array2.shape.IsScalar)
144+
{
145+
double value2 = array2;
146+
Assert.AreEqual(value, value2, eps);
147+
return;
148+
}
141149
var array1 = np.ones_like(array2) * value;
142-
Assert.IsTrue(np.allclose(array1, array2, rtol: eps));
150+
CollectionAssert.AreEqual(array1.ToArray(), array2.ToArray(), new CollectionComparer(eps));
151+
152+
//TODO: Assert.IsTrue(np.allclose(array1, array2, rtol: eps));
143153
}
144154

145155
private class CollectionComparer : IComparer
@@ -158,7 +168,7 @@ public int Compare(object? x, object? y)
158168
}
159169
else if (x == null)
160170
{
161-
return -1;
171+
return -1;
162172
}
163173
else if (y == null)
164174
{

0 commit comments

Comments
 (0)