Skip to content

Commit f6a64e7

Browse files
authored
Merge pull request #389 from kzys/dial-context
Define vsock.Dial and vsock.DialContext as like net and grpc
2 parents 12bd5d8 + 43ce2e9 commit f6a64e7

File tree

2 files changed

+78
-25
lines changed

2 files changed

+78
-25
lines changed

vsock/dial.go

Lines changed: 73 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"bufio"
1818
"context"
1919
"fmt"
20+
"io/ioutil"
2021
"net"
2122
"strings"
2223
"time"
@@ -25,21 +26,64 @@ import (
2526
"github.com/sirupsen/logrus"
2627
)
2728

28-
type Timeout struct {
29+
type config struct {
30+
logger logrus.FieldLogger
2931
DialTimeout time.Duration
3032
RetryTimeout time.Duration
3133
RetryInterval time.Duration
3234
ConnectMsgTimeout time.Duration
3335
AckMsgTimeout time.Duration
3436
}
3537

36-
func DefaultTimeouts() Timeout {
37-
return Timeout{
38+
func defaultConfig() config {
39+
noop := logrus.New()
40+
noop.Out = ioutil.Discard
41+
42+
return config{
3843
DialTimeout: 100 * time.Millisecond,
3944
RetryTimeout: 20 * time.Second,
4045
RetryInterval: 100 * time.Millisecond,
4146
ConnectMsgTimeout: 100 * time.Millisecond,
4247
AckMsgTimeout: 1 * time.Second,
48+
logger: noop,
49+
}
50+
}
51+
52+
type DialOption func(c *config)
53+
54+
func WithDialTimeout(d time.Duration) DialOption {
55+
return func(c *config) {
56+
c.DialTimeout = d
57+
}
58+
}
59+
60+
func WithRetryTimeout(d time.Duration) DialOption {
61+
return func(c *config) {
62+
c.RetryTimeout = d
63+
}
64+
}
65+
66+
func WithRetryInterval(d time.Duration) DialOption {
67+
return func(c *config) {
68+
c.RetryInterval = d
69+
}
70+
}
71+
72+
func WithConnectionMsgTimeout(d time.Duration) DialOption {
73+
return func(c *config) {
74+
c.ConnectMsgTimeout = d
75+
}
76+
}
77+
78+
func WithAckMsgTimeout(d time.Duration) DialOption {
79+
return func(c *config) {
80+
c.AckMsgTimeout = d
81+
}
82+
}
83+
84+
func WithLogger(logger logrus.FieldLogger) DialOption {
85+
return func(c *config) {
86+
c.logger = logger
4387
}
4488
}
4589

@@ -48,20 +92,31 @@ func DefaultTimeouts() Timeout {
4892
// It will retry connect attempts if a temporary error is encountered up to a fixed
4993
// timeout or the provided request is canceled.
5094
//
51-
// udsPath specifies the file system path of the UNIX domain socket.
95+
// path specifies the file system path of the UNIX domain socket.
5296
//
5397
// port will be used in the connect message to the firecracker vsock.
54-
func Dial(ctx context.Context, logger *logrus.Entry, udsPath string, port uint32) (net.Conn, error) {
55-
return DialTimeout(ctx, logger, udsPath, port, DefaultTimeouts())
98+
func DialContext(ctx context.Context, path string, port uint32, opts ...DialOption) (net.Conn, error) {
99+
t := defaultConfig()
100+
for _, o := range opts {
101+
o(&t)
102+
}
103+
104+
return dial(ctx, path, port, t)
56105
}
57106

58-
// DialTimeout acts like Dial but takes a timeout.
107+
// Dial connects to the Firecracker host-side vsock at the provided unix path and port.
59108
//
60-
// See func Dial for a description of the udsPath and port parameters.
61-
func DialTimeout(ctx context.Context, logger *logrus.Entry, udsPath string, port uint32, timeout Timeout) (net.Conn, error) {
62-
ticker := time.NewTicker(timeout.RetryInterval)
109+
// See func Dial for a description of the path and port parameters.
110+
func Dial(path string, port uint32, opts ...DialOption) (net.Conn, error) {
111+
return DialContext(context.Background(), path, port, opts...)
112+
}
113+
114+
func dial(ctx context.Context, udsPath string, port uint32, c config) (net.Conn, error) {
115+
ticker := time.NewTicker(c.RetryInterval)
63116
defer ticker.Stop()
64117

118+
logger := c.logger
119+
65120
tickerCh := ticker.C
66121
var attemptCount int
67122
for {
@@ -72,7 +127,7 @@ func DialTimeout(ctx context.Context, logger *logrus.Entry, udsPath string, port
72127
case <-ctx.Done():
73128
return nil, ctx.Err()
74129
case <-tickerCh:
75-
conn, err := tryConnect(logger, udsPath, port, timeout)
130+
conn, err := tryConnect(logger, udsPath, port, c)
76131
if isTemporaryNetErr(err) {
77132
err = errors.Wrap(err, "temporary vsock dial failure")
78133
logger.WithError(err).Debug()
@@ -98,10 +153,10 @@ func connectMsg(port uint32) string {
98153

99154
// tryConnect attempts to dial a guest vsock listener at the provided host-side
100155
// unix socket and provided guest-listener port.
101-
func tryConnect(logger *logrus.Entry, udsPath string, port uint32, timeout Timeout) (net.Conn, error) {
102-
conn, err := net.DialTimeout("unix", udsPath, timeout.DialTimeout)
156+
func tryConnect(logger *logrus.Entry, udsPath string, port uint32, c config) (net.Conn, error) {
157+
conn, err := net.DialTimeout("unix", udsPath, c.DialTimeout)
103158
if err != nil {
104-
return nil, errors.Wrapf(err, "failed to dial %q within %s", udsPath, timeout.DialTimeout)
159+
return nil, errors.Wrapf(err, "failed to dial %q within %s", udsPath, c.DialTimeout)
105160
}
106161

107162
defer func() {
@@ -115,17 +170,17 @@ func tryConnect(logger *logrus.Entry, udsPath string, port uint32, timeout Timeo
115170
}()
116171

117172
msg := connectMsg(port)
118-
err = tryConnWrite(conn, msg, timeout.ConnectMsgTimeout)
173+
err = tryConnWrite(conn, msg, c.ConnectMsgTimeout)
119174
if err != nil {
120175
return nil, connectMsgError{
121-
cause: errors.Wrapf(err, `failed to write %q within %s`, msg, timeout.ConnectMsgTimeout),
176+
cause: errors.Wrapf(err, `failed to write %q within %s`, msg, c.ConnectMsgTimeout),
122177
}
123178
}
124179

125-
line, err := tryConnReadUntil(conn, '\n', timeout.AckMsgTimeout)
180+
line, err := tryConnReadUntil(conn, '\n', c.AckMsgTimeout)
126181
if err != nil {
127182
return nil, ackError{
128-
cause: errors.Wrapf(err, `failed to read "OK <port>" within %s`, timeout.AckMsgTimeout),
183+
cause: errors.Wrapf(err, `failed to read "OK <port>" within %s`, c.AckMsgTimeout),
129184
}
130185
}
131186

vsock/listener.go

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,8 @@ import (
2626
type listener struct {
2727
listener net.Listener
2828
port uint32
29-
timeout Timeout
3029
ctx context.Context
31-
logger *logrus.Entry
30+
config config
3231
}
3332

3433
// Listener returns a net.Listener implementation for guest-side Firecracker
@@ -42,23 +41,22 @@ func Listener(ctx context.Context, logger *logrus.Entry, port uint32) (net.Liste
4241
return listener{
4342
listener: l,
4443
port: port,
45-
timeout: DefaultTimeouts(),
44+
config: defaultConfig(),
4645
ctx: ctx,
47-
logger: logger,
4846
}, nil
4947
}
5048

5149
func (l listener) Accept() (net.Conn, error) {
52-
ctx, cancel := context.WithTimeout(l.ctx, l.timeout.RetryTimeout)
50+
ctx, cancel := context.WithTimeout(l.ctx, l.config.RetryTimeout)
5351
defer cancel()
5452

5553
var attemptCount int
56-
ticker := time.NewTicker(l.timeout.RetryInterval)
54+
ticker := time.NewTicker(l.config.RetryInterval)
5755
defer ticker.Stop()
5856
tickerCh := ticker.C
5957
for {
6058
attemptCount++
61-
logger := l.logger.WithField("attempt", attemptCount)
59+
logger := l.config.logger.WithField("attempt", attemptCount)
6260

6361
select {
6462
case <-ctx.Done():

0 commit comments

Comments
 (0)