Skip to content

Commit 7d7a6bb

Browse files
authored
Fixed missing axis annotation in tf map_coordinates (#21304)
1 parent 37a0920 commit 7d7a6bb

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

keras/src/backend/tensorflow/image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -707,7 +707,7 @@ def process_coordinates(coords, size):
707707
gathered = tf.transpose(tf.gather_nd(input_arr, indices))
708708

709709
if fill_mode == "constant":
710-
all_valid = tf.reduce_all(validities)
710+
all_valid = tf.reduce_all(validities, axis=0)
711711
gathered = tf.where(all_valid, gathered, fill_value)
712712

713713
contribution = gathered

keras/src/ops/image_test.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1865,6 +1865,30 @@ def test_elastic_transform(self):
18651865
)
18661866
self.assertAllClose(np.var(ref_out), np.var(out), atol=1e-2, rtol=1e-2)
18671867

1868+
def test_map_coordinates_constant_padding(self):
1869+
input_img = tf.ones((2, 2), dtype=tf.uint8)
1870+
# one pixel outside of the input space around the edges
1871+
grid = tf.stack(
1872+
tf.meshgrid(
1873+
tf.range(-1, 3, dtype=tf.float32),
1874+
tf.range(-1, 3, dtype=tf.float32),
1875+
indexing="ij",
1876+
),
1877+
axis=0,
1878+
)
1879+
out = backend.convert_to_numpy(
1880+
kimage.map_coordinates(
1881+
input_img, grid, order=0, fill_mode="constant", fill_value=0
1882+
)
1883+
)
1884+
1885+
# check for ones in the middle and zeros around the edges
1886+
self.assertTrue(np.all(out[:1] == 0))
1887+
self.assertTrue(np.all(out[-1:] == 0))
1888+
self.assertTrue(np.all(out[:, :1] == 0))
1889+
self.assertTrue(np.all(out[:, -1:] == 0))
1890+
self.assertTrue(np.all(out[1:3, 1:3] == 1))
1891+
18681892

18691893
class ImageOpsBehaviorTests(testing.TestCase):
18701894
def setUp(self):

0 commit comments

Comments
 (0)