24
24
from typing import Callable , Any , Optional
25
25
from ._connections import _ConnectionPool , _WSConnection , RetryStrategy , IntervalStrategy
26
26
from ._downlinks ._downlinks import _ValueDownlinkView , _EventDownlinkView , _DownlinkView , _MapDownlinkView
27
+ from ._downlinks ._utils import validate_callback
27
28
from ._utils import _URI , after_started , exception_warn
28
29
from swimos .structures import RecordConverter
29
- from swimos .warp ._warp import _CommandMessage
30
+ from swimos .warp ._warp import _CommandMessage , _AuthRequest , _Envelope
30
31
31
32
32
33
class SwimClient :
@@ -41,7 +42,31 @@ def __init__(self, retry_strategy: RetryStrategy = IntervalStrategy(), terminate
41
42
self ._loop = None
42
43
self ._loop_thread = None
43
44
self ._has_started = False
44
- self .__connection_pool = _ConnectionPool (retry_strategy )
45
+ self ._did_auth_callback = None
46
+ self ._did_deauth_callback = None
47
+ self .authed_hosts = dict ()
48
+
49
+ self .__connection_pool = _ConnectionPool (self , retry_strategy )
50
+
51
+ def did_auth (self , function : Callable ) -> 'SwimClient' :
52
+ """
53
+ Set the `did_auth` callback of the current client to a given function.
54
+
55
+ :param function: - Function to be called when a remote host is authenticated.
56
+ :return: - The current Swim client.
57
+ """
58
+ self ._did_auth_callback = validate_callback (function )
59
+ return self
60
+
61
+ def did_deauth (self , function : Callable ) -> 'SwimClient' :
62
+ """
63
+ Set the `did_deauth` callback of the current client to a given function.
64
+
65
+ :param function: - Function to be called when a remote host is deauthenticated.
66
+ :return: - The current Swim client.
67
+ """
68
+ self ._did_deauth_callback = validate_callback (function )
69
+ return self
45
70
46
71
def __enter__ (self ) -> 'SwimClient' :
47
72
self .start ()
@@ -95,7 +120,7 @@ def stop(self) -> 'SwimClient':
95
120
96
121
return self
97
122
98
- def command (self , host_uri : str , node_uri : str , lane_uri : str , body : Any ) -> 'Future' :
123
+ def command (self , host_uri : str , node_uri : str , lane_uri : str , body : Any ):
99
124
"""
100
125
Send a command message to a command lane on a remote Swim agent.
101
126
@@ -107,6 +132,17 @@ def command(self, host_uri: str, node_uri: str, lane_uri: str, body: Any) -> 'Fu
107
132
108
133
return self ._schedule_task (self .__send_command , host_uri , node_uri , lane_uri , body )
109
134
135
+ def authenticate (self , host_uri : str , body : Any ):
136
+ """
137
+ Send an authentication request to a remote Swim server.
138
+
139
+ :param host_uri: - Host URI of the remote server.
140
+ :param body: - The authentication message body.
141
+ """
142
+
143
+ self .authed_hosts [host_uri ] = asyncio .Event ()
144
+ return self ._schedule_task (self .__authenticate , host_uri , body )
145
+
110
146
def downlink_event (self ) -> '_EventDownlinkView' :
111
147
"""
112
148
Create an Event Downlink.
@@ -156,6 +192,26 @@ async def _get_connection(self, host_uri: str, scheme: str, keep_linked: bool,
156
192
connection = await self .__connection_pool ._get_connection (host_uri , scheme , keep_linked , keep_synced )
157
193
return connection
158
194
195
+ async def _execute_did_auth (self , host_uri : str , message : '_Envelope' ) -> None :
196
+ """
197
+ Execute the custom `did_auth` callback of the current Swim client.
198
+
199
+ :param host_uri: - Uri of the remote host.
200
+ :param message: - Message received from the remote host.
201
+ """
202
+ if self ._did_auth_callback :
203
+ self ._schedule_task (self ._did_auth_callback , host_uri , message )
204
+
205
+ async def _execute_did_deauth (self , host_uri : str , message : '_Envelope' ) -> None :
206
+ """
207
+ Execute the custom `did_deauth` callback of the current Swim client.
208
+
209
+ :param host_uri: - Uri of the remote host.
210
+ :param message: - Message received from the remote host.
211
+ """
212
+ if self ._did_deauth_callback :
213
+ self ._schedule_task (self ._did_deauth_callback , host_uri , message )
214
+
159
215
@after_started
160
216
def _schedule_task (self , task : Callable , * args : Any ) -> 'Future' :
161
217
"""
@@ -222,6 +278,22 @@ async def __send_command(self, host_uri: str, node_uri: str, lane_uri: str, body
222
278
connection = await self ._get_connection (host_uri , scheme , True , False )
223
279
await connection ._send_message (message ._to_recon ())
224
280
281
+ async def __authenticate (self , host_uri : str , body : Any ) -> None :
282
+ """
283
+ Send an authentication request to a given host.
284
+
285
+ :param host_uri: - Host URI of the remote host.
286
+ :param body: - The authentication message body.
287
+ """
288
+ record = RecordConverter .get_converter ().object_to_record (body )
289
+ host_uri , scheme = _URI ._parse_uri (host_uri )
290
+ message = _AuthRequest (body = record )
291
+ connection = await self ._get_connection (host_uri , scheme , True , False )
292
+ connection ._set_auth_message (message ._to_recon ())
293
+ await connection ._open ()
294
+ self ._schedule_task (connection ._wait_for_messages )
295
+ await connection ._send_message (message ._to_recon ())
296
+
225
297
def __start_event_loop (self ) -> None :
226
298
asyncio .set_event_loop (self ._loop )
227
299
asyncio .get_event_loop ().run_forever ()
0 commit comments