diff --git a/python/scripts/test_index.py b/python/scripts/test_index.py index cee1f307..5ac4f8a9 100644 --- a/python/scripts/test_index.py +++ b/python/scripts/test_index.py @@ -29,7 +29,7 @@ ScalarKind.BF16, ScalarKind.I8, ] -dtypes = [np.float32, np.float64, np.float16] +dtypes = [np.float32, np.float64, np.float16, np.int8, np.uint8] threads = 2 connectivity_options = [3, 13, 50, DEFAULT_CONNECTIVITY] @@ -49,8 +49,10 @@ def reset_randomness(): @pytest.mark.parametrize("metric", [MetricKind.Cos, MetricKind.L2sq]) @pytest.mark.parametrize("batch_size", [1, 7, 1024]) @pytest.mark.parametrize("quantization", [ScalarKind.F32, ScalarKind.I8]) -@pytest.mark.parametrize("dtype", [np.float32, np.float64, np.float16]) -def test_index_initialization_and_addition(ndim, metric, quantization, dtype, batch_size): +@pytest.mark.parametrize("dtype", dtypes) +def test_index_initialization_and_addition( + ndim, metric, quantization, dtype, batch_size +): reset_randomness() index = Index(ndim=ndim, metric=metric, dtype=quantization, multi=False) @@ -63,7 +65,9 @@ def test_index_initialization_and_addition(ndim, metric, quantization, dtype, ba @pytest.mark.parametrize("ndim", [3, 97, 256]) @pytest.mark.parametrize("metric", [MetricKind.Cos, MetricKind.L2sq]) @pytest.mark.parametrize("batch_size", [1, 7, 1024]) -@pytest.mark.parametrize("quantization", [ScalarKind.F32, ScalarKind.F16, ScalarKind.I8]) +@pytest.mark.parametrize( + "quantization", [ScalarKind.F32, ScalarKind.F16, ScalarKind.I8] +) @pytest.mark.parametrize("dtype", [np.float32, np.float64, np.float16]) def test_index_retrieval(ndim, metric, quantization, dtype, batch_size): reset_randomness() @@ -103,7 +107,9 @@ def test_index_retrieval(ndim, metric, quantization, dtype, batch_size): # Try a transposed version of the same vectors, that is not C-contiguous # and should raise an exception! index = Index(ndim=ndim, metric=metric, dtype=quantization, multi=False) - vectors = random_vectors(count=ndim, ndim=batch_size, dtype=dtype) #! reversed dims + vectors = random_vectors( + count=ndim, ndim=batch_size, dtype=dtype + ) #! reversed dims assert vectors.strides == (batch_size * dtype().itemsize, dtype().itemsize) assert vectors.T.strides == (dtype().itemsize, batch_size * dtype().itemsize) with pytest.raises(Exception): @@ -220,7 +226,9 @@ def test_index_save_load_restore_copy(ndim, quantization, batch_size): copied_index = index.copy() assert len(copied_index) == len(index) if batch_size > 0: - assert np.allclose(np.vstack(copied_index.get(keys)), np.vstack(index.get(keys))) + assert np.allclose( + np.vstack(copied_index.get(keys)), np.vstack(index.get(keys)) + ) index.save("tmp.usearch") index.clear() @@ -244,7 +252,9 @@ def test_index_save_load_restore_copy(ndim, quantization, batch_size): copied_index = index.copy() assert len(copied_index) == len(index) if batch_size > 0: - assert np.allclose(np.vstack(copied_index.get(keys)), np.vstack(index.get(keys))) + assert np.allclose( + np.vstack(copied_index.get(keys)), np.vstack(index.get(keys)) + ) # Perform the same operations in RAM, without touching the filesystem serialized_index = index.save() @@ -255,7 +265,9 @@ def test_index_save_load_restore_copy(ndim, quantization, batch_size): assert len(deserialized_index) == len(index) assert set(np.array(deserialized_index.keys)) == set(np.array(index.keys)) if batch_size > 0: - assert np.allclose(np.vstack(deserialized_index.get(keys)), np.vstack(index.get(keys))) + assert np.allclose( + np.vstack(deserialized_index.get(keys)), np.vstack(index.get(keys)) + ) deserialized_index.reset() index.reset() @@ -280,7 +292,7 @@ def test_index_contains_remove_rename(batch_size): removed_keys = keys[: batch_size // 2] remaining_keys = keys[batch_size // 2 :] index.remove(removed_keys) - del index[removed_keys] # ! This will trigger the `__delitem__` dunder method + del index[removed_keys] # ! This will trigger the `__delitem__` dunder method assert len(index) == (len(keys) - len(removed_keys)) assert np.sum(index.contains(keys)) == len(remaining_keys) assert np.sum(index.count(keys)) == len(remaining_keys) @@ -313,7 +325,9 @@ def test_index_oversubscribed_search(batch_size: int, threads: int): assert np.all(index.contains(keys)) assert np.all(index.count(keys) == np.ones(batch_size)) - batch_matches: BatchMatches = index.search(vectors, batch_size * 10, threads=threads) + batch_matches: BatchMatches = index.search( + vectors, batch_size * 10, threads=threads + ) for i, match in enumerate(batch_matches): assert i == match.keys[0] assert len(match.keys) == batch_size