Skip to content
This repository was archived by the owner on Jan 28, 2021. It is now read-only.

Commit 62d616e

Browse files
committed
auth: add tests for audit trails
Signed-off-by: Javi Fontan <jfontan@gmail.com>
1 parent 5af73e8 commit 62d616e

File tree

4 files changed

+309
-30
lines changed

4 files changed

+309
-30
lines changed

auth/audit_test.go

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
package auth_test
2+
3+
import (
4+
"context"
5+
"testing"
6+
"time"
7+
8+
"gopkg.in/src-d/go-mysql-server.v0/auth"
9+
"gopkg.in/src-d/go-mysql-server.v0/sql"
10+
11+
"github.com/sanity-io/litter"
12+
"github.com/sirupsen/logrus"
13+
"github.com/sirupsen/logrus/hooks/test"
14+
"github.com/stretchr/testify/require"
15+
)
16+
17+
type Authentication struct {
18+
user string
19+
address string
20+
err error
21+
}
22+
23+
type Authorization struct {
24+
ctx *sql.Context
25+
p auth.Permission
26+
err error
27+
}
28+
29+
type Query struct {
30+
ctx *sql.Context
31+
d time.Duration
32+
err error
33+
}
34+
35+
type auditTest struct {
36+
authentication Authentication
37+
authorization Authorization
38+
query Query
39+
}
40+
41+
func (a *auditTest) Authentication(user string, address string, err error) {
42+
a.authentication = Authentication{
43+
user: user,
44+
address: address,
45+
err: err,
46+
}
47+
}
48+
49+
func (a *auditTest) Authorization(ctx *sql.Context, p auth.Permission, err error) {
50+
a.authorization = Authorization{
51+
ctx: ctx,
52+
p: p,
53+
err: err,
54+
}
55+
}
56+
57+
func (a *auditTest) Query(ctx *sql.Context, d time.Duration, err error) {
58+
println("query!")
59+
a.query = Query{
60+
ctx: ctx,
61+
d: d,
62+
err: err,
63+
}
64+
}
65+
66+
func (a *auditTest) Clean() {
67+
a.authorization = Authorization{}
68+
a.authentication = Authentication{}
69+
a.query = Query{}
70+
}
71+
72+
func TestAuditAuthentication(t *testing.T) {
73+
a := auth.NewNativeSingle("user", "password", auth.AllPermissions)
74+
at := new(auditTest)
75+
audit := auth.NewAudit(a, at)
76+
77+
extra := func(t *testing.T, c authenticationTest) {
78+
a := at.authentication
79+
80+
require.Equal(t, c.user, a.user)
81+
require.NotEmpty(t, a.address)
82+
if c.success {
83+
require.NoError(t, a.err)
84+
} else {
85+
require.Error(t, a.err)
86+
require.Nil(t, at.authorization.ctx)
87+
require.Nil(t, at.query.ctx)
88+
}
89+
90+
at.Clean()
91+
}
92+
93+
testAuthentication(t, audit, nativeSingleTests, extra)
94+
}
95+
96+
func TestAuditAuthorization(t *testing.T) {
97+
a := auth.NewNativeSingle("user", "", auth.ReadPerm)
98+
at := new(auditTest)
99+
audit := auth.NewAudit(a, at)
100+
101+
tests := []authorizationTest{
102+
{"user", "invalid query", false},
103+
{"user", queries["select"], true},
104+
105+
{"user", queries["create_index"], false},
106+
{"user", queries["drop_index"], false},
107+
{"user", queries["insert"], false},
108+
{"user", queries["lock"], false},
109+
{"user", queries["unlock"], false},
110+
}
111+
112+
extra := func(t *testing.T, c authorizationTest) {
113+
a := at.authorization
114+
q := at.query
115+
116+
litter.Dump(q)
117+
require.NotNil(t, q.ctx)
118+
require.Equal(t, c.user, q.ctx.Client().User)
119+
require.NotEmpty(t, q.ctx.Client().Address)
120+
require.NotZero(t, q.d)
121+
require.Equal(t, c.user, at.authentication.user)
122+
123+
if c.success {
124+
require.Equal(t, c.user, a.ctx.Client().User)
125+
require.NotEmpty(t, a.ctx.Client().Address)
126+
require.NoError(t, a.err)
127+
require.NoError(t, q.err)
128+
} else {
129+
require.Error(t, q.err)
130+
131+
// if there's a syntax error authorization is not triggered
132+
if auth.ErrNotAuthorized.Is(q.err) {
133+
require.Equal(t, q.err, a.err)
134+
require.NotNil(t, a.ctx)
135+
require.Equal(t, c.user, a.ctx.Client().User)
136+
require.NotEmpty(t, a.ctx.Client().Address)
137+
} else {
138+
require.NoError(t, a.err)
139+
require.Nil(t, a.ctx)
140+
}
141+
}
142+
143+
at.Clean()
144+
}
145+
146+
testAudit(t, audit, tests, extra)
147+
}
148+
149+
func TestAuditLog(t *testing.T) {
150+
require := require.New(t)
151+
152+
logger, hook := test.NewNullLogger()
153+
l := auth.NewAuditLog(logger)
154+
155+
pid := uint64(303)
156+
id := uint32(42)
157+
158+
l.Authentication("user", "client", nil)
159+
e := hook.LastEntry()
160+
require.NotNil(e)
161+
require.Equal(logrus.InfoLevel, e.Level)
162+
m := logrus.Fields{
163+
"system": "audit",
164+
"action": "authentication",
165+
"user": "user",
166+
"address": "client",
167+
"success": true,
168+
}
169+
require.Equal(m, e.Data)
170+
171+
err := auth.ErrNoPermission.New(auth.ReadPerm)
172+
l.Authentication("user", "client", err)
173+
e = hook.LastEntry()
174+
m["success"] = false
175+
m["err"] = err
176+
require.Equal(m, e.Data)
177+
178+
s := sql.NewSession("server", "client", "user", id)
179+
ctx := sql.NewContext(context.TODO(),
180+
sql.WithSession(s),
181+
sql.WithPid(pid),
182+
sql.WithQuery("query"),
183+
)
184+
185+
l.Authorization(ctx, auth.ReadPerm, nil)
186+
e = hook.LastEntry()
187+
require.NotNil(e)
188+
require.Equal(logrus.InfoLevel, e.Level)
189+
m = logrus.Fields{
190+
"system": "audit",
191+
"action": "authorization",
192+
"permission": auth.ReadPerm.String(),
193+
"user": "user",
194+
"query": "query",
195+
"address": "client",
196+
"connection_id": id,
197+
"pid": pid,
198+
"success": true,
199+
}
200+
require.Equal(m, e.Data)
201+
202+
l.Authorization(ctx, auth.ReadPerm, err)
203+
e = hook.LastEntry()
204+
m["success"] = false
205+
m["err"] = err
206+
require.Equal(m, e.Data)
207+
208+
l.Query(ctx, 808*time.Second, nil)
209+
e = hook.LastEntry()
210+
require.NotNil(e)
211+
require.Equal(logrus.InfoLevel, e.Level)
212+
m = logrus.Fields{
213+
"system": "audit",
214+
"action": "query",
215+
"duration": 808 * time.Second,
216+
"user": "user",
217+
"query": "query",
218+
"address": "client",
219+
"connection_id": id,
220+
"pid": pid,
221+
"success": true,
222+
}
223+
require.Equal(m, e.Data)
224+
225+
l.Query(ctx, 808*time.Second, err)
226+
e = hook.LastEntry()
227+
m["success"] = false
228+
m["err"] = err
229+
require.Equal(m, e.Data)
230+
}

auth/common_test.go

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ func connString(user, password string) string {
8080
return fmt.Sprintf("%s:%s@tcp(127.0.0.1:%d)/test", user, password, port)
8181
}
8282

83-
type authenticationTests []struct {
83+
type authenticationTest struct {
8484
user string
8585
password string
8686
success bool
@@ -89,7 +89,8 @@ type authenticationTests []struct {
8989
func testAuthentication(
9090
t *testing.T,
9191
a auth.Auth,
92-
tests authenticationTests,
92+
tests []authenticationTest,
93+
extra func(t *testing.T, c authenticationTest),
9394
) {
9495
t.Helper()
9596
req := require.New(t)
@@ -115,6 +116,10 @@ func testAuthentication(
115116

116117
err = db.Close()
117118
req.NoError(err)
119+
120+
if extra != nil {
121+
extra(t, c)
122+
}
118123
})
119124
}
120125

@@ -131,7 +136,7 @@ var queries = map[string]string{
131136
"unlock": "unlock tables",
132137
}
133138

134-
type authorizationTests []struct {
139+
type authorizationTest struct {
135140
user string
136141
query string
137142
success bool
@@ -140,7 +145,8 @@ type authorizationTests []struct {
140145
func testAuthorization(
141146
t *testing.T,
142147
a auth.Auth,
143-
tests authorizationTests,
148+
tests []authorizationTest,
149+
extra func(t *testing.T, c authorizationTest),
144150
) {
145151
t.Helper()
146152
req := require.New(t)
@@ -166,7 +172,51 @@ func testAuthorization(
166172
}
167173

168174
req.Error(err)
169-
req.True(auth.ErrNotAuthorized.Is(err))
175+
if extra != nil {
176+
extra(t, c)
177+
} else {
178+
req.True(auth.ErrNotAuthorized.Is(err))
179+
}
180+
})
181+
}
182+
}
183+
184+
func testAudit(
185+
t *testing.T,
186+
a auth.Auth,
187+
tests []authorizationTest,
188+
extra func(t *testing.T, c authorizationTest),
189+
) {
190+
t.Helper()
191+
req := require.New(t)
192+
193+
tmpDir, s, err := authServer(a)
194+
req.NoError(err)
195+
defer os.RemoveAll(tmpDir)
196+
197+
for _, c := range tests {
198+
t.Run(fmt.Sprintf("%s", c.user), func(t *testing.T) {
199+
req := require.New(t)
200+
201+
db, err := dsql.Open("mysql", connString(c.user, ""))
202+
req.NoError(err)
203+
_, err = db.Query(c.query)
204+
205+
if c.success {
206+
req.NoError(err)
207+
} else {
208+
req.Error(err)
209+
}
210+
211+
err = db.Close()
212+
req.NoError(err)
213+
214+
if extra != nil {
215+
extra(t, c)
216+
}
170217
})
171218
}
219+
220+
err = s.Close()
221+
req.NoError(err)
172222
}

0 commit comments

Comments
 (0)