Skip to content

Commit b643ebe

Browse files
committed
* Added did_auth and did_deauth callbacks.
* Fixed bug in the parsing of host addressed envelopes. * Added automatic re-authentication when retrying is enabled.
1 parent 4cbf0dc commit b643ebe

File tree

4 files changed

+256
-17
lines changed

4 files changed

+256
-17
lines changed

swimos/client/_connections.py

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
if TYPE_CHECKING:
2525
from ._downlinks._downlinks import _DownlinkModel
2626
from ._downlinks._downlinks import _DownlinkView
27+
from .. import SwimClient
2728

2829

2930
class RetryStrategy:
@@ -79,7 +80,8 @@ def reset(self):
7980

8081
class _ConnectionPool:
8182

82-
def __init__(self, retry_strategy: RetryStrategy = RetryStrategy()) -> None:
83+
def __init__(self, client: 'SwimClient', retry_strategy: RetryStrategy = RetryStrategy()) -> None:
84+
self.__client = client
8385
self.__connections = dict()
8486
self.retry_strategy = retry_strategy
8587

@@ -102,7 +104,7 @@ async def _get_connection(self, host_uri: str, scheme: str, keep_linked: bool,
102104
connection = self.__connections.get(host_uri)
103105

104106
if connection is None or connection.status == _ConnectionStatus.CLOSED:
105-
connection = _WSConnection(host_uri, scheme, keep_linked, keep_synced, self.retry_strategy)
107+
connection = _WSConnection(self.__client, host_uri, scheme, keep_linked, keep_synced, self.retry_strategy)
106108
self.__connections[host_uri] = connection
107109

108110
return connection
@@ -154,7 +156,7 @@ async def _remove_downlink_view(self, downlink_view: '_DownlinkView') -> None:
154156

155157
class _WSConnection:
156158

157-
def __init__(self, host_uri: str, scheme: str, keep_linked, keep_synced,
159+
def __init__(self, client: 'SwimClient', host_uri: str, scheme: str, keep_linked, keep_synced,
158160
retry_strategy: RetryStrategy = RetryStrategy()) -> None:
159161
self.host_uri = host_uri
160162
self.scheme = scheme
@@ -163,12 +165,15 @@ def __init__(self, host_uri: str, scheme: str, keep_linked, keep_synced,
163165
self.connected = asyncio.Event()
164166
self.websocket = None
165167
self.status = _ConnectionStatus.CLOSED
168+
self.auth_message = None
166169
self.init_message = None
167170

168171
self.keep_linked = keep_linked
169172
self.keep_synced = keep_synced
170173

171174
self.__subscribers = _DownlinkManagerPool()
175+
self.__authenticated = asyncio.Event()
176+
self.__client = client
172177

173178
async def _open(self) -> None:
174179
if self.status == _ConnectionStatus.CLOSED:
@@ -211,6 +216,20 @@ def should_reconnect(self) -> bool:
211216
"""
212217
return self.keep_linked or self.keep_synced
213218

219+
def _set_auth_message(self, message: str) -> None:
220+
"""
221+
Set the initial auth message that gets sent when the underlying downlink is established.
222+
"""
223+
224+
self.auth_message = message
225+
226+
async def _send_auth_message(self) -> None:
227+
"""
228+
Send the initial auth message for the underlying downlink if it is set.
229+
"""
230+
if self.auth_message is not None:
231+
await self._send_message(self.auth_message)
232+
214233
def _set_init_message(self, message: str) -> None:
215234
"""
216235
Set the initial message that gets sent when the underlying downlink is established.
@@ -283,15 +302,53 @@ async def _wait_for_messages(self) -> None:
283302
while self.status == _ConnectionStatus.RUNNING:
284303
message = await self.websocket.recv()
285304
response = _Envelope._parse_recon(message)
286-
await self.__subscribers._receive_message(response)
305+
306+
if response._route:
307+
await self.__subscribers._receive_message(response)
308+
else:
309+
await self._receive_message(self.host_uri, response)
287310
except ConnectionClosed as error:
288311
exception_warn(error)
289312
await self._close()
290313
if self.should_reconnect() and await self.retry_strategy.retry():
291314
await self._open()
315+
await self._send_auth_message()
292316
await self._send_init_message()
293317
continue
294318

319+
async def _receive_message(self, host_uri: str, message: '_Envelope') -> None:
320+
"""
321+
Receive a host addressed message from the remote host.
322+
323+
:param host_uri: - Uri of the remote host.
324+
:param message: - Message received from the remote host.
325+
"""
326+
327+
if message._tag == 'authed':
328+
await self._receive_authed(host_uri, message)
329+
elif message._tag == 'deauthed':
330+
await self._receive_deauthed(host_uri, message)
331+
332+
async def _receive_authed(self, host_uri: str, message: '_Envelope') -> None:
333+
"""
334+
Handle an `authed` response message from the remote agent.
335+
336+
:param host_uri: - Uri of the remote host.
337+
:param message: - Message received from the remote host.
338+
"""
339+
self.__authenticated.set()
340+
await self.__client._execute_did_auth(host_uri, message)
341+
342+
async def _receive_deauthed(self, host_uri: str, message: '_Envelope') -> None:
343+
"""
344+
Handle a `deauthed` response message from the remote agent.
345+
346+
:param host_uri: - Uri of the remote host.
347+
:param message: - Message received from the remote host.
348+
"""
349+
self.__authenticated.clear()
350+
await self.__client._execute_did_deauth(host_uri, message)
351+
295352

296353
class _ConnectionStatus(Enum):
297354
CLOSED = 0

swimos/client/_swim_client.py

Lines changed: 75 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@
2424
from typing import Callable, Any, Optional
2525
from ._connections import _ConnectionPool, _WSConnection, RetryStrategy, IntervalStrategy
2626
from ._downlinks._downlinks import _ValueDownlinkView, _EventDownlinkView, _DownlinkView, _MapDownlinkView
27+
from ._downlinks._utils import validate_callback
2728
from ._utils import _URI, after_started, exception_warn
2829
from swimos.structures import RecordConverter
29-
from swimos.warp._warp import _CommandMessage
30+
from swimos.warp._warp import _CommandMessage, _AuthRequest, _Envelope
3031

3132

3233
class SwimClient:
@@ -41,7 +42,31 @@ def __init__(self, retry_strategy: RetryStrategy = IntervalStrategy(), terminate
4142
self._loop = None
4243
self._loop_thread = None
4344
self._has_started = False
44-
self.__connection_pool = _ConnectionPool(retry_strategy)
45+
self._did_auth_callback = None
46+
self._did_deauth_callback = None
47+
self.authed_hosts = dict()
48+
49+
self.__connection_pool = _ConnectionPool(self, retry_strategy)
50+
51+
def did_auth(self, function: Callable) -> 'SwimClient':
52+
"""
53+
Set the `did_auth` callback of the current client to a given function.
54+
55+
:param function: - Function to be called when a remote host is authenticated.
56+
:return: - The current Swim client.
57+
"""
58+
self._did_auth_callback = validate_callback(function)
59+
return self
60+
61+
def did_deauth(self, function: Callable) -> 'SwimClient':
62+
"""
63+
Set the `did_deauth` callback of the current client to a given function.
64+
65+
:param function: - Function to be called when a remote host is deauthenticated.
66+
:return: - The current Swim client.
67+
"""
68+
self._did_deauth_callback = validate_callback(function)
69+
return self
4570

4671
def __enter__(self) -> 'SwimClient':
4772
self.start()
@@ -95,7 +120,7 @@ def stop(self) -> 'SwimClient':
95120

96121
return self
97122

98-
def command(self, host_uri: str, node_uri: str, lane_uri: str, body: Any) -> 'Future':
123+
def command(self, host_uri: str, node_uri: str, lane_uri: str, body: Any):
99124
"""
100125
Send a command message to a command lane on a remote Swim agent.
101126
@@ -107,6 +132,17 @@ def command(self, host_uri: str, node_uri: str, lane_uri: str, body: Any) -> 'Fu
107132

108133
return self._schedule_task(self.__send_command, host_uri, node_uri, lane_uri, body)
109134

135+
def authenticate(self, host_uri: str, body: Any):
136+
"""
137+
Send an authentication request to a remote Swim server.
138+
139+
:param host_uri: - Host URI of the remote server.
140+
:param body: - The authentication message body.
141+
"""
142+
143+
self.authed_hosts[host_uri] = asyncio.Event()
144+
return self._schedule_task(self.__authenticate, host_uri, body)
145+
110146
def downlink_event(self) -> '_EventDownlinkView':
111147
"""
112148
Create an Event Downlink.
@@ -156,6 +192,26 @@ async def _get_connection(self, host_uri: str, scheme: str, keep_linked: bool,
156192
connection = await self.__connection_pool._get_connection(host_uri, scheme, keep_linked, keep_synced)
157193
return connection
158194

195+
async def _execute_did_auth(self, host_uri: str, message: '_Envelope') -> None:
196+
"""
197+
Execute the custom `did_auth` callback of the current Swim client.
198+
199+
:param host_uri: - Uri of the remote host.
200+
:param message: - Message received from the remote host.
201+
"""
202+
if self._did_auth_callback:
203+
self._schedule_task(self._did_auth_callback, host_uri, message)
204+
205+
async def _execute_did_deauth(self, host_uri: str, message: '_Envelope') -> None:
206+
"""
207+
Execute the custom `did_deauth` callback of the current Swim client.
208+
209+
:param host_uri: - Uri of the remote host.
210+
:param message: - Message received from the remote host.
211+
"""
212+
if self._did_deauth_callback:
213+
self._schedule_task(self._did_deauth_callback, host_uri, message)
214+
159215
@after_started
160216
def _schedule_task(self, task: Callable, *args: Any) -> 'Future':
161217
"""
@@ -222,6 +278,22 @@ async def __send_command(self, host_uri: str, node_uri: str, lane_uri: str, body
222278
connection = await self._get_connection(host_uri, scheme, True, False)
223279
await connection._send_message(message._to_recon())
224280

281+
async def __authenticate(self, host_uri: str, body: Any) -> None:
282+
"""
283+
Send an authentication request to a given host.
284+
285+
:param host_uri: - Host URI of the remote host.
286+
:param body: - The authentication message body.
287+
"""
288+
record = RecordConverter.get_converter().object_to_record(body)
289+
host_uri, scheme = _URI._parse_uri(host_uri)
290+
message = _AuthRequest(body=record)
291+
connection = await self._get_connection(host_uri, scheme, True, False)
292+
connection._set_auth_message(message._to_recon())
293+
await connection._open()
294+
self._schedule_task(connection._wait_for_messages)
295+
await connection._send_message(message._to_recon())
296+
225297
def __start_event_loop(self) -> None:
226298
asyncio.set_event_loop(self._loop)
227299
asyncio.get_event_loop().run_forever()

swimos/recon/_writers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def _write(key: 'Value' = None, writer: '_ReconWriter' = None, value: 'Value' =
135135
if key_text:
136136
output._append(key_text)
137137

138-
if value != _Extant._get_extant() and value is not None:
138+
if value != _Extant._get_extant() and value != _Absent._get_absent() and value is not None:
139139
output._append('(')
140140
value_text = writer._write_value(value)
141141
output._append(value_text)

0 commit comments

Comments
 (0)