Skip to content
15 changes: 9 additions & 6 deletions supabase/_async/client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import copy
import re
from typing import Any, Dict, List, Optional, Union

Expand Down Expand Up @@ -69,8 +70,9 @@ def __init__(

self.supabase_url = supabase_url
self.supabase_key = supabase_key
self.options = options
options.headers.update(self._get_auth_headers())
self.options = copy.deepcopy(options)
self.options.headers.update(self._get_auth_headers())

self.rest_url = f"{supabase_url}/rest/v1"
self.realtime_url = f"{supabase_url}/realtime/v1".replace("http", "ws")
self.auth_url = f"{supabase_url}/auth/v1"
Expand All @@ -80,12 +82,12 @@ def __init__(
# Instantiate clients.
self.auth = self._init_supabase_auth_client(
auth_url=self.auth_url,
client_options=options,
client_options=self.options,
)
self.realtime = self._init_realtime_client(
realtime_url=self.realtime_url,
supabase_key=self.supabase_key,
options=options.realtime if options else None,
options=self.options.realtime if self.options else None,
)
self._postgrest = None
self._storage = None
Expand Down Expand Up @@ -294,8 +296,9 @@ def _listen_to_auth_events(
self._storage = None
self._functions = None
access_token = session.access_token if session else self.supabase_key

self.options.headers["Authorization"] = self._create_auth_header(access_token)
auth_header = copy.deepcopy(self._create_auth_header(access_token))
self.options.headers["Authorization"] = auth_header
self.auth._headers["Authorization"] = auth_header
asyncio.create_task(self.realtime.set_auth(access_token))


Expand Down
13 changes: 8 additions & 5 deletions supabase/_sync/client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import re
from typing import Any, Dict, List, Optional, Union

Expand Down Expand Up @@ -68,8 +69,9 @@ def __init__(

self.supabase_url = supabase_url
self.supabase_key = supabase_key
self.options = options
options.headers.update(self._get_auth_headers())
self.options = copy.deepcopy(options)
self.options.headers.update(self._get_auth_headers())

self.rest_url = f"{supabase_url}/rest/v1"
self.realtime_url = f"{supabase_url}/realtime/v1".replace("http", "ws")
self.auth_url = f"{supabase_url}/auth/v1"
Expand All @@ -79,12 +81,12 @@ def __init__(
# Instantiate clients.
self.auth = self._init_supabase_auth_client(
auth_url=self.auth_url,
client_options=options,
client_options=self.options,
)
self.realtime = self._init_realtime_client(
realtime_url=self.realtime_url,
supabase_key=self.supabase_key,
options=options.realtime if options else None,
options=self.options.realtime if self.options else None,
)
self._postgrest = None
self._storage = None
Expand Down Expand Up @@ -293,8 +295,9 @@ def _listen_to_auth_events(
self._storage = None
self._functions = None
access_token = session.access_token if session else self.supabase_key
auth_header = copy.deepcopy(self._create_auth_header(access_token))

self.options.headers["Authorization"] = self._create_auth_header(access_token)
self.options.headers["Authorization"] = auth_header


def create_client(
Expand Down
2 changes: 1 addition & 1 deletion tests/_async/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
async def test_incorrect_values_dont_instantiate_client() -> None:
"""Ensure we can't instantiate client with invalid values."""
try:
client: AClient = create_async_client(None, None)
client: AClient = await create_async_client(None, None)
except ASupabaseException:
pass

Expand Down
34 changes: 31 additions & 3 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from unittest.mock import MagicMock

import pytest
from gotrue import SyncMemoryStorage

from supabase import Client, ClientOptions, SupabaseException, create_client

Expand Down Expand Up @@ -70,7 +71,7 @@ def test_supports_setting_a_global_authorization_header() -> None:
url = os.environ.get("SUPABASE_TEST_URL")
key = os.environ.get("SUPABASE_TEST_KEY")

authorization = f"Bearer secretuserjwt"
authorization = "Bearer secretuserjwt"

options = ClientOptions(headers={"Authorization": authorization})

Expand Down Expand Up @@ -101,7 +102,6 @@ def test_updates_the_authorization_header_on_auth_events() -> None:
mock_session = MagicMock(access_token="secretuserjwt")
realtime_mock = MagicMock()
client.realtime = realtime_mock

client._listen_to_auth_events("SIGNED_IN", mock_session)

updated_authorization = f"Bearer {mock_session.access_token}"
Expand All @@ -113,9 +113,37 @@ def test_updates_the_authorization_header_on_auth_events() -> None:
assert (
client.postgrest.session.headers.get("Authorization") == updated_authorization
)

assert client.auth._headers.get("apiKey") == key
assert client.auth._headers.get("Authorization") == updated_authorization

assert client.storage.session.headers.get("apiKey") == key
assert client.storage.session.headers.get("Authorization") == updated_authorization


def test_mutable_headers_issue():
url = os.environ.get("SUPABASE_TEST_URL")
key = os.environ.get("SUPABASE_TEST_KEY")

shared_options = ClientOptions(
storage=SyncMemoryStorage(), headers={"Authorization": "Bearer initial-token"}
)

client1 = create_client(url, key, shared_options)
client2 = create_client(url, key, shared_options)

client1.options.headers["Authorization"] = "Bearer modified-token"

assert client2.options.headers["Authorization"] == "Bearer initial-token"
assert client1.options.headers["Authorization"] == "Bearer modified-token"


def test_global_authorization_header_issue():
url = os.environ.get("SUPABASE_TEST_URL")
key = os.environ.get("SUPABASE_TEST_KEY")

authorization = "Bearer secretuserjwt"
options = ClientOptions(headers={"Authorization": authorization})

client = create_client(url, key, options)

assert client.options.headers.get("apiKey") == key