Skip to content
This repository was archived by the owner on Jan 28, 2021. It is now read-only.

Commit 94eaa89

Browse files
authored
Merge pull request #517 from jfontan/fix/client-address-in-session
sql: process list now shows client address
2 parents c5a2c61 + 2eaa2da commit 94eaa89

File tree

12 files changed

+45
-24
lines changed

12 files changed

+45
-24
lines changed

auth/common_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ func testAuthorization(
153153
t.Run(fmt.Sprintf("%s-%s", c.user, c.query), func(t *testing.T) {
154154
req := require.New(t)
155155

156-
session := sql.NewSession("localhost", c.user, uint32(i))
156+
session := sql.NewSession("localhost", "client", c.user, uint32(i))
157157
ctx := sql.NewContext(context.TODO(),
158158
sql.WithSession(session),
159159
sql.WithPid(uint64(i)))

engine_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1688,7 +1688,7 @@ func insertRows(t *testing.T, table sql.Inserter, rows ...sql.Row) {
16881688
var pid uint64
16891689

16901690
func newCtx() *sql.Context {
1691-
session := sql.NewSession("address", "user", 1)
1691+
session := sql.NewSession("address", "client", "user", 1)
16921692
return sql.NewContext(
16931693
context.Background(),
16941694
sql.WithPid(atomic.AddUint64(&pid, 1)),

server/context.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ type DoneFunc func()
1818

1919
// DefaultSessionBuilder is a SessionBuilder that returns a base session.
2020
func DefaultSessionBuilder(c *mysql.Conn, addr string) sql.Session {
21-
return sql.NewSession(addr, c.User, c.ConnectionID)
21+
client := c.RemoteAddr().String()
22+
return sql.NewSession(addr, client, c.User, c.ConnectionID)
2223
}
2324

2425
// SessionManager is in charge of creating new sessions for the given

server/handler_test.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,16 @@ func setupMemDB(require *require.Assertions) *sqle.Engine {
3636
}
3737

3838
func TestHandlerOutput(t *testing.T) {
39+
// This session builder is used as dummy mysql Conn is not complete and
40+
// causes panic when accessing remote address.
41+
testSessionBuilder := func(c *mysql.Conn, addr string) sql.Session {
42+
client := "127.0.0.1:34567"
43+
return sql.NewSession(addr, client, c.User, c.ConnectionID)
44+
}
45+
3946
e := setupMemDB(require.New(t))
4047
dummyConn := &mysql.Conn{ConnectionID: 1}
41-
handler := NewHandler(e, NewSessionManager(DefaultSessionBuilder, opentracing.NoopTracer{}, "foo"))
48+
handler := NewHandler(e, NewSessionManager(testSessionBuilder, opentracing.NoopTracer{}, "foo"))
4249
handler.NewConnection(dummyConn)
4350

4451
type exptectedValues struct {

sql/analyzer/check_auth.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ func CheckAuthorization(au auth.Auth) RuleFunc {
2121
perm = auth.ReadPerm
2222
}
2323

24-
err := au.Allowed(ctx.User(), perm)
24+
err := au.Allowed(ctx.Client().User, perm)
2525
if err != nil {
2626
return nil, err
2727
}

sql/expression/function/connection_id_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import (
1111
func TestConnectionID(t *testing.T) {
1212
require := require.New(t)
1313

14-
session := sql.NewSession("", "", 2)
14+
session := sql.NewSession("", "", "", 2)
1515
ctx := sql.NewContext(context.Background(), sql.WithSession(session))
1616

1717
f := NewConnectionID()

sql/plan/processlist.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ func (p *ShowProcessList) RowIter(ctx *sql.Context) (sql.RowIter, error) {
9494
time: int64(proc.Seconds()),
9595
state: strings.Join(status, ", "),
9696
command: proc.Type.String(),
97-
host: ctx.Session.Address(),
97+
host: ctx.Session.Client().Address,
9898
info: proc.Query,
9999
db: p.Database,
100100
}.toRow()

sql/plan/processlist_test.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@ import (
1111
func TestShowProcessList(t *testing.T) {
1212
require := require.New(t)
1313

14+
addr := "127.0.0.1:34567"
15+
1416
n := NewShowProcessList()
1517
p := sql.NewProcessList()
16-
sess := sql.NewSession("0.0.0.0:1234", "foo", 1)
18+
sess := sql.NewSession("0.0.0.0:3306", addr, "foo", 1)
1719
ctx := sql.NewContext(context.Background(), sql.WithPid(1), sql.WithSession(sess))
1820

1921
ctx, err := p.AddProcess(ctx, sql.QueryProcess, "SELECT foo")
@@ -42,8 +44,8 @@ func TestShowProcessList(t *testing.T) {
4244
require.NoError(err)
4345

4446
expected := []sql.Row{
45-
{int64(1), "foo", "0.0.0.0:1234", "foo", "query", int64(0), "a(4/5), b(2/6)", "SELECT foo"},
46-
{int64(2), "foo", "0.0.0.0:1234", "foo", "create_index", int64(0), "foo(1/2)", "SELECT bar"},
47+
{int64(1), "foo", addr, "foo", "query", int64(0), "a(4/5), b(2/6)", "SELECT foo"},
48+
{int64(2), "foo", addr, "foo", "create_index", int64(0), "foo(1/2)", "SELECT bar"},
4749
}
4850

4951
require.ElementsMatch(expected, rows)

sql/processlist.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ func (pl *ProcessList) AddProcess(
108108
Type: typ,
109109
Query: query,
110110
Progress: make(map[string]Progress),
111-
User: ctx.Session.User(),
111+
User: ctx.Session.Client().User,
112112
StartedAt: time.Now(),
113113
Kill: cancel,
114114
}
@@ -136,7 +136,7 @@ func (pl *ProcessList) UpdateProgress(pid uint64, name string, delta int64) {
136136
p.Progress[name] = progress
137137
}
138138

139-
// AddProgressItem adds a new item to track progress from to the proces with
139+
// AddProgressItem adds a new item to track progress from to the process with
140140
// the given pid. If the pid does not exist, it will do nothing.
141141
func (pl *ProcessList) AddProgressItem(pid uint64, name string, total int64) {
142142
pl.mu.Lock()

sql/processlist_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ func TestProcessList(t *testing.T) {
1212
require := require.New(t)
1313

1414
p := NewProcessList()
15-
sess := NewSession("0.0.0.0:1234", "foo", 1)
15+
sess := NewSession("0.0.0.0:3306", "127.0.0.1:34567", "foo", 1)
1616
ctx := NewContext(context.Background(), WithPid(1), WithSession(sess))
1717
ctx, err := p.AddProcess(ctx, QueryProcess, "SELECT foo")
1818
require.NoError(err)

sql/session.go

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,20 @@ const (
1919
QueryKey key = iota
2020
)
2121

22+
// Client holds session user information.
23+
type Client struct {
24+
// User of the session.
25+
User string
26+
// Address of the client.
27+
Address string
28+
}
29+
2230
// Session holds the session data.
2331
type Session interface {
2432
// Address of the server.
2533
Address() string
2634
// User of the session.
27-
User() string
35+
Client() Client
2836
// Set session configuration.
2937
Set(key string, typ Type, value interface{})
3038
// Get session configuration.
@@ -47,18 +55,18 @@ type Session interface {
4755
type BaseSession struct {
4856
id uint32
4957
addr string
50-
user string
58+
client Client
5159
mu sync.RWMutex
5260
config map[string]TypedValue
5361
warnings []*Warning
5462
}
5563

56-
// User returns the current user of the session.
57-
func (s *BaseSession) User() string { return s.user }
58-
5964
// Address returns the server address.
6065
func (s *BaseSession) Address() string { return s.addr }
6166

67+
// User returns session's client information.
68+
func (s *BaseSession) Client() Client { return s.client }
69+
6270
// Set implements the Session interface.
6371
func (s *BaseSession) Set(key string, typ Type, value interface{}) {
6472
s.mu.Lock()
@@ -171,11 +179,14 @@ func HasDefaultValue(s Session, key string) (bool, interface{}) {
171179
}
172180

173181
// NewSession creates a new session with data.
174-
func NewSession(address string, user string, id uint32) Session {
182+
func NewSession(server, client, user string, id uint32) Session {
175183
return &BaseSession{
176-
id: id,
177-
addr: address,
178-
user: user,
184+
id: id,
185+
addr: server,
186+
client: Client{
187+
Address: client,
188+
User: user,
189+
},
179190
config: DefaultSessionConfig(),
180191
}
181192
}

sql/session_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import (
1111
func TestSessionConfig(t *testing.T) {
1212
require := require.New(t)
1313

14-
sess := NewSession("foo", "bar", 1)
14+
sess := NewSession("foo", "baz", "bar", 1)
1515
typ, v := sess.Get("foo")
1616
require.Equal(Null, typ)
1717
require.Equal(nil, v)
@@ -37,7 +37,7 @@ func TestSessionConfig(t *testing.T) {
3737

3838
func TestHasDefaultValue(t *testing.T) {
3939
require := require.New(t)
40-
sess := NewSession("foo", "bar", 1)
40+
sess := NewSession("foo", "baz", "bar", 1)
4141

4242
for key := range DefaultSessionConfig() {
4343
require.True(HasDefaultValue(sess, key))

0 commit comments

Comments
 (0)