@@ -169,6 +169,172 @@ def test_convert_to_tensor(self):
169
169
170
170
171
171
class CoreOpsCorrectnessTest (testing .TestCase ):
172
+ def test_getitem (self ):
173
+ self .np_tensor = np .arange (24 ).reshape (2 , 3 , 4 )
174
+ self .tensor = ops .convert_to_tensor (self .np_tensor )
175
+
176
+ t = self .tensor [1 ]
177
+ n = self .np_tensor [1 ]
178
+ self .assertEqual (t .shape , n .shape )
179
+ self .assertTrue (np .array_equal (ops .convert_to_numpy (t ), n ))
180
+
181
+ t = self .tensor [1 , 2 , 3 ]
182
+ n = self .np_tensor [1 , 2 , 3 ]
183
+ self .assertEqual (t .shape , n .shape )
184
+ self .assertTrue (np .array_equal (ops .convert_to_numpy (t ), n ))
185
+
186
+ t = self .tensor [1 :2 ]
187
+ n = self .np_tensor [1 :2 ]
188
+ self .assertEqual (t .shape , n .shape )
189
+ self .assertTrue (np .array_equal (ops .convert_to_numpy (t ), n ))
190
+
191
+ t = self .tensor [1 :2 , 2 :3 , 3 :4 ]
192
+ n = self .np_tensor [1 :2 , 2 :3 , 3 :4 ]
193
+ self .assertEqual (t .shape , n .shape )
194
+ self .assertTrue (np .array_equal (ops .convert_to_numpy (t ), n ))
195
+
196
+ t = self .tensor [1 :2 , None ]
197
+ n = self .np_tensor [1 :2 , None ]
198
+ self .assertEqual (t .shape , n .shape )
199
+ self .assertTrue (np .array_equal (ops .convert_to_numpy (t ), n ))
200
+
201
+ t = self .tensor [1 :2 , 2 :3 , ...]
202
+ n = self .np_tensor [1 :2 , 2 :3 , ...]
203
+ self .assertEqual (t .shape , n .shape )
204
+ self .assertTrue (np .array_equal (ops .convert_to_numpy (t ), n ))
205
+
206
+ t = self .tensor [1 :2 , ..., 3 :4 ]
207
+ n = self .np_tensor [1 :2 , ..., 3 :4 ]
208
+ self .assertEqual (t .shape , n .shape )
209
+ self .assertTrue (np .array_equal (ops .convert_to_numpy (t ), n ))
210
+
211
+ t = self .tensor [None , ..., 3 :4 , None ]
212
+ n = self .np_tensor [None , ..., 3 :4 , None ]
213
+ self .assertEqual (t .shape , n .shape )
214
+ self .assertTrue (np .array_equal (ops .convert_to_numpy (t ), n ))
215
+
216
+ t = self .tensor [1 :2 :None ]
217
+ n = self .np_tensor [1 :2 :None ]
218
+ self .assertEqual (t .shape , n .shape )
219
+ self .assertTrue (np .array_equal (ops .convert_to_numpy (t ), n ))
220
+
221
+ t = self .tensor [:, 2 ]
222
+ n = self .np_tensor [:, 2 ]
223
+ self .assertEqual (t .shape , n .shape )
224
+ self .assertTrue (np .array_equal (ops .convert_to_numpy (t ), n ))
225
+
226
+ t = self .tensor [None ]
227
+ n = self .np_tensor [None ]
228
+ self .assertEqual (t .shape , n .shape )
229
+ self .assertTrue (np .array_equal (ops .convert_to_numpy (t ), n ))
230
+
231
+ t = self .tensor [None , None ]
232
+ n = self .np_tensor [None , None ]
233
+ self .assertEqual (t .shape , n .shape )
234
+ self .assertTrue (np .array_equal (ops .convert_to_numpy (t ), n ))
235
+
236
+ t = self .tensor [...]
237
+ n = self .np_tensor [...]
238
+ self .assertEqual (t .shape , n .shape )
239
+ self .assertTrue (np .array_equal (ops .convert_to_numpy (t ), n ))
240
+
241
+ t = self .tensor [..., 1 ]
242
+ n = self .np_tensor [..., 1 ]
243
+ self .assertEqual (t .shape , n .shape )
244
+ self .assertTrue (np .array_equal (ops .convert_to_numpy (t ), n ))
245
+
246
+ t = self .tensor [..., 1 , 2 ]
247
+ n = self .np_tensor [..., 1 , 2 ]
248
+ self .assertEqual (t .shape , n .shape )
249
+ self .assertTrue (np .array_equal (ops .convert_to_numpy (t ), n ))
250
+
251
+ t = self .tensor [..., - 1 , 2 ]
252
+ n = self .np_tensor [..., - 1 , 2 ]
253
+ self .assertEqual (t .shape , n .shape )
254
+ self .assertTrue (np .array_equal (ops .convert_to_numpy (t ), n ))
255
+
256
+ t = self .tensor [..., - 1 :- 2 , 2 ]
257
+ n = self .np_tensor [..., - 1 :- 2 , 2 ]
258
+ self .assertEqual (t .shape , n .shape )
259
+ self .assertTrue (np .array_equal (ops .convert_to_numpy (t ), n ))
260
+
261
+ t = self .tensor [..., None , None ]
262
+ n = self .np_tensor [..., None , None ]
263
+ self .assertEqual (t .shape , n .shape )
264
+ self .assertTrue (np .array_equal (ops .convert_to_numpy (t ), n ))
265
+
266
+ t = self .tensor [None , ..., None ]
267
+ n = self .np_tensor [None , ..., None ]
268
+ self .assertEqual (t .shape , n .shape )
269
+ self .assertTrue (np .array_equal (ops .convert_to_numpy (t ), n ))
270
+
271
+ t = self .tensor [1 , 2 , None , ..., None ]
272
+ n = self .np_tensor [1 , 2 , None , ..., None ]
273
+ self .assertEqual (t .shape , n .shape )
274
+ self .assertTrue (np .array_equal (ops .convert_to_numpy (t ), n ))
275
+
276
+ t = self .tensor [None , ..., 1 , 2 ]
277
+ n = self .np_tensor [None , ..., 1 , 2 ]
278
+ self .assertEqual (t .shape , n .shape )
279
+ self .assertTrue (np .array_equal (ops .convert_to_numpy (t ), n ))
280
+
281
+ t = self .tensor [1 , None , 2 ]
282
+ n = self .np_tensor [1 , None , 2 ]
283
+ self .assertEqual (t .shape , n .shape )
284
+ self .assertTrue (np .array_equal (ops .convert_to_numpy (t ), n ))
285
+
286
+ index_tensor = ops .convert_to_tensor (np .array (1 , dtype = np .int32 ))
287
+ t = self .tensor [index_tensor ]
288
+ n = self .np_tensor [ops .convert_to_numpy (index_tensor )]
289
+ self .assertEqual (t .shape , n .shape )
290
+ self .assertTrue (np .array_equal (ops .convert_to_numpy (t ), n ))
291
+
292
+ index_tensor = ops .convert_to_tensor (np .array (1 , dtype = np .int32 ))
293
+ t = self .tensor [index_tensor , 2 , None ]
294
+ n = self .np_tensor [ops .convert_to_numpy (index_tensor ), 2 , None ]
295
+ self .assertEqual (t .shape , n .shape )
296
+ self .assertTrue (np .array_equal (ops .convert_to_numpy (t ), n ))
297
+
298
+ index_tensor = ops .convert_to_tensor (np .array (- 2 , dtype = np .int32 ))
299
+ t = self .tensor [index_tensor , 1 ]
300
+ n = self .np_tensor [ops .convert_to_numpy (index_tensor ), 1 ]
301
+ self .assertEqual (t .shape , n .shape )
302
+ self .assertTrue (np .array_equal (ops .convert_to_numpy (t ), n ))
303
+
304
+ index_tensor = ops .convert_to_tensor (np .array (- 1 , dtype = np .int32 ))
305
+ t = self .tensor [- 2 , index_tensor ]
306
+ n = self .np_tensor [- 2 , ops .convert_to_numpy (index_tensor )]
307
+ self .assertEqual (t .shape , n .shape )
308
+ self .assertTrue (np .array_equal (ops .convert_to_numpy (t ), n ))
309
+
310
+ # Negative indexing
311
+ t = self .tensor [- 1 ]
312
+ n = self .np_tensor [- 1 ]
313
+ self .assertEqual (t .shape , n .shape )
314
+ self .assertTrue (np .array_equal (ops .convert_to_numpy (t ), n ))
315
+
316
+ t = self .tensor [1 , - 1 , - 2 ]
317
+ n = self .np_tensor [1 , - 1 , - 2 ]
318
+ self .assertEqual (t .shape , n .shape )
319
+ self .assertTrue (np .array_equal (ops .convert_to_numpy (t ), n ))
320
+
321
+ # Slicing with step
322
+ t = self .tensor [::2 ]
323
+ n = self .np_tensor [::2 ]
324
+ self .assertEqual (t .shape , n .shape )
325
+ self .assertTrue (np .array_equal (ops .convert_to_numpy (t ), n ))
326
+
327
+ # Mixed slices and integers
328
+ t = self .tensor [1 , :, 1 :4 ]
329
+ n = self .np_tensor [1 , :, 1 :4 ]
330
+ self .assertEqual (t .shape , n .shape )
331
+ self .assertTrue (np .array_equal (ops .convert_to_numpy (t ), n ))
332
+
333
+ t = self .tensor [:, 1 :2 , 3 ]
334
+ n = self .np_tensor [:, 1 :2 , 3 ]
335
+ self .assertEqual (t .shape , n .shape )
336
+ self .assertTrue (np .array_equal (ops .convert_to_numpy (t ), n ))
337
+
172
338
def test_map (self ):
173
339
def f (x ):
174
340
return x ** 2
0 commit comments