@@ -26,6 +26,12 @@ type ConnRequest interface {
26
26
// to decide what to do with the connection.
27
27
StreamId () string
28
28
29
+ // SocketId return the socketid of the connection.
30
+ SocketId () uint32
31
+
32
+ // PeerSocketId returns the socketid of the peer of the connection.
33
+ PeerSocketId () uint32
34
+
29
35
// IsEncrypted returns whether the connection is encrypted. If it is
30
36
// encrypted, use SetPassphrase to set the passphrase for decrypting.
31
37
IsEncrypted () bool
@@ -54,6 +60,7 @@ type connRequest struct {
54
60
addr net.Addr
55
61
start time.Time
56
62
socketId uint32
63
+ peerSocketId uint32
57
64
timestamp uint32
58
65
config Config
59
66
handshake * packet.CIFHandshake
@@ -231,13 +238,13 @@ func newConnRequest(ln *listener, p packet.Packet) *connRequest {
231
238
}
232
239
233
240
req := & connRequest {
234
- ln : ln ,
235
- addr : p .Header ().Addr ,
236
- start : time .Now (),
237
- socketId : cif .SRTSocketId ,
238
- timestamp : p .Header ().Timestamp ,
239
- config : config ,
240
- handshake : cif ,
241
+ ln : ln ,
242
+ addr : p .Header ().Addr ,
243
+ start : time .Now (),
244
+ peerSocketId : cif .SRTSocketId ,
245
+ timestamp : p .Header ().Timestamp ,
246
+ config : config ,
247
+ handshake : cif ,
241
248
}
242
249
243
250
if cif .SRTKM != nil {
@@ -263,11 +270,25 @@ func newConnRequest(ln *listener, p packet.Packet) *connRequest {
263
270
}
264
271
ln .lock .Unlock ()
265
272
266
- // we received a duplicate request: reject silently
273
+ // We received a duplicate request: reject silently
267
274
if exists {
268
275
return nil
269
276
}
270
277
278
+ // Already reserve a socketId for this connection
279
+ ln .lock .Lock ()
280
+ socketId , err := req .generateSocketId ()
281
+ if err == nil {
282
+ ln .conns [socketId ] = nil
283
+ req .socketId = socketId
284
+ }
285
+ ln .lock .Unlock ()
286
+
287
+ // We couldn't create a socketId: reject silently
288
+ if err != nil {
289
+ return nil
290
+ }
291
+
271
292
return req
272
293
} else {
273
294
if cif .HandshakeType .IsRejection () {
@@ -293,6 +314,14 @@ func (req *connRequest) StreamId() string {
293
314
return req .handshake .StreamId
294
315
}
295
316
317
+ func (req * connRequest ) SocketId () uint32 {
318
+ return req .socketId
319
+ }
320
+
321
+ func (req * connRequest ) PeerSocketId () uint32 {
322
+ return req .peerSocketId
323
+ }
324
+
296
325
func (req * connRequest ) IsEncrypted () bool {
297
326
return req .crypto != nil
298
327
}
@@ -321,7 +350,7 @@ func (req *connRequest) Reject(reason RejectionReason) {
321
350
req .ln .lock .Lock ()
322
351
defer req .ln .lock .Unlock ()
323
352
324
- if _ , hasReq := req .ln .connReqs [req .socketId ]; ! hasReq {
353
+ if _ , hasReq := req .ln .connReqs [req .peerSocketId ]; ! hasReq {
325
354
return
326
355
}
327
356
@@ -331,14 +360,15 @@ func (req *connRequest) Reject(reason RejectionReason) {
331
360
p .Header ().SubType = 0
332
361
p .Header ().TypeSpecific = 0
333
362
p .Header ().Timestamp = uint32 (time .Since (req .ln .start ).Microseconds ())
334
- p .Header ().DestinationSocketId = req .socketId
363
+ p .Header ().DestinationSocketId = req .peerSocketId
335
364
req .handshake .HandshakeType = packet .HandshakeType (reason )
336
365
p .MarshalCIF (req .handshake )
337
366
req .ln .log ("handshake:send:dump" , func () string { return p .Dump () })
338
367
req .ln .log ("handshake:send:cif" , func () string { return req .handshake .String () })
339
368
req .ln .send (p )
340
369
341
- delete (req .ln .connReqs , req .socketId )
370
+ delete (req .ln .connReqs , req .peerSocketId )
371
+ delete (req .ln .conns , req .socketId )
342
372
}
343
373
344
374
// generateSocketId generates an SRT SocketID that can be used for this connection
@@ -367,16 +397,10 @@ func (req *connRequest) Accept() (Conn, error) {
367
397
req .ln .lock .Lock ()
368
398
defer req .ln .lock .Unlock ()
369
399
370
- if _ , hasReq := req .ln .connReqs [req .socketId ]; ! hasReq {
400
+ if _ , hasReq := req .ln .connReqs [req .peerSocketId ]; ! hasReq {
371
401
return nil , fmt .Errorf ("connection already accepted" )
372
402
}
373
403
374
- // Create a new socket ID
375
- socketId , err := req .generateSocketId ()
376
- if err != nil {
377
- return nil , fmt .Errorf ("could not generate socket id: %w" , err )
378
- }
379
-
380
404
// Select the largest TSBPD delay advertised by the caller, but at least 120ms
381
405
recvTsbpdDelay := uint16 (req .config .ReceiverLatency .Milliseconds ())
382
406
sendTsbpdDelay := uint16 (req .config .PeerLatency .Milliseconds ())
@@ -402,8 +426,8 @@ func (req *connRequest) Accept() (Conn, error) {
402
426
remoteAddr : req .addr ,
403
427
config : req .config ,
404
428
start : req .start ,
405
- socketId : socketId ,
406
- peerSocketId : req .handshake . SRTSocketId ,
429
+ socketId : req . socketId ,
430
+ peerSocketId : req .peerSocketId ,
407
431
tsbpdTimeBase : uint64 (req .timestamp ),
408
432
tsbpdDelay : uint64 (recvTsbpdDelay ) * 1000 ,
409
433
peerTsbpdDelay : uint64 (sendTsbpdDelay ) * 1000 ,
@@ -417,7 +441,7 @@ func (req *connRequest) Accept() (Conn, error) {
417
441
418
442
req .ln .log ("connection:new" , func () string { return fmt .Sprintf ("%#08x (%s)" , conn .SocketId (), conn .StreamId ()) })
419
443
420
- req .handshake .SRTSocketId = socketId
444
+ req .handshake .SRTSocketId = req . socketId
421
445
req .handshake .SynCookie = 0
422
446
423
447
if req .handshake .Version == 5 {
@@ -441,14 +465,14 @@ func (req *connRequest) Accept() (Conn, error) {
441
465
p .Header ().SubType = 0
442
466
p .Header ().TypeSpecific = 0
443
467
p .Header ().Timestamp = uint32 (time .Since (req .start ).Microseconds ())
444
- p .Header ().DestinationSocketId = req .socketId
468
+ p .Header ().DestinationSocketId = req .peerSocketId
445
469
p .MarshalCIF (req .handshake )
446
470
req .ln .log ("handshake:send:dump" , func () string { return p .Dump () })
447
471
req .ln .log ("handshake:send:cif" , func () string { return req .handshake .String () })
448
472
req .ln .send (p )
449
473
450
- req .ln .conns [socketId ] = conn
451
- delete (req .ln .connReqs , req .socketId )
474
+ req .ln .conns [req . socketId ] = conn
475
+ delete (req .ln .connReqs , req .peerSocketId )
452
476
453
477
return conn , nil
454
478
}
0 commit comments