Skip to content

Commit 803d284

Browse files
committed
feat: lex-sort query args and arg transformer hook
This commit introduces a new feature that enables the specification of specific query parameters to be used as the cache index key when creating a URI via a arg transformer hook. For instance, it allows caching requests like GET /foo?mode=bar. The assumption is that users are aware of the query parameters that can be sent and that these parameters are utilized by the handler to modify the request. Consequently, even if the user includes additional irrelevant parameters like GET /foo?mode=bar&junk=xyz, the cache will still be utilized since the specified parameters are the key. Also, we now lexicographically sort the query params so that cache cannot be busted by reordering the params.
1 parent 1d2547a commit 803d284

File tree

3 files changed

+213
-29
lines changed

3 files changed

+213
-29
lines changed

fastcache.go

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ type Options struct {
4747

4848
// Cache based on uri+querystring.
4949
IncludeQueryString bool
50+
51+
QueryArgsTransformerHook func(*fasthttp.Args)
5052
}
5153

5254
// Item represents the cache entry for a single endpoint with the actual cache
@@ -94,14 +96,8 @@ func (f *FastCache) Cached(h fastglue.FastRequestHandler, o *Options, group stri
9496
}
9597
return h(r)
9698
}
97-
var hash [16]byte
98-
// If IncludeQueryString option is set then cache based on uri + md5(query_string)
99-
if o.IncludeQueryString {
100-
hash = md5.Sum(r.RequestCtx.URI().FullURI())
101-
} else {
102-
hash = md5.Sum(r.RequestCtx.URI().Path())
103-
}
104-
uri := hex.EncodeToString(hash[:])
99+
100+
uri := f.makeURI(r, o)
105101

106102
// Fetch etag + cached bytes from the store.
107103
blob, err := f.s.Get(namespace, group, uri)
@@ -193,6 +189,44 @@ func (f *FastCache) DelGroup(namespace string, group ...string) error {
193189
return f.s.DelGroup(namespace, group...)
194190
}
195191

192+
func (f *FastCache) makeURI(r *fastglue.Request, o *Options) string {
193+
var hash [16]byte
194+
195+
// lexicographically sort the query string.
196+
r.RequestCtx.QueryArgs().Sort(func(x, y []byte) int {
197+
return bytes.Compare(x, y)
198+
})
199+
200+
// If IncludeQueryString option is set then cache based on uri + md5(query_string)
201+
if o.IncludeQueryString {
202+
id := r.RequestCtx.URI().FullURI()
203+
204+
// Check if we need to include only specific query params.
205+
if o.QueryArgsTransformerHook != nil {
206+
// Acquire a copy so as to not modify the request.
207+
uriRaw := fasthttp.AcquireURI()
208+
r.RequestCtx.URI().CopyTo(uriRaw)
209+
210+
q := uriRaw.QueryArgs()
211+
212+
// Call the hook to transform the query string.
213+
o.QueryArgsTransformerHook(q)
214+
215+
// Get the new URI.
216+
id = uriRaw.FullURI()
217+
218+
// Release the borrowed URI.
219+
fasthttp.ReleaseURI(uriRaw)
220+
}
221+
222+
hash = md5.Sum(id)
223+
} else {
224+
hash = md5.Sum(r.RequestCtx.URI().Path())
225+
}
226+
227+
return hex.EncodeToString(hash[:])
228+
}
229+
196230
// cache caches a response body.
197231
func (f *FastCache) cache(r *fastglue.Request, namespace, group string, o *Options) error {
198232
// ETag?.
@@ -206,14 +240,7 @@ func (f *FastCache) cache(r *fastglue.Request, namespace, group string, o *Optio
206240
}
207241

208242
// Write cache to the store (etag, content type, response body).
209-
var hash [16]byte
210-
// If IncludeQueryString option is set then cache based on uri + md5(query_string)
211-
if o.IncludeQueryString {
212-
hash = md5.Sum(r.RequestCtx.URI().FullURI())
213-
} else {
214-
hash = md5.Sum(r.RequestCtx.URI().Path())
215-
}
216-
uri := hex.EncodeToString(hash[:])
243+
uri := f.makeURI(r, o)
217244

218245
var blob []byte
219246
if !o.NoBlob {

fastcache_test.go

Lines changed: 159 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
package fastcache_test
22

33
import (
4-
"io/ioutil"
4+
"fmt"
5+
"io"
56
"log"
67
"net/http"
78
"os"
@@ -50,6 +51,49 @@ func init() {
5051
Logger: log.New(os.Stdout, "", log.Ldate|log.Ltime|log.Lshortfile),
5152
}
5253

54+
includeQS = &fastcache.Options{
55+
NamespaceKey: namespaceKey,
56+
ETag: true,
57+
TTL: time.Second * 5,
58+
Logger: log.New(os.Stdout, "", log.Ldate|log.Ltime|log.Lshortfile),
59+
IncludeQueryString: true,
60+
}
61+
62+
includeQSNoEtag = &fastcache.Options{
63+
NamespaceKey: namespaceKey,
64+
ETag: false,
65+
TTL: time.Second * 5,
66+
Logger: log.New(os.Stdout, "", log.Ldate|log.Ltime|log.Lshortfile),
67+
IncludeQueryString: true,
68+
}
69+
70+
includeQSSpecific = &fastcache.Options{
71+
NamespaceKey: namespaceKey,
72+
ETag: true,
73+
TTL: time.Second * 5,
74+
Logger: log.New(os.Stdout, "", log.Ldate|log.Ltime|log.Lshortfile),
75+
IncludeQueryString: true,
76+
QueryArgsTransformerHook: func(args *fasthttp.Args) {
77+
// Copy the keys to delete, and delete them later. This is to
78+
// avoid borking the VisitAll() iterator.
79+
mp := map[string]struct{}{
80+
"foo": {},
81+
}
82+
83+
delKeys := [][]byte{}
84+
args.VisitAll(func(k, v []byte) {
85+
if _, ok := mp[string(k)]; !ok {
86+
delKeys = append(delKeys, k)
87+
}
88+
})
89+
90+
// Delete the keys.
91+
for _, k := range delKeys {
92+
args.DelBytes(k)
93+
}
94+
},
95+
}
96+
5397
fc = fastcache.New(cachestore.New("CACHE:", redis.NewClient(&redis.Options{
5498
Addr: rd.Addr(),
5599
})))
@@ -78,6 +122,19 @@ func init() {
78122
return r.SendBytes(200, "text/plain", []byte("ok"))
79123
}, ttlShort, group))
80124

125+
srv.GET("/include-qs", fc.Cached(func(r *fastglue.Request) error {
126+
return r.SendBytes(200, "text/plain", []byte("ok"))
127+
}, includeQS, group))
128+
129+
srv.GET("/include-qs-no-etag", fc.Cached(func(r *fastglue.Request) error {
130+
out := time.Now()
131+
return r.SendBytes(200, "text/plain", []byte(fmt.Sprintf("%v", out)))
132+
}, includeQSNoEtag, group))
133+
134+
srv.GET("/include-qs-specific", fc.Cached(func(r *fastglue.Request) error {
135+
return r.SendBytes(200, "text/plain", []byte("ok"))
136+
}, includeQSSpecific, group))
137+
81138
// Start the server
82139
go func() {
83140
s := &fasthttp.Server{
@@ -111,7 +168,7 @@ func getReq(url, etag string, t *testing.T) (*http.Response, string) {
111168
t.Fatal(err)
112169
}
113170

114-
b, err := ioutil.ReadAll(resp.Body)
171+
b, err := io.ReadAll(resp.Body)
115172
if err != nil {
116173
t.Fatal(b)
117174
}
@@ -139,22 +196,119 @@ func TestCache(t *testing.T) {
139196
}
140197

141198
// Wrong etag.
142-
r, b = getReq(srvRoot+"/cached", "wrong", t)
199+
r, _ = getReq(srvRoot+"/cached", "wrong", t)
143200
if r.StatusCode != 200 {
144201
t.Fatalf("expected 200 but got '%v'", r.StatusCode)
145202
}
146203

147204
// Clear cache.
148-
r, b = getReq(srvRoot+"/clear-group", "", t)
205+
r, _ = getReq(srvRoot+"/clear-group", "", t)
149206
if r.StatusCode != 200 {
150207
t.Fatalf("expected 200 but got %v", r.StatusCode)
151208
}
152-
r, b = getReq(srvRoot+"/cached", r.Header.Get("Etag"), t)
209+
r, _ = getReq(srvRoot+"/cached", r.Header.Get("Etag"), t)
153210
if r.StatusCode != 200 {
154211
t.Fatalf("expected 200 but got '%v'", r.StatusCode)
155212
}
156213
}
157214

215+
func TestQueryString(t *testing.T) {
216+
// First request should be 200.
217+
r, b := getReq(srvRoot+"/include-qs?foo=bar", "", t)
218+
if r.StatusCode != 200 {
219+
t.Fatalf("expected 200 but got %v", r.StatusCode)
220+
}
221+
222+
if b != "ok" {
223+
t.Fatalf("expected 'ok' in body but got %v", b)
224+
}
225+
226+
// Second should be 304.
227+
r, _ = getReq(srvRoot+"/include-qs?foo=bar", r.Header.Get("Etag"), t)
228+
if r.StatusCode != 304 {
229+
t.Fatalf("expected 304 but got '%v'", r.StatusCode)
230+
}
231+
}
232+
233+
func TestQueryStringLexicographical(t *testing.T) {
234+
// First request should be 200.
235+
r, b := getReq(srvRoot+"/include-qs?foo=bar&baz=qux", "", t)
236+
if r.StatusCode != 200 {
237+
t.Fatalf("expected 200 but got %v", r.StatusCode)
238+
}
239+
240+
if b != "ok" {
241+
t.Fatalf("expected 'ok' in body but got %v", b)
242+
}
243+
244+
// Second should be 304.
245+
r, _ = getReq(srvRoot+"/include-qs?baz=qux&foo=bar", r.Header.Get("Etag"), t)
246+
if r.StatusCode != 304 {
247+
t.Fatalf("expected 304 but got '%v'", r.StatusCode)
248+
}
249+
}
250+
251+
func TestQueryStringWithoutEtag(t *testing.T) {
252+
// First request should be 200.
253+
r, b := getReq(srvRoot+"/include-qs-no-etag?foo=bar", "", t)
254+
if r.StatusCode != 200 {
255+
t.Fatalf("expected 200 but got %v", r.StatusCode)
256+
}
257+
258+
// Second should be 200 but with same response.
259+
r2, b2 := getReq(srvRoot+"/include-qs-no-etag?foo=bar", "", t)
260+
if r2.StatusCode != 200 {
261+
t.Fatalf("expected 200 but got '%v'", r2.StatusCode)
262+
}
263+
264+
if b2 != b {
265+
t.Fatalf("expected '%v' in body but got %v", b, b2)
266+
}
267+
268+
// Third should be 200 but with different response.
269+
r3, b3 := getReq(srvRoot+"/include-qs-no-etag?foo=baz", "", t)
270+
if r3.StatusCode != 200 {
271+
t.Fatalf("expected 200 but got '%v'", r3.StatusCode)
272+
}
273+
274+
// time should be different
275+
if b3 == b {
276+
t.Fatalf("expected both to be different (should not be %v), but got %v", b, b3)
277+
}
278+
}
279+
280+
func TestQueryStringSpecific(t *testing.T) {
281+
// First request should be 200.
282+
r1, b := getReq(srvRoot+"/include-qs-specific?foo=bar&baz=qux", "", t)
283+
if r1.StatusCode != 200 {
284+
t.Fatalf("expected 200 but got %v", r1.StatusCode)
285+
}
286+
if b != "ok" {
287+
t.Fatalf("expected 'ok' in body but got %v", b)
288+
}
289+
290+
// Second should be 304.
291+
r, _ := getReq(srvRoot+"/include-qs-specific?foo=bar&baz=qux", r1.Header.Get("Etag"), t)
292+
if r.StatusCode != 304 {
293+
t.Fatalf("expected 304 but got '%v'", r.StatusCode)
294+
}
295+
296+
// Third should be 304 as foo=bar
297+
r, _ = getReq(srvRoot+"/include-qs-specific?loo=mar&foo=bar&baz=qux&quux=quuz", r1.Header.Get("Etag"), t)
298+
if r.StatusCode != 304 {
299+
t.Fatalf("expected 304 but got '%v'", r.StatusCode)
300+
}
301+
302+
// Fourth should be 200 as foo=rab
303+
r, b = getReq(srvRoot+"/include-qs-specific?foo=rab&baz=qux&quux=quuz", r1.Header.Get("Etag"), t)
304+
if r.StatusCode != 200 {
305+
t.Fatalf("expected 200 but got '%v'", r.StatusCode)
306+
}
307+
if b != "ok" {
308+
t.Fatalf("expected 'ok' in body but got %v", b)
309+
}
310+
}
311+
158312
func TestNoCache(t *testing.T) {
159313
// All requests should return 200.
160314
for n := 0; n < 3; n++ {

stores/goredis/redis.go

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,16 @@
22
// The internal structure looks like this where
33
// XX1234 = namespace, marketwach = group
44
// ```
5-
// CACHE:XX1234:marketwatch {
6-
// "/user/marketwatch_ctype" -> []byte
7-
// "/user/marketwatch_etag" -> []byte
8-
// "/user/marketwatch_blob" -> []byte
9-
// "/user/marketwatch/123_ctype" -> []byte
10-
// "/user/marketwatch/123_etag" -> []byte
11-
// "/user/marketwatch/123_blob" -> []byte
12-
// }
5+
//
6+
// CACHE:XX1234:marketwatch {
7+
// "/user/marketwatch_ctype" -> []byte
8+
// "/user/marketwatch_etag" -> []byte
9+
// "/user/marketwatch_blob" -> []byte
10+
// "/user/marketwatch/123_ctype" -> []byte
11+
// "/user/marketwatch/123_etag" -> []byte
12+
// "/user/marketwatch/123_blob" -> []byte
13+
// }
14+
//
1315
// ```
1416
package goredis
1517

@@ -53,6 +55,7 @@ func (s *Store) Get(namespace, group, uri string) (fastcache.Item, error) {
5355
var (
5456
out fastcache.Item
5557
)
58+
5659
// Get content_type, etag, blob in that order.
5760
cmd := s.cn.HMGet(s.ctx, s.key(namespace, group), s.field(keyCtype, uri), s.field(keyEtag, uri), s.field(keyBlob, uri))
5861
if err := cmd.Err(); err != nil {

0 commit comments

Comments
 (0)