limit.go (1228B) - raw
1 package main 2 3 import ( 4 "log" 5 "net" 6 "net/http" 7 "sync" 8 9 "golang.org/x/time/rate" 10 ) 11 12 // Create a map to hold the rate limiters for each visitor and a mutex. 13 var visitors = make(map[string]*rate.Limiter) 14 var mu sync.Mutex 15 16 // Retrieve and return the rate limiter for the current visitor if it 17 // already exists. Otherwise create a new rate limiter and add it to 18 // the visitors map, using the IP address as the key. 19 func getVisitor(ip string) *rate.Limiter { 20 mu.Lock() 21 defer mu.Unlock() 22 23 limiter, exists := visitors[ip] 24 if !exists { 25 limiter = rate.NewLimiter(.5, 2) 26 visitors[ip] = limiter 27 } 28 29 return limiter 30 } 31 32 func limit(next http.Handler) http.Handler { 33 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 34 // Get the IP address for the current user. 35 ip, _, err := net.SplitHostPort(r.RemoteAddr) 36 if err != nil { 37 log.Println(err.Error()) 38 http.Error(w, "Internal Server Error", http.StatusInternalServerError) 39 return 40 } 41 42 // Call the getVisitor function to retreive the rate limiter for 43 // the current user. 44 limiter := getVisitor(ip) 45 if limiter.Allow() == false { 46 http.Error(w, http.StatusText(429), http.StatusTooManyRequests) 47 return 48 } 49 50 next.ServeHTTP(w, r) 51 }) 52 }