diff --git a/README.md b/README.md index ccd59cc..08f3e71 100644 --- a/README.md +++ b/README.md @@ -35,9 +35,17 @@ all config via environment variables: | HAUK_REDIS_ADDR | localhost:6379 | redis address (host:port or redis:// url) | | HAUK_AUTH_METHOD | password | auth method (password, htpasswd, ldap) | | HAUK_PASSWORD_HASH | | bcrypt hash for password auth | +| HAUK_RATE_LIMIT_AUTH | 10 | max auth requests per minute per ip | +| HAUK_RATE_LIMIT_ADOPT | 10 | max adopt requests per minute per ip | +| HAUK_TRUST_PROXY | true | trust X-Forwarded-For (set false if not behind proxy) | see `config/config.go` for full list. +## security improvements over upstream + +- adopt authorization: only share owner can adopt into groups (fixes CVE-like auth bypass in upstream) +- built-in rate limiting on auth and adopt endpoints (configurable, default 10 req/min/ip) + ## compatibility drop-in replacement for the php backend. works with the existing android app and web frontend. diff --git a/api/adopt.go b/api/adopt.go index 9b97373..31790e6 100644 --- a/api/adopt.go +++ b/api/adopt.go @@ -54,13 +54,14 @@ func (s *Server) handleAdopt(w http.ResponseWriter, r *http.Request) { return } - hostSession, err := model.LoadSession(ctx, s.store, share.Host(), s.cfg.MaxCachedPts) - if err != nil { - fmt.Fprintln(w, "Session expired!") + // verify caller owns the share being adopted + // after this check, session IS the host session + if sid != share.Host() { + fmt.Fprintln(w, "Not authorized!") return } - if hostSession.Encrypted() { + if session.Encrypted() { fmt.Fprintln(w, "End-to-end encrypted shares cannot be adopted!") return } @@ -78,13 +79,11 @@ func (s *Server) handleAdopt(w http.ResponseWriter, r *http.Request) { return } - hostSession.AddTarget(target.ID()) - if err := hostSession.Save(ctx); err != nil { + session.AddTarget(target.ID()) + if err := session.Save(ctx); err != nil { http.Error(w, "internal error", http.StatusInternalServerError) return } - _ = session - fmt.Fprintln(w, "OK") } diff --git a/api/api_test.go b/api/api_test.go index 69a9f4b..6ce1080 100644 --- a/api/api_test.go +++ b/api/api_test.go @@ -15,15 +15,18 @@ import ( func testServer() (*Server, *store.Memory) { mem := store.NewMemory() cfg := &config.Config{ - PublicURL: "https://example.com/", - MaxDuration: 86400, - MinInterval: 1, - MaxCachedPts: 3, - MaxShownPts: 100, - LinkStyle: 0, - AllowLinkReq: true, - PasswordHash: "$2a$10$LerNFYkUU3ZZrNHhamISZeDK8afdExOwDKbyTaUECDOLa1rV4iN.O", // "test" - AuthMethod: config.AuthPassword, + PublicURL: "https://example.com/", + MaxDuration: 86400, + MinInterval: 1, + MaxCachedPts: 3, + MaxShownPts: 100, + LinkStyle: 0, + AllowLinkReq: true, + PasswordHash: "$2a$10$LerNFYkUU3ZZrNHhamISZeDK8afdExOwDKbyTaUECDOLa1rV4iN.O", // "test" + AuthMethod: config.AuthPassword, + RateLimitAuth: 10000, + RateLimitAdopt: 10000, + TrustProxy: true, } return NewServer(cfg, mem), mem } @@ -587,6 +590,7 @@ func TestAdopt(t *testing.T) { "ado": {"1"}, }) soloLines := strings.Split(strings.TrimSpace(w.Body.String()), "\n") + soloSid := soloLines[1] soloShareID := soloLines[3] t.Run("missing fields", func(t *testing.T) { @@ -622,7 +626,7 @@ func TestAdopt(t *testing.T) { t.Run("successful adopt", func(t *testing.T) { w := postForm(srv, "/api/adopt.php", url.Values{ - "sid": {ownerSid}, + "sid": {soloSid}, "nic": {"adopted-user"}, "aid": {soloShareID}, "pin": {groupPin}, @@ -632,6 +636,30 @@ func TestAdopt(t *testing.T) { } }) + t.Run("unauthorized adopt", func(t *testing.T) { + // create another adoptable share + w = postForm(srv, "/api/create.php", url.Values{ + "dur": {"3600"}, + "int": {"5"}, + "pwd": {"test"}, + "mod": {"0"}, + "ado": {"1"}, + }) + lines := strings.Split(strings.TrimSpace(w.Body.String()), "\n") + anotherShareID := lines[3] + + // try to adopt with wrong session + w = postForm(srv, "/api/adopt.php", url.Values{ + "sid": {ownerSid}, + "nic": {"attacker"}, + "aid": {anotherShareID}, + "pin": {groupPin}, + }) + if !strings.Contains(w.Body.String(), "Not authorized!") { + t.Errorf("expected not authorized, got: %s", w.Body.String()) + } + }) + t.Run("non-adoptable share", func(t *testing.T) { // create non-adoptable share w = postForm(srv, "/api/create.php", url.Values{ @@ -642,10 +670,11 @@ func TestAdopt(t *testing.T) { "ado": {"0"}, }) lines := strings.Split(strings.TrimSpace(w.Body.String()), "\n") + nonAdoptableSid := lines[1] nonAdoptableID := lines[3] w = postForm(srv, "/api/adopt.php", url.Values{ - "sid": {ownerSid}, + "sid": {nonAdoptableSid}, "nic": {"adopter"}, "aid": {nonAdoptableID}, "pin": {groupPin}, diff --git a/api/server.go b/api/server.go index 818a433..e339acd 100644 --- a/api/server.go +++ b/api/server.go @@ -3,22 +3,26 @@ package api import ( "io/fs" "net/http" + "time" "github.com/parkan/go-hauk/auth" "github.com/parkan/go-hauk/config" "github.com/parkan/go-hauk/frontend" "github.com/parkan/go-hauk/linkgen" + "github.com/parkan/go-hauk/ratelimit" "github.com/parkan/go-hauk/store" ) const backendVersion = "1.6.2-go" type Server struct { - mux *http.ServeMux - cfg *config.Config - store store.Store - auth auth.Authenticator - linkgen *linkgen.Generator + mux *http.ServeMux + cfg *config.Config + store store.Store + auth auth.Authenticator + linkgen *linkgen.Generator + rlAuth *ratelimit.Limiter + rlAdopt *ratelimit.Limiter } func NewServer(cfg *config.Config, s store.Store) *Server { @@ -27,6 +31,8 @@ func NewServer(cfg *config.Config, s store.Store) *Server { cfg: cfg, store: s, linkgen: linkgen.New(s, cfg.LinkStyle), + rlAuth: ratelimit.New(cfg.RateLimitAuth, time.Minute, cfg.TrustProxy), + rlAdopt: ratelimit.New(cfg.RateLimitAdopt, time.Minute, cfg.TrustProxy), } switch cfg.AuthMethod { @@ -41,11 +47,11 @@ func NewServer(cfg *config.Config, s store.Store) *Server { srv.auth = auth.NewPasswordAuth(cfg.PasswordHash) } - srv.mux.HandleFunc("POST /api/create.php", srv.handleCreate) + srv.mux.HandleFunc("POST /api/create.php", srv.rlAuth.WrapFunc(srv.handleCreate)) srv.mux.HandleFunc("POST /api/post.php", srv.handlePost) srv.mux.HandleFunc("GET /api/fetch.php", srv.handleFetch) srv.mux.HandleFunc("POST /api/stop.php", srv.handleStop) - srv.mux.HandleFunc("POST /api/adopt.php", srv.handleAdopt) + srv.mux.HandleFunc("POST /api/adopt.php", srv.rlAdopt.WrapFunc(srv.handleAdopt)) srv.mux.HandleFunc("POST /api/new-link.php", srv.handleNewLink) srv.mux.HandleFunc("GET /dynamic.js.php", srv.handleDynamic) diff --git a/config/config.go b/config/config.go index 9486719..b20060a 100644 --- a/config/config.go +++ b/config/config.go @@ -84,6 +84,11 @@ type Config struct { TrailColor string OfflineTimeout int RequestTimeout int + + // rate limiting + RateLimitAuth int // requests per minute for auth endpoints + RateLimitAdopt int // requests per minute for adopt/join endpoints + TrustProxy bool // trust X-Forwarded-For headers } func envStr(key, def string) string { @@ -165,7 +170,10 @@ func Load() *Config { VelocityUnit: velUnit, VelocityDataPts: envInt("HAUK_VELOCITY_DATA_PTS", 2), TrailColor: envStr("HAUK_TRAIL_COLOR", "#d80037"), - OfflineTimeout: envInt("HAUK_OFFLINE_TIMEOUT", 30), - RequestTimeout: envInt("HAUK_REQUEST_TIMEOUT", 10), + OfflineTimeout: envInt("HAUK_OFFLINE_TIMEOUT", 30), + RequestTimeout: envInt("HAUK_REQUEST_TIMEOUT", 10), + RateLimitAuth: envInt("HAUK_RATE_LIMIT_AUTH", 10), + RateLimitAdopt: envInt("HAUK_RATE_LIMIT_ADOPT", 10), + TrustProxy: envBool("HAUK_TRUST_PROXY", true), } } diff --git a/ratelimit/ratelimit.go b/ratelimit/ratelimit.go new file mode 100644 index 0000000..cb2769f --- /dev/null +++ b/ratelimit/ratelimit.go @@ -0,0 +1,122 @@ +package ratelimit + +import ( + "net/http" + "strings" + "sync" + "time" +) + +const maxEntries = 10000 + +type entry struct { + count int + resetAt time.Time +} + +type Limiter struct { + mu sync.Mutex + entries map[string]*entry + limit int + window time.Duration + lastSweep time.Time + trustProxy bool +} + +func New(limit int, window time.Duration, trustProxy bool) *Limiter { + return &Limiter{ + entries: make(map[string]*entry), + limit: limit, + window: window, + lastSweep: time.Now(), + trustProxy: trustProxy, + } +} + +// Allow checks if request from key should be allowed +func (l *Limiter) Allow(key string) bool { + // limit <= 0 means disabled + if l.limit <= 0 { + return true + } + + l.mu.Lock() + defer l.mu.Unlock() + + now := time.Now() + + // cleanup stale entries periodically or when approaching cap + if time.Since(l.lastSweep) > l.window*2 || len(l.entries) >= maxEntries { + for k, e := range l.entries { + if now.After(e.resetAt) { + delete(l.entries, k) + } + } + l.lastSweep = now + } + + e, ok := l.entries[key] + if !ok || now.After(e.resetAt) { + // new entry - check cap first + if len(l.entries) >= maxEntries { + return false + } + l.entries[key] = &entry{ + count: 1, + resetAt: now.Add(l.window), + } + return true + } + + // existing entry - always allow through rate limit check + if e.count >= l.limit { + return false + } + e.count++ + return true +} + +// Middleware wraps an http.Handler with rate limiting +func (l *Limiter) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + key := l.clientIP(r) + if !l.Allow(key) { + http.Error(w, "rate limit exceeded", http.StatusTooManyRequests) + return + } + next.ServeHTTP(w, r) + }) +} + +// WrapFunc wraps an http.HandlerFunc with rate limiting +func (l *Limiter) WrapFunc(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + key := l.clientIP(r) + if !l.Allow(key) { + http.Error(w, "rate limit exceeded", http.StatusTooManyRequests) + return + } + next(w, r) + } +} + +func (l *Limiter) clientIP(r *http.Request) string { + if l.trustProxy { + // check X-Forwarded-For (railway, nginx, etc) + if xff := r.Header.Get("X-Forwarded-For"); xff != "" { + if idx := strings.Index(xff, ","); idx != -1 { + return strings.TrimSpace(xff[:idx]) + } + return strings.TrimSpace(xff) + } + // check X-Real-IP + if xri := r.Header.Get("X-Real-IP"); xri != "" { + return xri + } + } + // use remote addr directly + if idx := strings.LastIndex(r.RemoteAddr, ":"); idx != -1 { + return r.RemoteAddr[:idx] + } + return r.RemoteAddr +}