Skip to content

Commit 56a1f03

Browse files
committed
Add SocketId() and PeerSocketId() to ConnRequest (#104)
1 parent f4b464a commit 56a1f03

File tree

2 files changed

+52
-25
lines changed

2 files changed

+52
-25
lines changed

conn_request.go

Lines changed: 48 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@ type ConnRequest interface {
2626
// to decide what to do with the connection.
2727
StreamId() string
2828

29+
// SocketId return the socketid of the connection.
30+
SocketId() uint32
31+
32+
// PeerSocketId returns the socketid of the peer of the connection.
33+
PeerSocketId() uint32
34+
2935
// IsEncrypted returns whether the connection is encrypted. If it is
3036
// encrypted, use SetPassphrase to set the passphrase for decrypting.
3137
IsEncrypted() bool
@@ -54,6 +60,7 @@ type connRequest struct {
5460
addr net.Addr
5561
start time.Time
5662
socketId uint32
63+
peerSocketId uint32
5764
timestamp uint32
5865
config Config
5966
handshake *packet.CIFHandshake
@@ -231,13 +238,13 @@ func newConnRequest(ln *listener, p packet.Packet) *connRequest {
231238
}
232239

233240
req := &connRequest{
234-
ln: ln,
235-
addr: p.Header().Addr,
236-
start: time.Now(),
237-
socketId: cif.SRTSocketId,
238-
timestamp: p.Header().Timestamp,
239-
config: config,
240-
handshake: cif,
241+
ln: ln,
242+
addr: p.Header().Addr,
243+
start: time.Now(),
244+
peerSocketId: cif.SRTSocketId,
245+
timestamp: p.Header().Timestamp,
246+
config: config,
247+
handshake: cif,
241248
}
242249

243250
if cif.SRTKM != nil {
@@ -263,11 +270,25 @@ func newConnRequest(ln *listener, p packet.Packet) *connRequest {
263270
}
264271
ln.lock.Unlock()
265272

266-
// we received a duplicate request: reject silently
273+
// We received a duplicate request: reject silently
267274
if exists {
268275
return nil
269276
}
270277

278+
// Already reserve a socketId for this connection
279+
ln.lock.Lock()
280+
socketId, err := req.generateSocketId()
281+
if err == nil {
282+
ln.conns[socketId] = nil
283+
req.socketId = socketId
284+
}
285+
ln.lock.Unlock()
286+
287+
// We couldn't create a socketId: reject silently
288+
if err != nil {
289+
return nil
290+
}
291+
271292
return req
272293
} else {
273294
if cif.HandshakeType.IsRejection() {
@@ -293,6 +314,14 @@ func (req *connRequest) StreamId() string {
293314
return req.handshake.StreamId
294315
}
295316

317+
func (req *connRequest) SocketId() uint32 {
318+
return req.socketId
319+
}
320+
321+
func (req *connRequest) PeerSocketId() uint32 {
322+
return req.peerSocketId
323+
}
324+
296325
func (req *connRequest) IsEncrypted() bool {
297326
return req.crypto != nil
298327
}
@@ -321,7 +350,7 @@ func (req *connRequest) Reject(reason RejectionReason) {
321350
req.ln.lock.Lock()
322351
defer req.ln.lock.Unlock()
323352

324-
if _, hasReq := req.ln.connReqs[req.socketId]; !hasReq {
353+
if _, hasReq := req.ln.connReqs[req.peerSocketId]; !hasReq {
325354
return
326355
}
327356

@@ -331,14 +360,15 @@ func (req *connRequest) Reject(reason RejectionReason) {
331360
p.Header().SubType = 0
332361
p.Header().TypeSpecific = 0
333362
p.Header().Timestamp = uint32(time.Since(req.ln.start).Microseconds())
334-
p.Header().DestinationSocketId = req.socketId
363+
p.Header().DestinationSocketId = req.peerSocketId
335364
req.handshake.HandshakeType = packet.HandshakeType(reason)
336365
p.MarshalCIF(req.handshake)
337366
req.ln.log("handshake:send:dump", func() string { return p.Dump() })
338367
req.ln.log("handshake:send:cif", func() string { return req.handshake.String() })
339368
req.ln.send(p)
340369

341-
delete(req.ln.connReqs, req.socketId)
370+
delete(req.ln.connReqs, req.peerSocketId)
371+
delete(req.ln.conns, req.socketId)
342372
}
343373

344374
// generateSocketId generates an SRT SocketID that can be used for this connection
@@ -367,16 +397,10 @@ func (req *connRequest) Accept() (Conn, error) {
367397
req.ln.lock.Lock()
368398
defer req.ln.lock.Unlock()
369399

370-
if _, hasReq := req.ln.connReqs[req.socketId]; !hasReq {
400+
if _, hasReq := req.ln.connReqs[req.peerSocketId]; !hasReq {
371401
return nil, fmt.Errorf("connection already accepted")
372402
}
373403

374-
// Create a new socket ID
375-
socketId, err := req.generateSocketId()
376-
if err != nil {
377-
return nil, fmt.Errorf("could not generate socket id: %w", err)
378-
}
379-
380404
// Select the largest TSBPD delay advertised by the caller, but at least 120ms
381405
recvTsbpdDelay := uint16(req.config.ReceiverLatency.Milliseconds())
382406
sendTsbpdDelay := uint16(req.config.PeerLatency.Milliseconds())
@@ -402,8 +426,8 @@ func (req *connRequest) Accept() (Conn, error) {
402426
remoteAddr: req.addr,
403427
config: req.config,
404428
start: req.start,
405-
socketId: socketId,
406-
peerSocketId: req.handshake.SRTSocketId,
429+
socketId: req.socketId,
430+
peerSocketId: req.peerSocketId,
407431
tsbpdTimeBase: uint64(req.timestamp),
408432
tsbpdDelay: uint64(recvTsbpdDelay) * 1000,
409433
peerTsbpdDelay: uint64(sendTsbpdDelay) * 1000,
@@ -417,7 +441,7 @@ func (req *connRequest) Accept() (Conn, error) {
417441

418442
req.ln.log("connection:new", func() string { return fmt.Sprintf("%#08x (%s)", conn.SocketId(), conn.StreamId()) })
419443

420-
req.handshake.SRTSocketId = socketId
444+
req.handshake.SRTSocketId = req.socketId
421445
req.handshake.SynCookie = 0
422446

423447
if req.handshake.Version == 5 {
@@ -441,14 +465,14 @@ func (req *connRequest) Accept() (Conn, error) {
441465
p.Header().SubType = 0
442466
p.Header().TypeSpecific = 0
443467
p.Header().Timestamp = uint32(time.Since(req.start).Microseconds())
444-
p.Header().DestinationSocketId = req.socketId
468+
p.Header().DestinationSocketId = req.peerSocketId
445469
p.MarshalCIF(req.handshake)
446470
req.ln.log("handshake:send:dump", func() string { return p.Dump() })
447471
req.ln.log("handshake:send:cif", func() string { return req.handshake.String() })
448472
req.ln.send(p)
449473

450-
req.ln.conns[socketId] = conn
451-
delete(req.ln.connReqs, req.socketId)
474+
req.ln.conns[req.socketId] = conn
475+
delete(req.ln.connReqs, req.peerSocketId)
452476

453477
return conn, nil
454478
}

listen.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,9 @@ func (ln *listener) Close() {
347347

348348
ln.lock.RLock()
349349
for _, conn := range ln.conns {
350+
if conn == nil {
351+
continue
352+
}
350353
conn.close()
351354
}
352355
ln.lock.RUnlock()
@@ -402,7 +405,7 @@ func (ln *listener) reader(ctx context.Context) {
402405
conn, ok := ln.conns[p.Header().DestinationSocketId]
403406
ln.lock.RUnlock()
404407

405-
if !ok {
408+
if !ok || conn == nil {
406409
// ignore the packet, we don't know the destination
407410
break
408411
}

0 commit comments

Comments
 (0)