/shield

Shield is a simple net/http compatible middleware which blocks or allows requests based on a predicate

Primary LanguageGoMIT LicenseMIT

Build Status codecov GoDoc Go Report Card

Shield

Shield is a net/http compatible middleware which blocks or allows requests based on a predicate. Shield replies back with a user defined response when the request is blocked.

Usage

Below you can find a example of how to configure the Shield middleware in order to allow only requests with GET method, and reply back with 405 Method Not Allowed in any other case.

package main

import (
	"net/http"

	"github.com/psampaz/shield"
)

func main() {

	shieldMiddleware := shield.New(shield.Options{
		Block: func(r *http.Request) bool {
			return r.Method != "GET"
		},
		Code:    http.StatusMethodNotAllowed,
		Headers: http.Header{"Content-Type": {"text/plain"}},
		Body:    []byte(http.StatusText(http.StatusMethodNotAllowed)),
	})
    
	helloWorldHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		w.Write([]byte("hello world"))
	})
    
	http.ListenAndServe(":8080", shieldMiddleware.Handler(helloWorldHandler))
}
$ curl -i -X GET localhost:8080

HTTP/1.1 200 OK
Date: Sat, 22 Feb 2020 10:03:35 GMT
Content-Length: 11
Content-Type: text/plain; charset=utf-8

hello world
$ curl -i -X POST localhost:8080

HTTP/1.1 400 Bad Request
Content-Type: text/plain
Date: Sat, 22 Feb 2020 10:02:31 GMT
Content-Length: 11

Bad Request

Passing a func as Block option, gives you access only in the current request. If there is a need to to use non request related data and functionality, you can you a stuct method with the same signature.

package main

import (
	"net/http"

	"github.com/psampaz/shield"
)

type BlockLogic struct {
	ShouldBLock bool
}

func (b *BlockLogic) Block(r *http.Request) bool {
	return b.ShouldBLock
}

func main() {
	blockLogic := BlockLogic{true}
	shieldMiddleware := shield.New(shield.Options{
		Block:   blockLogic.Block,
		Code:    http.StatusBadRequest,
		Headers: http.Header{"Content-Type": {"text/plain"}},
		Body:    []byte(http.StatusText(http.StatusBadRequest)),
	})

	helloWorldHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		w.Write([]byte("hello world"))
	})

	http.ListenAndServe(":8080", shieldMiddleware.Handler(helloWorldHandler))
}

Options

Shield middleware can be configured with the following options:

type Options struct {
	// Block is a predicate responsible for blocking the request.
	// Return true when the request should be blocked, false otherwise
	Block func(r *http.Request) bool
	// Code  is the status code of the response, sent when the request is blocked
	Code int
	// Headers are the headers of the response, sent when the request is blocked
	Headers http.Header
	// Body is the body of the response, sent when the request is blocked
	Body []byte
}

List of predefined block methods

Block based on a list of query param regexes

	queryBlock := blocker.NewQuery(map[string]string{
		"op":       "search",
		"page":     `\d+`,
		"v":        `[a-zA-Z]+`,
		"optional": `^$|\d`,
	})
	shieldMiddleware := shield.New(shield.Options{
		Block:   queryBlock.Block,
		Code:    http.StatusBadRequest,
		Headers: http.Header{"Content-Type": {"text/plain"}},
		Body:    []byte(http.StatusText(http.StatusBadRequest)),
	})

Block based on a list of HTTP Method

	methodBlock := blocker.NewMethod([]string{http.MethodGet, http.MethodPost})
	shieldMiddleware := shield.New(shield.Options{
		Block:   methodBlock.Block,
		Code:    http.StatusBadRequest,
		Headers: http.Header{"Content-Type": {"text/plain"}},
		Body:    []byte(http.StatusText(http.StatusBadRequest)),
	})

Block based on a list of HTTP Scheme

	schemeBlock := blocker.NewScheme([]string{"https"})
	shieldMiddleware := shield.New(shield.Options{
		Block:   schemeBlock.Block,
		Code:    http.StatusBadRequest,
		Headers: http.Header{"Content-Type": {"text/plain"}},
		Body:    []byte(http.StatusText(http.StatusBadRequest)),
	})

Integration with popular routers

Gorilla Mux

package main

import (
	"net/http"

	"github.com/psampaz/shield"

	"github.com/gorilla/mux"
)

func main() {
	shieldMiddleware := shield.New(shield.Options{
		Block: func(r *http.Request) bool {
			return true
		},
		Code:    http.StatusMethodNotAllowed,
		Headers: http.Header{"Content-Type": {"text/plain"}},
		Body:    []byte(http.StatusText(http.StatusMethodNotAllowed)),
	})

	r := mux.NewRouter()
	r.Use(shieldMiddleware.Handler)
	r.HandleFunc("/", HelloHandler)

	http.ListenAndServe(":8080", r)
}

func HelloHandler(w http.ResponseWriter, r *http.Request) {
	w.Write([]byte("hello world"))
}

Chi

package main

import (
	"net/http"

	"github.com/psampaz/shield"

	"github.com/go-chi/chi"
)

func main() {
	shieldMiddleware := shield.New(shield.Options{
		Block: func(r *http.Request) bool {
			return true
		},
		Code:    http.StatusMethodNotAllowed,
		Headers: http.Header{"Content-Type": {"text/plain"}},
		Body:    []byte(http.StatusText(http.StatusMethodNotAllowed)),
	})

	r := chi.NewRouter()
	r.Use(shieldMiddleware.Handler)
	r.Get("/", HelloHandler)

	http.ListenAndServe(":8080", r)
}

func HelloHandler(w http.ResponseWriter, r *http.Request) {
	w.Write([]byte("hello world"))
}

Echo

package main

import (
	"net/http"

	"github.com/psampaz/shield"

	"github.com/labstack/echo"
)

func main() {
	shieldMiddleware := shield.New(shield.Options{
		Block: func(r *http.Request) bool {
			return true
		},
		Code:    http.StatusMethodNotAllowed,
		Headers: http.Header{"Content-Type": {"text/plain"}},
		Body:    []byte(http.StatusText(http.StatusMethodNotAllowed)),
	})

	e := echo.New()
	e.Use(echo.WrapMiddleware(shieldMiddleware.Handler))

	e.GET("/", func(c echo.Context) error {
		return c.String(http.StatusOK, "Hello world")
	})

	e.Start((":8080"))
}