From 563bb499fdffa2ccfa185dd279e15f7e9a702156 Mon Sep 17 00:00:00 2001 From: Kyle Brown Date: Fri, 7 Mar 2025 16:11:38 -0700 Subject: [PATCH 1/6] added retryable client Signed-off-by: Kyle Brown --- authzed/api/v1/__init__.py | 6 + authzed/api/v1/retryable_client.py | 275 ++++++++++++++ examples/v1/retryable_import_relationships.py | 111 ++++++ .../retryable_import_relationships_async.py | 113 ++++++ tests/retryable_client_test.py | 359 ++++++++++++++++++ 5 files changed, 864 insertions(+) create mode 100644 authzed/api/v1/retryable_client.py create mode 100644 examples/v1/retryable_import_relationships.py create mode 100644 examples/v1/retryable_import_relationships_async.py create mode 100644 tests/retryable_client_test.py diff --git a/authzed/api/v1/__init__.py b/authzed/api/v1/__init__.py index fa2dd8f..ee9a862 100644 --- a/authzed/api/v1/__init__.py +++ b/authzed/api/v1/__init__.py @@ -164,8 +164,14 @@ def __init__( self.init_stubs(channel) +# Import after defining Client to avoid circular imports +from authzed.api.v1.retryable_client import RetryableClient, ConflictStrategy + + __all__ = [ "Client", + "RetryableClient", + "ConflictStrategy", # Core "AlgebraicSubjectSet", "ContextualizedCaveat", diff --git a/authzed/api/v1/retryable_client.py b/authzed/api/v1/retryable_client.py new file mode 100644 index 0000000..9316e22 --- /dev/null +++ b/authzed/api/v1/retryable_client.py @@ -0,0 +1,275 @@ +import asyncio +import enum +import time +from typing import List, Optional + +import grpc +from google.rpc import code_pb2 +from grpc import StatusCode + +from authzed.api.v1 import Client, Relationship, RelationshipUpdate +from authzed.api.v1.experimental_service_pb2 import BulkImportRelationshipsRequest +from authzed.api.v1.permission_service_pb2 import WriteRelationshipsRequest + +# Default configuration +DEFAULT_BACKOFF_MS = 50 +DEFAULT_MAX_RETRIES = 10 +DEFAULT_MAX_BACKOFF_MS = 2000 +DEFAULT_TIMEOUT_SECONDS = 30 + + +class ConflictStrategy(enum.Enum): + """Strategy to handle conflicts during bulk relationship import.""" + FAIL = 0 # The operation will fail if any duplicate relationships are found + SKIP = 1 # The operation will ignore duplicates and continue with the import + TOUCH = 2 # The operation will retry with TOUCH semantics for duplicates + + +# Datastore error strings for older versions of SpiceDB +TX_CONFLICT_STRINGS = [ + "SQLSTATE 23505", # CockroachDB + "Error 1062 (23000)", # MySQL +] + +RETRYABLE_ERROR_STRINGS = [ + "retryable error", # CockroachDB, PostgreSQL + "try restarting transaction", "Error 1205", # MySQL +] + + +class RetryableClient(Client): + """ + A client for SpiceDB that adds retryable operations with support for + different conflict strategies. This client extends the base Client with + additional functionality for handling transaction conflicts. + """ + + def __init__(self, target, credentials, options=None, compression=None): + super().__init__(target, credentials, options, compression) + + def retryable_bulk_import_relationships( + self, + relationships: List[Relationship], + conflict_strategy: ConflictStrategy, + timeout_seconds: Optional[int] = None + ): + """ + Import relationships with configurable retry behavior based on conflict strategy. + + Args: + relationships: List of relationships to import + conflict_strategy: Strategy to use when conflicts are detected + timeout_seconds: Optional timeout in seconds for the operation + + Returns: + The response from the successful import operation + + Raises: + Exception: If the import fails and cannot be retried + """ + if asyncio.iscoroutinefunction(self.BulkImportRelationships): + return self._retryable_bulk_import_relationships_async( + relationships, conflict_strategy, timeout_seconds + ) + else: + return self._retryable_bulk_import_relationships_sync( + relationships, conflict_strategy, timeout_seconds + ) + + def _retryable_bulk_import_relationships_sync( + self, + relationships: List[Relationship], + conflict_strategy: ConflictStrategy, + timeout_seconds: Optional[int] = None + ): + """Synchronous implementation of retryable bulk import.""" + timeout = timeout_seconds or DEFAULT_TIMEOUT_SECONDS + + # Try bulk import first + writer = self.BulkImportRelationships(timeout=timeout) + request = BulkImportRelationshipsRequest(relationships=relationships) + + writer.send(request) + + try: + response = writer.done() + return response # Success on first try + except Exception as err: + # Handle errors based on type and conflict strategy + if self._is_canceled_error(err): + raise err + + if self._is_already_exists_error(err) and conflict_strategy == ConflictStrategy.SKIP: + return None # Skip conflicts + + if self._is_retryable_error(err) or ( + self._is_already_exists_error(err) and + conflict_strategy == ConflictStrategy.TOUCH + ): + # Retry with write_relationships_with_retry + return self._write_batches_with_retry_sync(relationships, timeout) + + if self._is_already_exists_error(err) and conflict_strategy == ConflictStrategy.FAIL: + raise ValueError("Duplicate relationships found") + + # Default case - propagate the error + raise ValueError(f"Error finalizing write of {len(relationships)} relationships: {err}") + + async def _retryable_bulk_import_relationships_async( + self, + relationships: List[Relationship], + conflict_strategy: ConflictStrategy, + timeout_seconds: Optional[int] = None + ): + """Asynchronous implementation of retryable bulk import.""" + timeout = timeout_seconds or DEFAULT_TIMEOUT_SECONDS + + # Try bulk import first + writer = await self.BulkImportRelationships(timeout=timeout) + request = BulkImportRelationshipsRequest(relationships=relationships) + + await writer.write(request) + + try: + response = await writer.done_writing() + return response # Success on first try + except Exception as err: + # Handle errors based on type and conflict strategy + if self._is_canceled_error(err): + raise err + + if self._is_already_exists_error(err) and conflict_strategy == ConflictStrategy.SKIP: + return None # Skip conflicts + + if self._is_retryable_error(err) or ( + self._is_already_exists_error(err) and + conflict_strategy == ConflictStrategy.TOUCH + ): + # Retry with write_relationships_with_retry + return await self._write_batches_with_retry_async(relationships, timeout) + + if self._is_already_exists_error(err) and conflict_strategy == ConflictStrategy.FAIL: + raise ValueError("Duplicate relationships found") + + # Default case - propagate the error + raise ValueError(f"Error finalizing write of {len(relationships)} relationships: {err}") + + def _write_batches_with_retry_sync(self, relationships: List[Relationship], timeout_seconds: int): + """ + Retry writing relationships in batches with exponential backoff. + This is a synchronous implementation. + """ + updates = [ + RelationshipUpdate( + relationship=rel, + operation=RelationshipUpdate.OPERATION_TOUCH + ) + for rel in relationships + ] + + backoff_ms = DEFAULT_BACKOFF_MS + current_retries = 0 + + while True: + try: + request = WriteRelationshipsRequest(updates=updates) + response = self.WriteRelationships(request, timeout=timeout_seconds) + return response + except Exception as err: + if self._is_retryable_error(err) and current_retries < DEFAULT_MAX_RETRIES: + # Throttle writes with exponential backoff + time.sleep(backoff_ms / 1000) + backoff_ms = min(backoff_ms * 2, DEFAULT_MAX_BACKOFF_MS) + current_retries += 1 + continue + + # Non-retryable error or max retries exceeded + raise ValueError(f"Failed to write relationships after retry: {err}") + + async def _write_batches_with_retry_async(self, relationships: List[Relationship], timeout_seconds: int): + """ + Retry writing relationships in batches with exponential backoff. + This is an asynchronous implementation. + """ + updates = [ + RelationshipUpdate( + relationship=rel, + operation=RelationshipUpdate.OPERATION_TOUCH + ) + for rel in relationships + ] + + backoff_ms = DEFAULT_BACKOFF_MS + current_retries = 0 + + while True: + try: + request = WriteRelationshipsRequest(updates=updates) + response = await self.WriteRelationships(request, timeout=timeout_seconds) + return response + except Exception as err: + if self._is_retryable_error(err) and current_retries < DEFAULT_MAX_RETRIES: + # Throttle writes with exponential backoff + await asyncio.sleep(backoff_ms / 1000) + backoff_ms = min(backoff_ms * 2, DEFAULT_MAX_BACKOFF_MS) + current_retries += 1 + continue + + # Non-retryable error or max retries exceeded + raise ValueError(f"Failed to write relationships after retry: {err}") + + def _is_already_exists_error(self, err): + """Check if the error is an 'already exists' error.""" + if err is None: + return False + + if self._is_grpc_code(err, StatusCode.ALREADY_EXISTS): + return True + + return self._contains_error_string(err, TX_CONFLICT_STRINGS) + + def _is_retryable_error(self, err): + """Check if the error is retryable.""" + if err is None: + return False + + if self._is_grpc_code(err, StatusCode.UNAVAILABLE, StatusCode.DEADLINE_EXCEEDED): + return True + + if self._contains_error_string(err, RETRYABLE_ERROR_STRINGS): + return True + + return isinstance(err, asyncio.TimeoutError) or isinstance(getattr(err, "__cause__", None), asyncio.TimeoutError) + + def _is_canceled_error(self, err): + """Check if the error is a cancellation error.""" + if err is None: + return False + + if isinstance(err, asyncio.CancelledError): + return True + + if self._is_grpc_code(err, StatusCode.CANCELLED): + return True + + return False + + def _contains_error_string(self, err, error_strings): + """Check if the error message contains any of the given strings.""" + if err is None: + return False + + err_str = str(err) + return any(es in err_str for es in error_strings) + + def _is_grpc_code(self, err, *codes): + """Check if the error is a gRPC error with one of the given status codes.""" + if err is None: + return False + + try: + status = grpc.StatusCode(err.code()) + return status in codes + except (ValueError, AttributeError): + # If we can't extract a gRPC status code, it's not a gRPC error + return False \ No newline at end of file diff --git a/examples/v1/retryable_import_relationships.py b/examples/v1/retryable_import_relationships.py new file mode 100644 index 0000000..368ff0a --- /dev/null +++ b/examples/v1/retryable_import_relationships.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python3 + +import os +import sys +from typing import List + +import grpc + +from authzed.api.v1 import ( + ConflictStrategy, + ObjectReference, + Relationship, + RetryableClient, + SubjectReference, +) + +# Environment variables for configuration +AUTHZED_ENDPOINT = os.getenv("AUTHZED_ENDPOINT", "grpc.authzed.com:443") +AUTHZED_TOKEN = os.getenv("AUTHZED_TOKEN", "") + + +def create_sample_relationships() -> List[Relationship]: + """Create a list of sample relationships for import.""" + relationships = [] + + # Create 5 documents, each with view permissions for 2 users + for i in range(5): + doc_id = f"doc{i}" + for j in range(2): + user_id = f"user{j}" + + # Create a relationship where user can view document + rel = Relationship( + resource=ObjectReference( + object_type="document", + object_id=doc_id, + ), + relation="viewer", + subject=SubjectReference( + object=ObjectReference( + object_type="user", + object_id=user_id, + ), + ), + ) + relationships.append(rel) + + return relationships + + +def main(): + """ + Demonstrate usage of the RetryableClient for importing relationships + with different conflict strategies. + """ + if not AUTHZED_TOKEN: + print("Error: AUTHZED_TOKEN environment variable is required") + sys.exit(1) + + # Create channel credentials + channel_creds = grpc.ssl_channel_credentials() + + # Create RetryableClient + client = RetryableClient( + AUTHZED_ENDPOINT, + grpc.composite_channel_credentials( + channel_creds, + grpc.access_token_call_credentials(AUTHZED_TOKEN), + ), + ) + + # Create sample relationships + relationships = create_sample_relationships() + print(f"Created {len(relationships)} sample relationships") + + # Import relationships with TOUCH conflict strategy + print("Importing relationships with TOUCH conflict strategy...") + try: + client.retryable_bulk_import_relationships( + relationships=relationships, + conflict_strategy=ConflictStrategy.TOUCH, + ) + print("Import successful!") + except Exception as e: + print(f"Import failed: {e}") + + # Try to import the same relationships again, but with SKIP strategy + print("\nImporting the same relationships again with SKIP conflict strategy...") + try: + client.retryable_bulk_import_relationships( + relationships=relationships, + conflict_strategy=ConflictStrategy.SKIP, + ) + print("Import successful (skipped duplicates)!") + except Exception as e: + print(f"Import failed: {e}") + + # Try to import the same relationships again, but with FAIL strategy + print("\nImporting the same relationships again with FAIL conflict strategy...") + try: + client.retryable_bulk_import_relationships( + relationships=relationships, + conflict_strategy=ConflictStrategy.FAIL, + ) + print("Import successful!") + except Exception as e: + print(f"Import failed as expected: {e}") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/v1/retryable_import_relationships_async.py b/examples/v1/retryable_import_relationships_async.py new file mode 100644 index 0000000..1c7ad4b --- /dev/null +++ b/examples/v1/retryable_import_relationships_async.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 + +import asyncio +import os +import sys +from typing import List + +import grpc +import grpc.aio + +from authzed.api.v1 import ( + ConflictStrategy, + ObjectReference, + Relationship, + RetryableClient, + SubjectReference, +) + +# Environment variables for configuration +AUTHZED_ENDPOINT = os.getenv("AUTHZED_ENDPOINT", "grpc.authzed.com:443") +AUTHZED_TOKEN = os.getenv("AUTHZED_TOKEN", "") + + +def create_sample_relationships() -> List[Relationship]: + """Create a list of sample relationships for import.""" + relationships = [] + + # Create 5 documents, each with view permissions for 2 users + for i in range(5): + doc_id = f"doc{i}" + for j in range(2): + user_id = f"user{j}" + + # Create a relationship where user can view document + rel = Relationship( + resource=ObjectReference( + object_type="document", + object_id=doc_id, + ), + relation="viewer", + subject=SubjectReference( + object=ObjectReference( + object_type="user", + object_id=user_id, + ), + ), + ) + relationships.append(rel) + + return relationships + + +async def main(): + """ + Demonstrate usage of the RetryableClient for importing relationships + with different conflict strategies using async/await. + """ + if not AUTHZED_TOKEN: + print("Error: AUTHZED_TOKEN environment variable is required") + sys.exit(1) + + # Create channel credentials + channel_creds = grpc.ssl_channel_credentials() + + # Create RetryableClient + client = RetryableClient( + AUTHZED_ENDPOINT, + grpc.composite_channel_credentials( + channel_creds, + grpc.access_token_call_credentials(AUTHZED_TOKEN), + ), + ) + + # Create sample relationships + relationships = create_sample_relationships() + print(f"Created {len(relationships)} sample relationships") + + # Import relationships with TOUCH conflict strategy + print("Importing relationships with TOUCH conflict strategy...") + try: + await client.retryable_bulk_import_relationships( + relationships=relationships, + conflict_strategy=ConflictStrategy.TOUCH, + ) + print("Import successful!") + except Exception as e: + print(f"Import failed: {e}") + + # Try to import the same relationships again, but with SKIP strategy + print("\nImporting the same relationships again with SKIP conflict strategy...") + try: + await client.retryable_bulk_import_relationships( + relationships=relationships, + conflict_strategy=ConflictStrategy.SKIP, + ) + print("Import successful (skipped duplicates)!") + except Exception as e: + print(f"Import failed: {e}") + + # Try to import the same relationships again, but with FAIL strategy + print("\nImporting the same relationships again with FAIL conflict strategy...") + try: + await client.retryable_bulk_import_relationships( + relationships=relationships, + conflict_strategy=ConflictStrategy.FAIL, + ) + print("Import successful!") + except Exception as e: + print(f"Import failed as expected: {e}") + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/tests/retryable_client_test.py b/tests/retryable_client_test.py new file mode 100644 index 0000000..6c2c858 --- /dev/null +++ b/tests/retryable_client_test.py @@ -0,0 +1,359 @@ +import asyncio +import uuid +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +import grpc +import pytest +from google.protobuf.empty_pb2 import Empty + +from authzed.api.v1 import ( + ConflictStrategy, + ObjectReference, + Relationship, + RelationshipUpdate, + RetryableClient, + SubjectReference, + WriteRelationshipsResponse, +) +from authzed.api.v1.experimental_service_pb2 import BulkImportRelationshipsResponse +from grpcutil import insecure_bearer_token_credentials +from inspect import isawaitable + + +async def maybe_await(resp): + """Helper function to handle both sync and async responses.""" + if isawaitable(resp): + resp = await resp + return resp + + +# Create fixtures for mocked sync and async clients +@pytest.fixture() +def sync_retryable_client(token) -> RetryableClient: + with patch.object(RetryableClient, 'create_channel') as mock_create_channel: + mock_channel = Mock() + mock_create_channel.return_value = mock_channel + client = RetryableClient("localhost:50051", insecure_bearer_token_credentials(token)) + + # Mock all the key methods we'll use in testing + client.BulkImportRelationships = Mock() + client.WriteRelationships = Mock(return_value=WriteRelationshipsResponse()) + + return client + + +@pytest.fixture() +async def async_retryable_client(token) -> RetryableClient: + with patch.object(RetryableClient, 'create_channel') as mock_create_channel: + mock_channel = Mock() + mock_create_channel.return_value = mock_channel + client = RetryableClient("localhost:50051", insecure_bearer_token_credentials(token)) + + # Mock all the key methods we'll use in testing + client.BulkImportRelationships = AsyncMock() + client.WriteRelationships = AsyncMock(return_value=WriteRelationshipsResponse()) + + # Force async mode + client._is_async = True + + return client + + +@pytest.fixture(params=["sync", "async"]) +def retryable_client( + request, + sync_retryable_client: RetryableClient, + async_retryable_client: RetryableClient, +): + clients = { + "sync": sync_retryable_client, + "async": async_retryable_client, + } + return clients[request.param] + + +@pytest.fixture +def sample_relationships(): + """Return sample relationships for testing.""" + return [ + Relationship( + resource=ObjectReference(object_type="document", object_id="doc1"), + relation="viewer", + subject=SubjectReference( + object=ObjectReference(object_type="user", object_id="user1") + ), + ), + ] + + +@patch("asyncio.iscoroutinefunction") +async def test_successful_bulk_import(mock_is_coro, retryable_client, sample_relationships): + """Test that bulk import works without errors using mocks.""" + # Configure mocks + mock_is_coro.return_value = isinstance(retryable_client.WriteRelationships, AsyncMock) + + if isinstance(retryable_client.BulkImportRelationships, AsyncMock): + # For async client + mock_writer = AsyncMock() + mock_writer.write = AsyncMock() + mock_writer.done_writing = AsyncMock(return_value=BulkImportRelationshipsResponse()) + retryable_client.BulkImportRelationships.return_value = mock_writer + else: + # For sync client + mock_writer = Mock() + mock_writer.send = Mock() + mock_writer.done = Mock(return_value=BulkImportRelationshipsResponse()) + retryable_client.BulkImportRelationships.return_value = mock_writer + + # Import with TOUCH conflict strategy + result = await maybe_await( + retryable_client.retryable_bulk_import_relationships( + relationships=sample_relationships, + conflict_strategy=ConflictStrategy.TOUCH, + ) + ) + + # Verify the expected methods were called + assert retryable_client.BulkImportRelationships.called + + # If we get here without errors, the test passes + if isinstance(retryable_client.BulkImportRelationships, AsyncMock): + assert mock_writer.write.called + assert mock_writer.done_writing.called + else: + assert mock_writer.send.called + assert mock_writer.done.called + + +@patch("authzed.api.v1.retryable_client.RetryableClient._is_already_exists_error") +@patch("asyncio.iscoroutinefunction") +async def test_skip_conflict_strategy(mock_is_coro, mock_already_exists, retryable_client, sample_relationships): + """Test that SKIP strategy works as expected.""" + # Configure mock behaviors + mock_already_exists.return_value = True + mock_is_coro.return_value = isinstance(retryable_client.WriteRelationships, AsyncMock) + + # Create a mock writer that raises an exception + mock_bulk_writer = Mock() + if isinstance(retryable_client.BulkImportRelationships, AsyncMock): + mock_bulk_writer = AsyncMock() + mock_bulk_writer.write = AsyncMock() + mock_bulk_writer.done_writing = AsyncMock(side_effect=grpc.RpcError("Already exists")) + else: + mock_bulk_writer.send = Mock() + mock_bulk_writer.done = Mock(side_effect=grpc.RpcError("Already exists")) + + # Set up the mock + retryable_client.BulkImportRelationships.return_value = mock_bulk_writer + + # Test import with SKIP conflict strategy + result = await maybe_await( + retryable_client.retryable_bulk_import_relationships( + relationships=sample_relationships, + conflict_strategy=ConflictStrategy.SKIP, + ) + ) + + # Should return None (indicating skipped) + assert result is None + + +@patch("authzed.api.v1.retryable_client.RetryableClient._is_already_exists_error") +@patch("authzed.api.v1.retryable_client.RetryableClient._write_batches_with_retry_sync") +@patch("authzed.api.v1.retryable_client.RetryableClient._write_batches_with_retry_async") +@patch("asyncio.iscoroutinefunction") +async def test_touch_conflict_strategy( + mock_is_coro, + mock_write_async, + mock_write_sync, + mock_already_exists, + retryable_client, + sample_relationships +): + """Test that TOUCH strategy calls the correct retry method.""" + # Configure mock behaviors + mock_already_exists.return_value = True + mock_write_sync.return_value = WriteRelationshipsResponse() + mock_write_async.return_value = WriteRelationshipsResponse() + mock_is_coro.return_value = isinstance(retryable_client.WriteRelationships, AsyncMock) + + # Create a mock writer that raises an exception + mock_bulk_writer = Mock() + if isinstance(retryable_client.BulkImportRelationships, AsyncMock): + mock_bulk_writer = AsyncMock() + mock_bulk_writer.write = AsyncMock() + mock_bulk_writer.done_writing = AsyncMock(side_effect=grpc.RpcError("Already exists")) + else: + mock_bulk_writer.send = Mock() + mock_bulk_writer.done = Mock(side_effect=grpc.RpcError("Already exists")) + + # Set up the mock + retryable_client.BulkImportRelationships.return_value = mock_bulk_writer + + # Test import with TOUCH conflict strategy + await maybe_await( + retryable_client.retryable_bulk_import_relationships( + relationships=sample_relationships, + conflict_strategy=ConflictStrategy.TOUCH, + ) + ) + + # Verify the correct retry method was called + if isinstance(retryable_client.BulkImportRelationships, AsyncMock): + assert mock_write_async.called + assert not mock_write_sync.called + else: + assert not mock_write_async.called + assert mock_write_sync.called + + +@patch("authzed.api.v1.retryable_client.RetryableClient._is_already_exists_error") +@patch("asyncio.iscoroutinefunction") +async def test_fail_conflict_strategy(mock_is_coro, mock_already_exists, retryable_client, sample_relationships): + """Test that FAIL strategy raises an error when conflicts occur.""" + # Configure mock behaviors + mock_already_exists.return_value = True + mock_is_coro.return_value = isinstance(retryable_client.WriteRelationships, AsyncMock) + + # Create a mock writer that raises an exception + mock_bulk_writer = Mock() + if isinstance(retryable_client.BulkImportRelationships, AsyncMock): + mock_bulk_writer = AsyncMock() + mock_bulk_writer.write = AsyncMock() + mock_bulk_writer.done_writing = AsyncMock(side_effect=grpc.RpcError("Already exists")) + else: + mock_bulk_writer.send = Mock() + mock_bulk_writer.done = Mock(side_effect=grpc.RpcError("Already exists")) + + # Set up the mock + retryable_client.BulkImportRelationships.return_value = mock_bulk_writer + + # Test import with FAIL conflict strategy + with pytest.raises(ValueError, match="Duplicate relationships found"): + await maybe_await( + retryable_client.retryable_bulk_import_relationships( + relationships=sample_relationships, + conflict_strategy=ConflictStrategy.FAIL, + ) + ) + + +@patch("authzed.api.v1.retryable_client.RetryableClient._is_retryable_error") +@patch("time.sleep") +async def test_retry_with_backoff_sync(mock_sleep, mock_retryable, retryable_client, sample_relationships): + """Test retrying with exponential backoff.""" + # Only run this test for sync client + if isinstance(retryable_client.WriteRelationships, AsyncMock): + pytest.skip("This test is for sync client only") + + # Setup for a retryable error that succeeds after 2 attempts + mock_retryable.side_effect = [True, True, False] + + # Create a mock that raises errors for the first two calls, then succeeds + mock_write = Mock() + mock_write.side_effect = [ + grpc.RpcError("Retryable error"), + grpc.RpcError("Retryable error"), + WriteRelationshipsResponse(), + ] + + # Apply the mock + retryable_client.WriteRelationships = mock_write + + # Test the retry mechanism + result = retryable_client._write_batches_with_retry_sync(sample_relationships, 30) + + # Verify the correct number of retries occurred + assert mock_write.call_count == 3 + assert mock_sleep.call_count == 2 + + # Verify backoff increased + assert mock_sleep.call_args_list[0][0][0] < mock_sleep.call_args_list[1][0][0] + + +@patch("authzed.api.v1.retryable_client.RetryableClient._is_retryable_error") +@patch("asyncio.sleep") +async def test_retry_with_backoff_async(mock_sleep, mock_retryable, retryable_client, sample_relationships): + """Test retrying with exponential backoff (async version).""" + # Only run this test for async client + if not isinstance(retryable_client.WriteRelationships, AsyncMock): + pytest.skip("This test is for async client only") + + # Setup for a retryable error that succeeds after 2 attempts + mock_retryable.side_effect = [True, True, False] + + # Create a mock that raises errors for the first two calls, then succeeds + mock_write = AsyncMock() + mock_write.side_effect = [ + grpc.RpcError("Retryable error"), + grpc.RpcError("Retryable error"), + WriteRelationshipsResponse(), + ] + + # Apply the mock + retryable_client.WriteRelationships = mock_write + + # Test the retry mechanism + result = await retryable_client._write_batches_with_retry_async(sample_relationships, 30) + + # Verify the correct number of retries occurred + assert mock_write.call_count == 3 + assert mock_sleep.call_count == 2 + + # Verify backoff increased + assert mock_sleep.call_args_list[0][0][0] < mock_sleep.call_args_list[1][0][0] + + +@patch("authzed.api.v1.retryable_client.RetryableClient._is_retryable_error") +async def test_max_retries_exceeded(mock_retryable, retryable_client, sample_relationships): + """Test that retry count is limited.""" + # Always report retryable error + mock_retryable.return_value = True + + # Create a mock that always raises errors + if isinstance(retryable_client.WriteRelationships, AsyncMock): + mock_write = AsyncMock(side_effect=grpc.RpcError("Always retryable error")) + retryable_client.WriteRelationships = mock_write + + # Should eventually fail after max retries + with pytest.raises(ValueError, match="Failed to write relationships after retry"): + await retryable_client._write_batches_with_retry_async(sample_relationships, 30) + else: + mock_write = Mock(side_effect=grpc.RpcError("Always retryable error")) + retryable_client.WriteRelationships = mock_write + + # Should eventually fail after max retries + with pytest.raises(ValueError, match="Failed to write relationships after retry"): + retryable_client._write_batches_with_retry_sync(sample_relationships, 30) + + # Verify retry count (DEFAULT_MAX_RETRIES + 1) + assert mock_write.call_count == 11 # 10 retries + 1 initial attempt + + +def test_error_detection_methods(): + """Test the error classification methods.""" + client = RetryableClient("localhost:50051", insecure_bearer_token_credentials("token")) + + # Test _is_already_exists_error + already_exists_error = grpc.RpcError() + already_exists_error.code = lambda: grpc.StatusCode.ALREADY_EXISTS + assert client._is_already_exists_error(already_exists_error) is True + + sql_error = Exception("SQLSTATE 23505 duplicate key value violates constraint") + assert client._is_already_exists_error(sql_error) is True + + # Test _is_retryable_error + unavailable_error = grpc.RpcError() + unavailable_error.code = lambda: grpc.StatusCode.UNAVAILABLE + assert client._is_retryable_error(unavailable_error) is True + + retry_error = Exception("retryable error: restart transaction") + assert client._is_retryable_error(retry_error) is True + + # Test _is_canceled_error + canceled_error = grpc.RpcError() + canceled_error.code = lambda: grpc.StatusCode.CANCELLED + assert client._is_canceled_error(canceled_error) is True + + timeout_error = asyncio.CancelledError() + assert client._is_canceled_error(timeout_error) is True \ No newline at end of file From a9a824ddfd692477cb982cf4b96bb2330662954c Mon Sep 17 00:00:00 2001 From: Kyle Brown Date: Mon, 10 Mar 2025 12:59:24 -0600 Subject: [PATCH 2/6] updated sync and async to correctly pass iterators --- authzed/api/v1/retryable_client.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/authzed/api/v1/retryable_client.py b/authzed/api/v1/retryable_client.py index 9316e22..7154e8a 100644 --- a/authzed/api/v1/retryable_client.py +++ b/authzed/api/v1/retryable_client.py @@ -85,14 +85,13 @@ def _retryable_bulk_import_relationships_sync( """Synchronous implementation of retryable bulk import.""" timeout = timeout_seconds or DEFAULT_TIMEOUT_SECONDS - # Try bulk import first - writer = self.BulkImportRelationships(timeout=timeout) - request = BulkImportRelationshipsRequest(relationships=relationships) - - writer.send(request) + # Create a generator function to yield requests + def request_iterator(): + yield BulkImportRelationshipsRequest(relationships=relationships) + # Try bulk import first - correctly passing the request iterator try: - response = writer.done() + response = self.BulkImportRelationships(request_iterator(), timeout=timeout) return response # Success on first try except Exception as err: # Handle errors based on type and conflict strategy @@ -124,14 +123,13 @@ async def _retryable_bulk_import_relationships_async( """Asynchronous implementation of retryable bulk import.""" timeout = timeout_seconds or DEFAULT_TIMEOUT_SECONDS - # Try bulk import first - writer = await self.BulkImportRelationships(timeout=timeout) - request = BulkImportRelationshipsRequest(relationships=relationships) - - await writer.write(request) + # Create an async generator function to yield requests + async def request_iterator(): + yield BulkImportRelationshipsRequest(relationships=relationships) + # Try bulk import first - correctly passing the request iterator try: - response = await writer.done_writing() + response = await self.BulkImportRelationships(request_iterator(), timeout=timeout) return response # Success on first try except Exception as err: # Handle errors based on type and conflict strategy From 09754e1000d4210c335d959b9f239af32755c183 Mon Sep 17 00:00:00 2001 From: Kyle Brown Date: Mon, 10 Mar 2025 13:18:25 -0600 Subject: [PATCH 3/6] update test file for retryable client --- tests/retryable_client_test.py | 62 +++++++++------------------------- 1 file changed, 16 insertions(+), 46 deletions(-) diff --git a/tests/retryable_client_test.py b/tests/retryable_client_test.py index 6c2c858..59a5090 100644 --- a/tests/retryable_client_test.py +++ b/tests/retryable_client_test.py @@ -92,18 +92,11 @@ async def test_successful_bulk_import(mock_is_coro, retryable_client, sample_rel # Configure mocks mock_is_coro.return_value = isinstance(retryable_client.WriteRelationships, AsyncMock) + # For both sync and async, simply return the response directly if isinstance(retryable_client.BulkImportRelationships, AsyncMock): - # For async client - mock_writer = AsyncMock() - mock_writer.write = AsyncMock() - mock_writer.done_writing = AsyncMock(return_value=BulkImportRelationshipsResponse()) - retryable_client.BulkImportRelationships.return_value = mock_writer + retryable_client.BulkImportRelationships.return_value = BulkImportRelationshipsResponse() else: - # For sync client - mock_writer = Mock() - mock_writer.send = Mock() - mock_writer.done = Mock(return_value=BulkImportRelationshipsResponse()) - retryable_client.BulkImportRelationships.return_value = mock_writer + retryable_client.BulkImportRelationships.return_value = BulkImportRelationshipsResponse() # Import with TOUCH conflict strategy result = await maybe_await( @@ -117,12 +110,7 @@ async def test_successful_bulk_import(mock_is_coro, retryable_client, sample_rel assert retryable_client.BulkImportRelationships.called # If we get here without errors, the test passes - if isinstance(retryable_client.BulkImportRelationships, AsyncMock): - assert mock_writer.write.called - assert mock_writer.done_writing.called - else: - assert mock_writer.send.called - assert mock_writer.done.called + assert result is not None @patch("authzed.api.v1.retryable_client.RetryableClient._is_already_exists_error") @@ -133,18 +121,12 @@ async def test_skip_conflict_strategy(mock_is_coro, mock_already_exists, retryab mock_already_exists.return_value = True mock_is_coro.return_value = isinstance(retryable_client.WriteRelationships, AsyncMock) - # Create a mock writer that raises an exception - mock_bulk_writer = Mock() + # Set up the mock to raise an error when called + error = grpc.RpcError("Already exists") if isinstance(retryable_client.BulkImportRelationships, AsyncMock): - mock_bulk_writer = AsyncMock() - mock_bulk_writer.write = AsyncMock() - mock_bulk_writer.done_writing = AsyncMock(side_effect=grpc.RpcError("Already exists")) + retryable_client.BulkImportRelationships.side_effect = error else: - mock_bulk_writer.send = Mock() - mock_bulk_writer.done = Mock(side_effect=grpc.RpcError("Already exists")) - - # Set up the mock - retryable_client.BulkImportRelationships.return_value = mock_bulk_writer + retryable_client.BulkImportRelationships.side_effect = error # Test import with SKIP conflict strategy result = await maybe_await( @@ -177,18 +159,12 @@ async def test_touch_conflict_strategy( mock_write_async.return_value = WriteRelationshipsResponse() mock_is_coro.return_value = isinstance(retryable_client.WriteRelationships, AsyncMock) - # Create a mock writer that raises an exception - mock_bulk_writer = Mock() + # Set up the mock to raise an error when called + error = grpc.RpcError("Already exists") if isinstance(retryable_client.BulkImportRelationships, AsyncMock): - mock_bulk_writer = AsyncMock() - mock_bulk_writer.write = AsyncMock() - mock_bulk_writer.done_writing = AsyncMock(side_effect=grpc.RpcError("Already exists")) + retryable_client.BulkImportRelationships.side_effect = error else: - mock_bulk_writer.send = Mock() - mock_bulk_writer.done = Mock(side_effect=grpc.RpcError("Already exists")) - - # Set up the mock - retryable_client.BulkImportRelationships.return_value = mock_bulk_writer + retryable_client.BulkImportRelationships.side_effect = error # Test import with TOUCH conflict strategy await maybe_await( @@ -215,18 +191,12 @@ async def test_fail_conflict_strategy(mock_is_coro, mock_already_exists, retryab mock_already_exists.return_value = True mock_is_coro.return_value = isinstance(retryable_client.WriteRelationships, AsyncMock) - # Create a mock writer that raises an exception - mock_bulk_writer = Mock() + # Set up the mock to raise an error when called + error = grpc.RpcError("Already exists") if isinstance(retryable_client.BulkImportRelationships, AsyncMock): - mock_bulk_writer = AsyncMock() - mock_bulk_writer.write = AsyncMock() - mock_bulk_writer.done_writing = AsyncMock(side_effect=grpc.RpcError("Already exists")) + retryable_client.BulkImportRelationships.side_effect = error else: - mock_bulk_writer.send = Mock() - mock_bulk_writer.done = Mock(side_effect=grpc.RpcError("Already exists")) - - # Set up the mock - retryable_client.BulkImportRelationships.return_value = mock_bulk_writer + retryable_client.BulkImportRelationships.side_effect = error # Test import with FAIL conflict strategy with pytest.raises(ValueError, match="Duplicate relationships found"): From 911d5337a0413d4beb11bfa0251e0fcdfc98da6d Mon Sep 17 00:00:00 2001 From: Kyle Brown Date: Wed, 12 Mar 2025 17:58:36 -0600 Subject: [PATCH 4/6] Update authzed/api/v1/retryable_client.py fix error Co-authored-by: Tanner Stirrat --- authzed/api/v1/retryable_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/authzed/api/v1/retryable_client.py b/authzed/api/v1/retryable_client.py index 7154e8a..70e1c80 100644 --- a/authzed/api/v1/retryable_client.py +++ b/authzed/api/v1/retryable_client.py @@ -150,7 +150,7 @@ async def request_iterator(): raise ValueError("Duplicate relationships found") # Default case - propagate the error - raise ValueError(f"Error finalizing write of {len(relationships)} relationships: {err}") + raise ValueError(f"Error finalizing write of {len(relationships)} relationships") from err def _write_batches_with_retry_sync(self, relationships: List[Relationship], timeout_seconds: int): """ From 7aeb7563c089448550ab5730e6edfe2faa4df739 Mon Sep 17 00:00:00 2001 From: Kyle Brown Date: Thu, 13 Mar 2025 10:20:54 -0600 Subject: [PATCH 5/6] refactor: move client implementations to client.py Signed-off-by: Kyle Brown --- authzed/api/v1/__init__.py | 128 +++------------------- authzed/api/v1/client.py | 164 +++++++++++++++++++++++++++++ authzed/api/v1/retryable_client.py | 3 +- 3 files changed, 181 insertions(+), 114 deletions(-) create mode 100644 authzed/api/v1/client.py diff --git a/authzed/api/v1/__init__.py b/authzed/api/v1/__init__.py index ee9a862..49bbe1c 100644 --- a/authzed/api/v1/__init__.py +++ b/authzed/api/v1/__init__.py @@ -1,10 +1,4 @@ -import asyncio -from typing import Any, Callable - -import grpc -import grpc.aio -from grpc_interceptor import ClientCallDetails, ClientInterceptor - +# Import core types from protocol buffer modules from authzed.api.v1.core_pb2 import ( AlgebraicSubjectSet, ContextualizedCaveat, @@ -29,7 +23,6 @@ BulkImportRelationshipsRequest, BulkImportRelationshipsResponse, ) -from authzed.api.v1.experimental_service_pb2_grpc import ExperimentalServiceStub from authzed.api.v1.permission_service_pb2 import ( CheckBulkPermissionsPair, CheckBulkPermissionsRequest, @@ -55,121 +48,31 @@ WriteRelationshipsRequest, WriteRelationshipsResponse, ) -from authzed.api.v1.permission_service_pb2_grpc import PermissionsServiceStub from authzed.api.v1.schema_service_pb2 import ( ReadSchemaRequest, ReadSchemaResponse, WriteSchemaRequest, WriteSchemaResponse, ) -from authzed.api.v1.schema_service_pb2_grpc import SchemaServiceStub from authzed.api.v1.watch_service_pb2 import WatchRequest, WatchResponse -from authzed.api.v1.watch_service_pb2_grpc import WatchServiceStub - - -class Client(SchemaServiceStub, PermissionsServiceStub, ExperimentalServiceStub, WatchServiceStub): - """ - v1 Authzed gRPC API client - Auto-detects sync or async depending on if initialized within an event loop - """ - - def __init__(self, target, credentials, options=None, compression=None): - channel = self.create_channel(target, credentials, options, compression) - self.init_stubs(channel) - - def init_stubs(self, channel): - SchemaServiceStub.__init__(self, channel) - PermissionsServiceStub.__init__(self, channel) - ExperimentalServiceStub.__init__(self, channel) - WatchServiceStub.__init__(self, channel) - - def create_channel(self, target, credentials, options=None, compression=None): - try: - asyncio.get_running_loop() - channelfn = grpc.aio.secure_channel - except RuntimeError: - channelfn = grpc.secure_channel - - return channelfn(target, credentials, options, compression) - - -class AsyncClient(Client): - """ - v1 Authzed gRPC API client, for use with asyncio. - """ - - def __init__(self, target, credentials, options=None, compression=None): - channel = grpc.aio.secure_channel(target, credentials, options, compression) - self.init_stubs(channel) - - -class SyncClient(Client): - """ - v1 Authzed gRPC API client, running synchronously. - """ - - def __init__(self, target, credentials, options=None, compression=None): - channel = grpc.secure_channel(target, credentials, options, compression) - self.init_stubs(channel) - -class TokenAuthorization(ClientInterceptor): - def __init__(self, token: str): - self._token = token - - def intercept( - self, - method: Callable, - request_or_iterator: Any, - call_details: grpc.ClientCallDetails, - ): - metadata: list[tuple[str, str | bytes]] = [("authorization", f"Bearer {self._token}")] - if call_details.metadata is not None: - metadata = [*metadata, *call_details.metadata] - - new_details = ClientCallDetails( - call_details.method, - call_details.timeout, - metadata, - call_details.credentials, - call_details.wait_for_ready, - call_details.compression, - ) - - return method(request_or_iterator, new_details) - - -class InsecureClient(Client): - """ - An insecure client variant for non-TLS contexts. - - The default behavior of the python gRPC client is to restrict non-TLS - calls to `localhost` only, which is frustrating in contexts like docker-compose, - so we provide this as a convenience. - """ - - def __init__( - self, - target: str, - token: str, - options=None, - compression=None, - ): - fake_credentials = grpc.local_channel_credentials() - channel = self.create_channel(target, fake_credentials, options, compression) - auth_interceptor = TokenAuthorization(token) - - insecure_channel = grpc.insecure_channel(target, options, compression) - channel = grpc.intercept_channel(insecure_channel, auth_interceptor) - - self.init_stubs(channel) - - -# Import after defining Client to avoid circular imports -from authzed.api.v1.retryable_client import RetryableClient, ConflictStrategy +# Import client implementations +from authzed.api.v1.client import ( + AsyncClient, + Client, + InsecureClient, + SyncClient, + TokenAuthorization, +) +from authzed.api.v1.retryable_client import ConflictStrategy, RetryableClient __all__ = [ "Client", + "AsyncClient", + "SyncClient", + "InsecureClient", + "TokenAuthorization", "RetryableClient", "ConflictStrategy", # Core @@ -198,7 +101,6 @@ def __init__( "DeleteRelationshipsResponse", "ExpandPermissionTreeRequest", "ExpandPermissionTreeResponse", - "InsecureClient", "LookupResourcesRequest", "LookupResourcesResponse", "LookupSubjectsRequest", @@ -228,4 +130,4 @@ def __init__( "BulkImportRelationshipsResponse", "BulkExportRelationshipsRequest", "BulkExportRelationshipsResponse", -] +] \ No newline at end of file diff --git a/authzed/api/v1/client.py b/authzed/api/v1/client.py new file mode 100644 index 0000000..9abbef3 --- /dev/null +++ b/authzed/api/v1/client.py @@ -0,0 +1,164 @@ +import asyncio +from typing import Any, Callable + +import grpc +import grpc.aio +from grpc_interceptor import ClientCallDetails, ClientInterceptor + +from authzed.api.v1.core_pb2 import ( + AlgebraicSubjectSet, + ContextualizedCaveat, + Cursor, + DirectSubjectSet, + ObjectReference, + PermissionRelationshipTree, + Relationship, + RelationshipUpdate, + SubjectReference, + ZedToken, +) +from authzed.api.v1.error_reason_pb2 import ErrorReason +from authzed.api.v1.experimental_service_pb2 import ( + BulkCheckPermissionPair, + BulkCheckPermissionRequest, + BulkCheckPermissionRequestItem, + BulkCheckPermissionResponse, + BulkCheckPermissionResponseItem, + BulkExportRelationshipsRequest, + BulkExportRelationshipsResponse, + BulkImportRelationshipsRequest, + BulkImportRelationshipsResponse, +) +from authzed.api.v1.experimental_service_pb2_grpc import ExperimentalServiceStub +from authzed.api.v1.permission_service_pb2 import ( + CheckBulkPermissionsPair, + CheckBulkPermissionsRequest, + CheckBulkPermissionsRequestItem, + CheckBulkPermissionsResponse, + CheckBulkPermissionsResponseItem, + CheckPermissionRequest, + CheckPermissionResponse, + Consistency, + DeleteRelationshipsRequest, + DeleteRelationshipsResponse, + ExpandPermissionTreeRequest, + ExpandPermissionTreeResponse, + LookupResourcesRequest, + LookupResourcesResponse, + LookupSubjectsRequest, + LookupSubjectsResponse, + Precondition, + ReadRelationshipsRequest, + ReadRelationshipsResponse, + RelationshipFilter, + SubjectFilter, + WriteRelationshipsRequest, + WriteRelationshipsResponse, +) +from authzed.api.v1.permission_service_pb2_grpc import PermissionsServiceStub +from authzed.api.v1.schema_service_pb2 import ( + ReadSchemaRequest, + ReadSchemaResponse, + WriteSchemaRequest, + WriteSchemaResponse, +) +from authzed.api.v1.schema_service_pb2_grpc import SchemaServiceStub +from authzed.api.v1.watch_service_pb2 import WatchRequest, WatchResponse +from authzed.api.v1.watch_service_pb2_grpc import WatchServiceStub + + +class Client(SchemaServiceStub, PermissionsServiceStub, ExperimentalServiceStub, WatchServiceStub): + """ + v1 Authzed gRPC API client - Auto-detects sync or async depending on if initialized within an event loop + """ + + def __init__(self, target, credentials, options=None, compression=None): + channel = self.create_channel(target, credentials, options, compression) + self.init_stubs(channel) + + def init_stubs(self, channel): + SchemaServiceStub.__init__(self, channel) + PermissionsServiceStub.__init__(self, channel) + ExperimentalServiceStub.__init__(self, channel) + WatchServiceStub.__init__(self, channel) + + def create_channel(self, target, credentials, options=None, compression=None): + try: + asyncio.get_running_loop() + channelfn = grpc.aio.secure_channel + except RuntimeError: + channelfn = grpc.secure_channel + + return channelfn(target, credentials, options, compression) + + +class AsyncClient(Client): + """ + v1 Authzed gRPC API client, for use with asyncio. + """ + + def __init__(self, target, credentials, options=None, compression=None): + channel = grpc.aio.secure_channel(target, credentials, options, compression) + self.init_stubs(channel) + + +class SyncClient(Client): + """ + v1 Authzed gRPC API client, running synchronously. + """ + + def __init__(self, target, credentials, options=None, compression=None): + channel = grpc.secure_channel(target, credentials, options, compression) + self.init_stubs(channel) + + +class TokenAuthorization(ClientInterceptor): + def __init__(self, token: str): + self._token = token + + def intercept( + self, + method: Callable, + request_or_iterator: Any, + call_details: grpc.ClientCallDetails, + ): + metadata: list[tuple[str, str | bytes]] = [("authorization", f"Bearer {self._token}")] + if call_details.metadata is not None: + metadata = [*metadata, *call_details.metadata] + + new_details = ClientCallDetails( + call_details.method, + call_details.timeout, + metadata, + call_details.credentials, + call_details.wait_for_ready, + call_details.compression, + ) + + return method(request_or_iterator, new_details) + + +class InsecureClient(Client): + """ + An insecure client variant for non-TLS contexts. + + The default behavior of the python gRPC client is to restrict non-TLS + calls to `localhost` only, which is frustrating in contexts like docker-compose, + so we provide this as a convenience. + """ + + def __init__( + self, + target: str, + token: str, + options=None, + compression=None, + ): + fake_credentials = grpc.local_channel_credentials() + channel = self.create_channel(target, fake_credentials, options, compression) + auth_interceptor = TokenAuthorization(token) + + insecure_channel = grpc.insecure_channel(target, options, compression) + channel = grpc.intercept_channel(insecure_channel, auth_interceptor) + + self.init_stubs(channel) \ No newline at end of file diff --git a/authzed/api/v1/retryable_client.py b/authzed/api/v1/retryable_client.py index 70e1c80..13a74c4 100644 --- a/authzed/api/v1/retryable_client.py +++ b/authzed/api/v1/retryable_client.py @@ -7,7 +7,8 @@ from google.rpc import code_pb2 from grpc import StatusCode -from authzed.api.v1 import Client, Relationship, RelationshipUpdate +from authzed.api.v1.client import Client +from authzed.api.v1.core_pb2 import Relationship, RelationshipUpdate from authzed.api.v1.experimental_service_pb2 import BulkImportRelationshipsRequest from authzed.api.v1.permission_service_pb2 import WriteRelationshipsRequest From b28dca47c84dcd0c369c1cb33520c13a9a272f12 Mon Sep 17 00:00:00 2001 From: Kyle Brown Date: Fri, 14 Mar 2025 15:37:48 -0600 Subject: [PATCH 6/6] lint type fix --- authzed/api/v1/retryable_client.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/authzed/api/v1/retryable_client.py b/authzed/api/v1/retryable_client.py index 13a74c4..5866524 100644 --- a/authzed/api/v1/retryable_client.py +++ b/authzed/api/v1/retryable_client.py @@ -130,8 +130,9 @@ async def request_iterator(): # Try bulk import first - correctly passing the request iterator try: - response = await self.BulkImportRelationships(request_iterator(), timeout=timeout) - return response # Success on first try + if asyncio.iscoroutinefunction(self.BulkImportRelationships): + response = await self.BulkImportRelationships(request_iterator(), timeout=timeout) + return response # Success on first try # Success on first try except Exception as err: # Handle errors based on type and conflict strategy if self._is_canceled_error(err): @@ -203,9 +204,10 @@ async def _write_batches_with_retry_async(self, relationships: List[Relationship while True: try: - request = WriteRelationshipsRequest(updates=updates) - response = await self.WriteRelationships(request, timeout=timeout_seconds) - return response + if asyncio.iscoroutinefunction(self.WriteRelationships): + request = WriteRelationshipsRequest(updates=updates) + response = await self.WriteRelationships(request, timeout=timeout_seconds) + return response except Exception as err: if self._is_retryable_error(err) and current_retries < DEFAULT_MAX_RETRIES: # Throttle writes with exponential backoff