diff --git a/.github/workflows/grid-proxy-integration.yml b/.github/workflows/grid-proxy-integration.yml index aaed5825d..ce7623070 100644 --- a/.github/workflows/grid-proxy-integration.yml +++ b/.github/workflows/grid-proxy-integration.yml @@ -52,7 +52,7 @@ jobs: pushd tools/db go run . --seed 13 --postgres-host localhost --postgres-db tfgrid-graphql --postgres-password postgres --postgres-user postgres --reset popd - go run cmds/proxy_server/main.go -no-cert -no-indexer --address :8080 --log-level debug --postgres-host localhost --postgres-db tfgrid-graphql --postgres-password postgres --postgres-user postgres --mnemonics "$MNEMONICS" & + go run cmds/proxy_server/main.go --rate-limit-rps 1000 -no-cert -no-indexer --address :8080 --log-level debug --postgres-host localhost --postgres-db tfgrid-graphql --postgres-password postgres --postgres-user postgres --mnemonics "$MNEMONICS" & sleep 10 pushd tests/queries go test -v --seed 13 -no-modify --postgres-host localhost --postgres-db tfgrid-graphql --postgres-password postgres --postgres-user postgres --endpoint http://localhost:8080 diff --git a/grid-proxy/Makefile b/grid-proxy/Makefile index f3da93429..fe2115823 100644 --- a/grid-proxy/Makefile +++ b/grid-proxy/Makefile @@ -58,6 +58,7 @@ server-start: ## Start the proxy server (Args: `m=`) --postgres-db tfgrid-graphql \ --postgres-password postgres \ --postgres-user postgres \ + --rate-limit-rps 1000 \ --mnemonics "$(m)" ; all-start: db-stop db-start sleep db-fill server-start ## Full start of the database and the server (Args: `m=`) diff --git a/grid-proxy/cmds/proxy_server/main.go b/grid-proxy/cmds/proxy_server/main.go index 7518c518f..eddf1dd0f 100644 --- a/grid-proxy/cmds/proxy_server/main.go +++ b/grid-proxy/cmds/proxy_server/main.go @@ -55,6 +55,7 @@ type flags struct { relayURL string mnemonics string maxPoolOpenConnections int + rateLimitRPS int // Rate limit requests per second per IP noIndexer bool // true to stop the indexer, useful on running for testing indexerUpserterBatchSize uint @@ -98,6 +99,7 @@ func main() { flag.StringVar(&f.relayURL, "relay-url", DefaultRelayURL, "RMB relay url") flag.StringVar(&f.mnemonics, "mnemonics", "", "Dummy user mnemonics for relay calls") flag.IntVar(&f.maxPoolOpenConnections, "max-open-conns", 80, "max number of db connection pool open connections") + flag.IntVar(&f.rateLimitRPS, "rate-limit-rps", 20, "rate limit requests per second per IP address (0 to disable)") flag.BoolVar(&f.noIndexer, "no-indexer", false, "do not start the indexer") flag.UintVar(&f.indexerUpserterBatchSize, "indexer-upserter-batch-size", 20, "results batch size which collected before upserting") @@ -182,6 +184,13 @@ func main() { log.Fatal().Err(err).Msg("failed to create mux server") } + // Log rate limiting configuration + if f.rateLimitRPS > 0 { + log.Info().Int("rate_limit_rps", f.rateLimitRPS).Msg("HTTP rate limiting enabled") + } else { + log.Info().Msg("HTTP rate limiting disabled") + } + if err := app(s, f); err != nil { log.Fatal().Msg(err.Error()) } @@ -331,7 +340,7 @@ func createServer(f flags, dbClient explorer.DBClient, gitCommit string, relayCl router := mux.NewRouter().StrictSlash(true) // setup explorer - if err := explorer.Setup(router, gitCommit, dbClient, relayClient, idxIntervals); err != nil { + if err := explorer.Setup(router, gitCommit, dbClient, relayClient, idxIntervals, f.rateLimitRPS); err != nil { return nil, err } diff --git a/grid-proxy/internal/explorer/mw/ratelimiter.go b/grid-proxy/internal/explorer/mw/ratelimiter.go new file mode 100644 index 000000000..4f22cc17f --- /dev/null +++ b/grid-proxy/internal/explorer/mw/ratelimiter.go @@ -0,0 +1,92 @@ +package mw + +import ( + "fmt" + "net/http" + "strconv" + "time" + + "github.com/rs/zerolog/log" + "github.com/threefoldtech/tfgrid-sdk-go/grid-proxy/tools/ratelimiter" +) + +// RateLimiterMiddleware wraps the rate limiter to work with the existing middleware pattern +type RateLimiterMiddleware struct { + limiter *ratelimiter.SlidingWindowRateLimiter +} + +// NewRateLimiterMiddleware creates a new rate limiter middleware +func NewRateLimiterMiddleware(ratePerSecond int) *RateLimiterMiddleware { + return &RateLimiterMiddleware{ + limiter: ratelimiter.NewSlidingWindowRateLimiter(ratePerSecond), + } +} + +// RateLimitAction wraps an Action with rate limiting +func (rlm *RateLimiterMiddleware) RateLimitAction(action Action) Action { + return func(r *http.Request) (interface{}, Response) { + clientIP := ratelimiter.GetClientIP(r) + + if !rlm.limiter.Allow(clientIP) { + log.Warn(). + Str("ip", clientIP). + Str("method", r.Method). + Str("path", r.URL.Path). + Msg("Rate limit exceeded") + + return nil, rlm.TooManyRequests(fmt.Errorf("rate limit exceeded for IP: %s", clientIP), clientIP) + } + + return action(r) + } +} + +// RateLimitProxyAction wraps a ProxyAction with rate limiting +func (rlm *RateLimiterMiddleware) RateLimitProxyAction(action ProxyAction) ProxyAction { + return func(r *http.Request) (*http.Response, Response) { + clientIP := ratelimiter.GetClientIP(r) + if !rlm.limiter.Allow(clientIP) { + log.Warn(). + Str("ip", clientIP). + Str("method", r.Method). + Str("path", r.URL.Path). + Msg("Rate limit exceeded") + + return nil, rlm.TooManyRequests(fmt.Errorf("rate limit exceeded for IP: %s", clientIP), clientIP) + } + + return action(r) + } +} + +// AsRateLimitedHandlerFunc wraps AsHandlerFunc with rate limiting +func (rlm *RateLimiterMiddleware) AsRateLimitedHandlerFunc(action Action) http.HandlerFunc { + rateLimitedAction := rlm.RateLimitAction(action) + return AsHandlerFunc(rateLimitedAction) +} + +// AsRateLimitedProxyHandlerFunc wraps AsProxyHandlerFunc with rate limiting +func (rlm *RateLimiterMiddleware) AsRateLimitedProxyHandlerFunc(action ProxyAction) http.HandlerFunc { + rateLimitedAction := rlm.RateLimitProxyAction(action) + return AsProxyHandlerFunc(rateLimitedAction) +} + +// GetStats returns rate limiter statistics +func (rlm *RateLimiterMiddleware) GetStats() map[string]interface{} { + return rlm.limiter.GetStats() +} + +// TooManyRequests returns a 429 Too Many Requests response with accurate rate limit headers +func (rlm *RateLimiterMiddleware) TooManyRequests(err error, clientIP string) Response { + rateLimit := rlm.limiter.GetRateLimit() + currentRequests := rlm.limiter.GetCurrentRequestCount(clientIP) + remaining := max(0, rateLimit-currentRequests) + resetTime := time.Now().Add(time.Second) + + return Error(err, http.StatusTooManyRequests). + WithHeader("Retry-After", "1"). + WithHeader("X-RateLimit-Limit", strconv.Itoa(rateLimit)). + WithHeader("X-RateLimit-Remaining", strconv.Itoa(remaining)). + WithHeader("X-RateLimit-Reset", strconv.FormatInt(resetTime.Unix(), 10)). + WithHeader("X-Client-IP", clientIP) +} diff --git a/grid-proxy/internal/explorer/mw/utils.go b/grid-proxy/internal/explorer/mw/utils.go new file mode 100644 index 000000000..4ff45f24e --- /dev/null +++ b/grid-proxy/internal/explorer/mw/utils.go @@ -0,0 +1,23 @@ +package mw + +import "net/http" + +// WithRateLimit wraps an Action with rate limiting if a rate limiter is provided. +// If rateLimiter is nil, it falls back to the standard AsHandlerFunc wrapper. +// This provides a clean way to conditionally apply rate limiting to endpoints. +func WithRateLimit(rateLimiter *RateLimiterMiddleware, action Action) http.HandlerFunc { + if rateLimiter != nil { + return rateLimiter.AsRateLimitedHandlerFunc(action) + } + return AsHandlerFunc(action) +} + +// WithRateLimitProxy wraps a ProxyAction with rate limiting if a rate limiter is provided. +// If rateLimiter is nil, it falls back to the standard AsProxyHandlerFunc wrapper. +// This provides a clean way to conditionally apply rate limiting to proxy endpoints. +func WithRateLimitProxy(rateLimiter *RateLimiterMiddleware, action ProxyAction) http.HandlerFunc { + if rateLimiter != nil { + return rateLimiter.AsRateLimitedProxyHandlerFunc(action) + } + return AsProxyHandlerFunc(action) +} diff --git a/grid-proxy/internal/explorer/server.go b/grid-proxy/internal/explorer/server.go index c7f910acb..5dd37e784 100644 --- a/grid-proxy/internal/explorer/server.go +++ b/grid-proxy/internal/explorer/server.go @@ -629,7 +629,7 @@ func (a *App) getContractBills(r *http.Request) (interface{}, mw.Response) { // @license.name Apache 2.0 // @license.url http://www.apache.org/licenses/LICENSE-2.0.html // @BasePath / -func Setup(router *mux.Router, gitCommit string, cl DBClient, relayClient rmb.Client, idxIntervals map[string]uint) error { +func Setup(router *mux.Router, gitCommit string, cl DBClient, relayClient rmb.Client, idxIntervals map[string]uint, rateLimitRPS int) error { a := App{ cl: cl, @@ -638,32 +638,40 @@ func Setup(router *mux.Router, gitCommit string, cl DBClient, relayClient rmb.Cl idxIntervals: idxIntervals, } - router.HandleFunc("/farms", mw.AsHandlerFunc(a.listFarms)) - router.HandleFunc("/stats", mw.AsHandlerFunc(a.getStats)) + var rateLimiter *mw.RateLimiterMiddleware + if rateLimitRPS > 0 { + rateLimiter = mw.NewRateLimiterMiddleware(rateLimitRPS) + log.Info().Int("rate_limit_rps", rateLimitRPS).Msg("Rate limiting enabled") + } else { + log.Info().Msg("Rate limiting disabled") + } + + router.HandleFunc("/farms", mw.WithRateLimit(rateLimiter, a.listFarms)) + router.HandleFunc("/stats", mw.WithRateLimit(rateLimiter, a.getStats)) - router.HandleFunc("/twins", mw.AsHandlerFunc(a.listTwins)) - router.HandleFunc("/twins/{twin_id:[0-9]+}/consumption", mw.AsHandlerFunc(a.getTwinConsumption)) + router.HandleFunc("/twins", mw.WithRateLimit(rateLimiter, a.listTwins)) + router.HandleFunc("/twins/{twin_id:[0-9]+}/consumption", mw.WithRateLimit(rateLimiter, a.getTwinConsumption)) - router.HandleFunc("/nodes", mw.AsHandlerFunc(a.getNodes)) - router.HandleFunc("/nodes/{node_id:[0-9]+}", mw.AsHandlerFunc(a.getNode)) - router.HandleFunc("/nodes/{node_id:[0-9]+}/status", mw.AsHandlerFunc(a.getNodeStatus)) - router.HandleFunc("/nodes/{node_id:[0-9]+}/statistics", mw.AsHandlerFunc(a.getNodeStatistics)) - router.HandleFunc("/nodes/{node_id:[0-9]+}/gpu", mw.AsHandlerFunc(a.getNodeGpus)) + router.HandleFunc("/nodes", mw.WithRateLimit(rateLimiter, a.getNodes)) + router.HandleFunc("/nodes/{node_id:[0-9]+}", mw.WithRateLimit(rateLimiter, a.getNode)) + router.HandleFunc("/nodes/{node_id:[0-9]+}/status", mw.WithRateLimit(rateLimiter, a.getNodeStatus)) + router.HandleFunc("/nodes/{node_id:[0-9]+}/statistics", mw.WithRateLimit(rateLimiter, a.getNodeStatistics)) + router.HandleFunc("/nodes/{node_id:[0-9]+}/gpu", mw.WithRateLimit(rateLimiter, a.getNodeGpus)) - router.HandleFunc("/gateways", mw.AsHandlerFunc(a.getGateways)) - router.HandleFunc("/gateways/{node_id:[0-9]+}", mw.AsHandlerFunc(a.getGateway)) - router.HandleFunc("/gateways/{node_id:[0-9]+}/status", mw.AsHandlerFunc(a.getNodeStatus)) + router.HandleFunc("/gateways", mw.WithRateLimit(rateLimiter, a.getGateways)) + router.HandleFunc("/gateways/{node_id:[0-9]+}", mw.WithRateLimit(rateLimiter, a.getGateway)) + router.HandleFunc("/gateways/{node_id:[0-9]+}/status", mw.WithRateLimit(rateLimiter, a.getNodeStatus)) - router.HandleFunc("/contracts", mw.AsHandlerFunc(a.listContracts)) - router.HandleFunc("/contracts/{contract_id:[0-9]+}", mw.AsHandlerFunc(a.getContract)) - router.HandleFunc("/contracts/{contract_id:[0-9]+}/bills", mw.AsHandlerFunc(a.getContractBills)) + router.HandleFunc("/contracts", mw.WithRateLimit(rateLimiter, a.listContracts)) + router.HandleFunc("/contracts/{contract_id:[0-9]+}", mw.WithRateLimit(rateLimiter, a.getContract)) + router.HandleFunc("/contracts/{contract_id:[0-9]+}/bills", mw.WithRateLimit(rateLimiter, a.getContractBills)) - router.HandleFunc("/public_ips", mw.AsHandlerFunc(a.GetPublicIps)) + router.HandleFunc("/public_ips", mw.WithRateLimit(rateLimiter, a.GetPublicIps)) - router.HandleFunc("/", mw.AsHandlerFunc(a.indexPage(router))) - router.HandleFunc("/ping", mw.AsHandlerFunc(a.ping)) - router.HandleFunc("/version", mw.AsHandlerFunc(a.version)) - router.HandleFunc("/health", mw.AsHandlerFunc(a.health)) + router.HandleFunc("/", mw.WithRateLimit(rateLimiter, a.indexPage(router))) + router.HandleFunc("/ping", mw.WithRateLimit(rateLimiter, a.ping)) + router.HandleFunc("/version", mw.WithRateLimit(rateLimiter, a.version)) + router.HandleFunc("/health", mw.WithRateLimit(rateLimiter, a.health)) router.PathPrefix("/swagger/").Handler(httpSwagger.WrapHandler) return nil diff --git a/grid-proxy/tools/ratelimiter/README.md b/grid-proxy/tools/ratelimiter/README.md new file mode 100644 index 000000000..6744c8a2a --- /dev/null +++ b/grid-proxy/tools/ratelimiter/README.md @@ -0,0 +1,83 @@ +# Rate Limiter + +This package implements a sliding window rate limiter for the TFGrid Proxy server. + +## Features + +- **IP-based Rate Limiting**: Tracks requests per IP address +- **Sliding Window Algorithm**: Uses a sliding window approach for smooth rate limiting +- **Thread-Safe**: Safe for concurrent use across multiple goroutines +- **Memory Efficient**: Automatic cleanup of old entries +- **Configurable**: Rate limit can be set via command-line flag + +## Usage + +### Command Line Flag + +Use the `--rate-limit-rps` flag to set the rate limit: + +```bash +# Enable rate limiting at 20 requests per second per IP (default) +./proxy_server --rate-limit-rps 20 + +# Set custom rate limit of 100 requests per second per IP +./proxy_server --rate-limit-rps 100 + +# Disable rate limiting +./proxy_server --rate-limit-rps 0 +``` + +### IP Address Detection + +The rate limiter automatically extracts the client IP address using the following priority: + +1. `X-Real-IP` header +2. `X-Forwarded-For` header (first IP if multiple) +3. `RemoteAddr` from the connection + +This ensures proper rate limiting even when the proxy is behind load balancers or CDNs. + +### HTTP Response + +When rate limit is exceeded, the server returns: + +- **Status Code**: 429 (Too Many Requests) +- **Headers**: + - `Retry-After: 1` - Suggests retrying after 1 second + - `X-RateLimit-Limit: 20` - Current rate limit + - `X-RateLimit-Remaining: 0` - Remaining requests (0 when exceeded) + +### Algorithm Details + +The sliding window algorithm works as follows: + +1. **Time Window**: Uses a 1-second sliding window +2. **Request Tracking**: Stores timestamps of requests within the window +3. **Cleanup**: Automatically removes requests older than the window +4. **Memory Management**: Periodically cleans up inactive IP entries + +### Performance + +- **Memory Usage**: Minimal overhead, only stores active IP addresses +- **CPU Usage**: O(n) where n is the number of requests in the current window +- **Concurrency**: Thread-safe with read-write mutexes for optimal performance + +## Configuration + +| Flag | Default | Description | +|------|---------|-------------| +| `--rate-limit-rps` | 20 | Requests per second per IP address (0 to disable) | + +## Logging + +The rate limiter provides debug logging for: + +- Rate limit violations (WARN level) +- New IP tracking (DEBUG level) +- Request allowances (DEBUG level) +- Cleanup operations (DEBUG level) + +Example log output: +``` +{"level":"warn","ip":"192.168.1.100","method":"GET","path":"/nodes","time":"2025-07-07T14:40:43Z","message":"Rate limit exceeded"} +``` diff --git a/grid-proxy/tools/ratelimiter/ip_utils.go b/grid-proxy/tools/ratelimiter/ip_utils.go new file mode 100644 index 000000000..6fff1cec6 --- /dev/null +++ b/grid-proxy/tools/ratelimiter/ip_utils.go @@ -0,0 +1,26 @@ +package ratelimiter + +import ( + "net" + "net/http" + "strings" +) + +// GetClientIP extracts the real client IP from the HTTP request +// It checks X-Forwarded-For, X-Real-IP headers, and falls back to RemoteAddr +func GetClientIP(r *http.Request) string { + if realIP := r.Header.Get("X-Real-IP"); realIP != "" { + return realIP + } + + if fwd := r.Header.Get("X-Forwarded-For"); fwd != "" { + parts := strings.Split(fwd, ",") + return strings.TrimSpace(parts[0]) + } + + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + return r.RemoteAddr + } + return host +} diff --git a/grid-proxy/tools/ratelimiter/sliding_window.go b/grid-proxy/tools/ratelimiter/sliding_window.go new file mode 100644 index 000000000..6b9bdc34b --- /dev/null +++ b/grid-proxy/tools/ratelimiter/sliding_window.go @@ -0,0 +1,215 @@ +package ratelimiter + +import ( + "fmt" + "sync" + "time" + + "github.com/rs/zerolog/log" +) + +// SlidingWindow represents a sliding window rate limiter for a single IP +type SlidingWindow struct { + requests []time.Time + lastCleanup time.Time + mu sync.RWMutex +} + +// SlidingWindowRateLimiter implements IP-based rate limiting using sliding window algorithm +type SlidingWindowRateLimiter struct { + rate int + windowSize time.Duration + clients map[string]*SlidingWindow + mu sync.RWMutex + cleanupInterval time.Duration + lastGlobalCleanup time.Time +} + +// NewSlidingWindowRateLimiter creates a new sliding window rate limiter +func NewSlidingWindowRateLimiter(ratePerSecond int) *SlidingWindowRateLimiter { + return &SlidingWindowRateLimiter{ + rate: ratePerSecond, + windowSize: time.Second, + clients: make(map[string]*SlidingWindow), + cleanupInterval: time.Minute * 5, + lastGlobalCleanup: time.Now(), + } +} + +// Allow checks if a request from the given IP should be allowed +func (rl *SlidingWindowRateLimiter) Allow(ip string) bool { + now := time.Now() + rl.performGlobalCleanupIfNeeded(now) + window := rl.getOrCreateWindow(ip) + + window.mu.Lock() + defer window.mu.Unlock() + + cutoff := now.Add(-rl.windowSize) + window.requests = rl.filterRequests(window.requests, cutoff) + window.lastCleanup = now + + if len(window.requests) >= rl.rate { + log.Debug(). + Str("ip", ip). + Int("current_requests", len(window.requests)). + Int("rate_limit", rl.rate). + Msg("Rate limit exceeded") + return false + } + + window.requests = append(window.requests, now) + log.Debug(). + Str("ip", ip). + Int("current_requests", len(window.requests)). + Int("rate_limit", rl.rate). + Msg("Request allowed") + + return true +} + +// getOrCreateWindow retrieves or creates a sliding window for an IP address +func (rl *SlidingWindowRateLimiter) getOrCreateWindow(ip string) *SlidingWindow { + rl.mu.RLock() + window, exists := rl.clients[ip] + rl.mu.RUnlock() + if exists { + return window + } + + rl.mu.Lock() + defer rl.mu.Unlock() + + // Double-check after acquiring write lock + if window, exists := rl.clients[ip]; exists { + return window + } + + window = &SlidingWindow{ + requests: make([]time.Time, 0, rl.rate), + lastCleanup: time.Now(), + } + rl.clients[ip] = window + + log.Debug(). + Str("ip", ip). + Msg("Created new sliding window for IP") + + return window +} + +// filterRequests removes requests older than the cutoff time +func (rl *SlidingWindowRateLimiter) filterRequests(requests []time.Time, cutoff time.Time) []time.Time { + validIdx := 0 + for i, req := range requests { + if req.After(cutoff) { + validIdx = i + break + } + validIdx = len(requests) + } + + if validIdx >= len(requests) { + return requests[:0] + } + + return requests[validIdx:] +} + +// performGlobalCleanupIfNeeded removes old IP entries that haven't been used recently +func (rl *SlidingWindowRateLimiter) performGlobalCleanupIfNeeded(now time.Time) { + if now.Sub(rl.lastGlobalCleanup) < rl.cleanupInterval { + return + } + + rl.mu.Lock() + defer rl.mu.Unlock() + + cutoff := now.Add(-rl.cleanupInterval) + toDelete := make([]string, 0) + + for ip, window := range rl.clients { + window.mu.RLock() + shouldDelete := window.lastCleanup.Before(cutoff) && len(window.requests) == 0 + window.mu.RUnlock() + + if shouldDelete { + toDelete = append(toDelete, ip) + } + } + + for _, ip := range toDelete { + delete(rl.clients, ip) + } + + rl.lastGlobalCleanup = now + + if len(toDelete) > 0 { + log.Debug(). + Int("cleaned_ips", len(toDelete)). + Int("remaining_ips", len(rl.clients)). + Msg("Performed global cleanup of rate limiter") + } +} + +// GetStats returns current statistics about the rate limiter +func (rl *SlidingWindowRateLimiter) GetStats() map[string]interface{} { + rl.mu.RLock() + defer rl.mu.RUnlock() + + totalRequests := 0 + activeIPs := 0 + + for _, window := range rl.clients { + window.mu.RLock() + totalRequests += len(window.requests) + if len(window.requests) > 0 { + activeIPs++ + } + window.mu.RUnlock() + } + + return map[string]interface{}{ + "rate_limit_per_second": rl.rate, + "total_tracked_ips": len(rl.clients), + "active_ips": activeIPs, + "total_active_requests": totalRequests, + "window_size_seconds": rl.windowSize.Seconds(), + } +} + +// String returns a string representation of the rate limiter +func (rl *SlidingWindowRateLimiter) String() string { + return fmt.Sprintf("SlidingWindowRateLimiter(rate=%d/sec, window=%v)", rl.rate, rl.windowSize) +} + +// GetCurrentRequestCount returns the current number of requests for a specific IP +func (rl *SlidingWindowRateLimiter) GetCurrentRequestCount(ip string) int { + rl.mu.RLock() + window, exists := rl.clients[ip] + rl.mu.RUnlock() + + if !exists { + return 0 + } + + window.mu.RLock() + defer window.mu.RUnlock() + + now := time.Now() + cutoff := now.Add(-rl.windowSize) + + count := 0 + for _, req := range window.requests { + if req.After(cutoff) { + count++ + } + } + + return count +} + +// GetRateLimit returns the configured rate limit +func (rl *SlidingWindowRateLimiter) GetRateLimit() int { + return rl.rate +} diff --git a/grid-proxy/tools/ratelimiter/sliding_window_test.go b/grid-proxy/tools/ratelimiter/sliding_window_test.go new file mode 100644 index 000000000..0026a7618 --- /dev/null +++ b/grid-proxy/tools/ratelimiter/sliding_window_test.go @@ -0,0 +1,131 @@ +package ratelimiter + +import ( + "net/http" + "testing" + "time" +) + +func TestGetClientIP(t *testing.T) { + tests := []struct { + name string + headers map[string]string + remoteAddr string + expected string + }{ + { + name: "X-Forwarded-For header", + headers: map[string]string{ + "X-Forwarded-For": "192.168.1.1, 10.0.0.1", + }, + remoteAddr: "127.0.0.1:8080", + expected: "192.168.1.1", + }, + { + name: "X-Real-IP header", + headers: map[string]string{ + "X-Real-IP": "203.0.113.1", + }, + remoteAddr: "127.0.0.1:8080", + expected: "203.0.113.1", + }, + { + name: "RemoteAddr fallback", + headers: map[string]string{}, + remoteAddr: "198.51.100.1:8080", + expected: "198.51.100.1", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := &http.Request{ + Header: make(http.Header), + RemoteAddr: tt.remoteAddr, + } + + for k, v := range tt.headers { + req.Header.Set(k, v) + } + + result := GetClientIP(req) + if result != tt.expected { + t.Errorf("expected %s, got %s", tt.expected, result) + } + }) + } +} + +func TestSlidingWindowRateLimiter(t *testing.T) { + limiter := NewSlidingWindowRateLimiter(2) + + if !limiter.Allow("192.168.1.1") { + t.Error("First request should be allowed") + } + if !limiter.Allow("192.168.1.1") { + t.Error("Second request should be allowed") + } + + if limiter.Allow("192.168.1.1") { + t.Error("Third request should be blocked") + } + + // Different IP should be allowed + if !limiter.Allow("192.168.1.2") { + t.Error("Request from different IP should be allowed") + } + + // After waiting, requests should be allowed again + time.Sleep(1100 * time.Millisecond) + if !limiter.Allow("192.168.1.1") { + t.Error("Request should be allowed after window slide") + } +} + +func TestSlidingWindowRateLimiterStats(t *testing.T) { + limiter := NewSlidingWindowRateLimiter(5) + + // Make some requests + limiter.Allow("192.168.1.1") + limiter.Allow("192.168.1.1") + limiter.Allow("192.168.1.2") + + stats := limiter.GetStats() + + if stats["rate_limit_per_second"] != 5 { + t.Errorf("Expected rate limit 5, got %v", stats["rate_limit_per_second"]) + } + + if stats["total_tracked_ips"] != 2 { + t.Errorf("Expected 2 tracked IPs, got %v", stats["total_tracked_ips"]) + } +} + +func TestGetCurrentRequestCountAndRateLimit(t *testing.T) { + limiter := NewSlidingWindowRateLimiter(5) + + // Test rate limit getter + if limiter.GetRateLimit() != 5 { + t.Errorf("Expected rate limit 5, got %d", limiter.GetRateLimit()) + } + + // Test initial request count + if count := limiter.GetCurrentRequestCount("192.168.1.1"); count != 0 { + t.Errorf("Expected 0 requests initially, got %d", count) + } + + // Make some requests + limiter.Allow("192.168.1.1") + limiter.Allow("192.168.1.1") + limiter.Allow("192.168.1.1") + + // Test current request count + if count := limiter.GetCurrentRequestCount("192.168.1.1"); count != 3 { + t.Errorf("Expected 3 requests, got %d", count) + } + + // Different IP should have 0 + if count := limiter.GetCurrentRequestCount("192.168.1.2"); count != 0 { + t.Errorf("Expected 0 requests for different IP, got %d", count) + } +}