22import math
33import pickle
44from random import Random
5+ from typing import Dict , List
56
67from fedscale .core .internal .client import Client
78
@@ -41,7 +42,19 @@ def __init__(self, mode, args, sample_seed=233):
4142 self .user_trace_keys = list (self .user_trace .keys ())
4243
4344 def registerClient (self , hostId , clientId , size , speed , duration = 1 ):
45+ self .register_client (hostId , clientId , size , speed , duration )
4446
47+ def register_client (self , hostId : int , clientId : int , size : int , speed : Dict [str , float ], duration : float = 1 ) -> None :
48+ """Register client information to the client manager.
49+
50+ Args:
51+ hostId (int): executor Id.
52+ clientId (int): client Id.
53+ size (int): number of samples on this client.
54+ speed (Dict[str, float]): device speed (e.g., compuutation and communication).
55+ duration (float): execution latency.
56+
57+ """
4558 uniqueId = self .getUniqueId (hostId , clientId )
4659 user_trace = None if self .user_trace is None else self .user_trace [self .user_trace_keys [int (
4760 clientId ) % len (self .user_trace )]]
@@ -60,7 +73,7 @@ def registerClient(self, hostId, clientId, size, speed, duration=1):
6073 self .ucbSampler .register_client (clientId , feedbacks = feedbacks )
6174 else :
6275 del self .Clients [uniqueId ]
63-
76+
6477 def getAllClients (self ):
6578 return self .feasibleClients
6679
@@ -80,7 +93,6 @@ def registerDuration(self, clientId, batch_size, upload_step, upload_size, downl
8093 clientId , exe_cost ['computation' ]+ exe_cost ['communication' ])
8194
8295 def getCompletionTime (self , clientId , batch_size , upload_step , upload_size , download_size ):
83-
8496 return self .Clients [self .getUniqueId (0 , clientId )].getCompletionTime (
8597 batch_size = batch_size , upload_step = upload_step ,
8698 upload_size = upload_size , download_size = download_size
@@ -91,6 +103,20 @@ def registerSpeed(self, hostId, clientId, speed):
91103 self .Clients [uniqueId ].speed = speed
92104
93105 def registerScore (self , clientId , reward , auxi = 1.0 , time_stamp = 0 , duration = 1. , success = True ):
106+ self .register_feedback (clientId , reward , auxi = auxi , time_stamp = time_stamp , duration = duration , success = success )
107+
108+ def register_feedback (self , clientId : int , reward : float , auxi : float = 1.0 , time_stamp : float = 0 , duration : float = 1. , success : bool = True ) -> None :
109+ """Collect client execution feedbacks of last round.
110+
111+ Args:
112+ clientId (int): client Id.
113+ reward (float): execution utilities (processed feedbacks).
114+ auxi (float): unprocessed feedbacks.
115+ time_stamp (float): current wall clock time.
116+ duration (float): system execution duration.
117+ success (bool): whether this client runs successfully.
118+
119+ """
94120 # currently, we only use distance as reward
95121 if self .mode == "oort" :
96122 feedbacks = {
@@ -180,27 +206,40 @@ def getFeasibleClients(self, cur_time):
180206 def isClientActive (self , clientId , cur_time ):
181207 return self .Clients [self .getUniqueId (0 , clientId )].isActive (cur_time )
182208
183- def resampleClients (self , numOfClients , cur_time = 0 ):
209+ def select_participants (self , num_of_clients : int , cur_time : float = 0 ) -> List [int ]:
210+ """Select participating clients for current execution task.
211+
212+ Args:
213+ num_of_clients (int): number of participants to select.
214+ cur_time (float): current wall clock time.
215+
216+ Returns:
217+ List[int]: indices of selected clients.
218+
219+ """
184220 self .count += 1
185221
186222 clients_online = self .getFeasibleClients (cur_time )
187223
188- if len (clients_online ) <= numOfClients :
224+ if len (clients_online ) <= num_of_clients :
189225 return clients_online
190226
191227 pickled_clients = None
192228 clients_online_set = set (clients_online )
193229
194230 if self .mode == "oort" and self .count > 1 :
195231 pickled_clients = self .ucbSampler .select_participant (
196- numOfClients , feasible_clients = clients_online_set )
232+ num_of_clients , feasible_clients = clients_online_set )
197233 else :
198234 self .rng .shuffle (clients_online )
199- client_len = min (numOfClients , len (clients_online ) - 1 )
235+ client_len = min (num_of_clients , len (clients_online ) - 1 )
200236 pickled_clients = clients_online [:client_len ]
201237
202238 return pickled_clients
203239
240+ def resampleClients (self , numOfClients , cur_time = 0 ):
241+ return self .select_participants (numOfClients , cur_time )
242+
204243 def getAllMetrics (self ):
205244 if self .mode == "oort" :
206245 return self .ucbSampler .getAllMetrics ()
0 commit comments