From 70bb30a8ec75d713bec445952807a143b0c15383 Mon Sep 17 00:00:00 2001 From: Arkadiy Kukarkin Date: Tue, 26 May 2026 22:41:24 +0200 Subject: [PATCH] ratelimit: derive client IP from rightmost proxy hop --- ratelimit/ratelimit.go | 10 +++--- ratelimit/ratelimit_test.go | 67 +++++++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+), 5 deletions(-) create mode 100644 ratelimit/ratelimit_test.go diff --git a/ratelimit/ratelimit.go b/ratelimit/ratelimit.go index cb2769f..cef088f 100644 --- a/ratelimit/ratelimit.go +++ b/ratelimit/ratelimit.go @@ -102,16 +102,16 @@ func (l *Limiter) WrapFunc(next http.HandlerFunc) http.HandlerFunc { func (l *Limiter) clientIP(r *http.Request) string { if l.trustProxy { - // check X-Forwarded-For (railway, nginx, etc) + // rightmost hop is the one our proxy appended; leftmost is spoofable. + // assumes a single trusted proxy in front (railway, nginx) if xff := r.Header.Get("X-Forwarded-For"); xff != "" { - if idx := strings.Index(xff, ","); idx != -1 { - return strings.TrimSpace(xff[:idx]) + if idx := strings.LastIndex(xff, ","); idx != -1 { + return strings.TrimSpace(xff[idx+1:]) } return strings.TrimSpace(xff) } - // check X-Real-IP if xri := r.Header.Get("X-Real-IP"); xri != "" { - return xri + return strings.TrimSpace(xri) } } // use remote addr directly diff --git a/ratelimit/ratelimit_test.go b/ratelimit/ratelimit_test.go new file mode 100644 index 0000000..62dabf7 --- /dev/null +++ b/ratelimit/ratelimit_test.go @@ -0,0 +1,67 @@ +package ratelimit + +import ( + "net/http" + "testing" + "time" +) + +func TestClientIP(t *testing.T) { + req := func(remote, xff, xri string) *http.Request { + r := &http.Request{RemoteAddr: remote, Header: http.Header{}} + if xff != "" { + r.Header.Set("X-Forwarded-For", xff) + } + if xri != "" { + r.Header.Set("X-Real-IP", xri) + } + return r + } + + tests := []struct { + name string + trustProxy bool + r *http.Request + want string + }{ + // leftmost entries are attacker-controlled; we must land on the rightmost + {"spoofed leftmost", true, req("10.0.0.1:1", "1.2.3.4, 9.9.9.9, 203.0.113.7", ""), "203.0.113.7"}, + {"single xff", true, req("10.0.0.1:1", "203.0.113.7", ""), "203.0.113.7"}, + {"xff over xri", true, req("10.0.0.1:1", "203.0.113.7", "8.8.8.8"), "203.0.113.7"}, + {"xri fallback", true, req("10.0.0.1:1", "", "8.8.8.8"), "8.8.8.8"}, + {"no headers", true, req("10.0.0.1:1234", "", ""), "10.0.0.1"}, + // untrusted: headers ignored entirely, even if present + {"distrust ignores xff", false, req("10.0.0.1:1234", "1.2.3.4", "8.8.8.8"), "10.0.0.1"}, + } + + l := New(10, time.Minute, true) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + l.trustProxy = tt.trustProxy + if got := l.clientIP(tt.r); got != tt.want { + t.Errorf("got %q, want %q", got, tt.want) + } + }) + } +} + +func TestAllow(t *testing.T) { + l := New(2, time.Minute, false) + if !l.Allow("a") || !l.Allow("a") { + t.Fatal("first two requests should pass") + } + if l.Allow("a") { + t.Error("third request should be denied") + } + if !l.Allow("b") { + t.Error("a different key should have its own bucket") + } + + // limit <= 0 disables + off := New(0, time.Minute, false) + for i := 0; i < 100; i++ { + if !off.Allow("x") { + t.Fatal("disabled limiter should always allow") + } + } +}