MicahParks/keyfunc

keyfunc's background refresh goroutine does not end when Options.Ctx is canceled

tho opened this issue · 2 comments

tho commented

The keyfunc.Options Ctx field allows users to pass a context to functions which load JWKS. According to the field's description, this context is used by keyfunc's background refresh. When the context expires or is canceled, the background goroutine will end. As far as I can tell, this is not the case. jwks.ctx and jwks.cancel are overridden at the end of keyfunc.Get, after options have been applied.

Illustration of the unexpected behavior

I've modified the examples/recommended_options program to illustrate the behavior.

package main

import (
	"context"
	_ "embed"
	"fmt"
	"log"
	"net/http"
	"net/http/httptest"
	"sync/atomic"
	"time"

	"github.com/MicahParks/keyfunc/v2"
)

//go:embed example_jwks.json
var jwksJSON string

func main() {
	// Create a test server to serve an example JWKS.
	var counter uint64
	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		atomic.AddUint64(&counter, 1)
		fmt.Printf("JWKS endpoint called %d time(s)\n", counter)
		_, err := w.Write([]byte(jwksJSON))
		if err != nil {
			http.Error(w, err.Error(), http.StatusInternalServerError)
		}
	}))
	defer server.Close()
	jwksURL := server.URL

	// Create a context that, when cancelled, *should* end the JWKS background refresh goroutine.
	ctx, cancel := context.WithCancel(context.Background())

	// Create the keyfunc options. Use the previously created context and refresh JWKS every second.
	options := keyfunc.Options{
		Ctx:             ctx,
		RefreshInterval: time.Second,
	}

	// Create the JWKS from the resource at the given URL.
	jwks, err := keyfunc.Get(jwksURL, options)
	if err != nil {
		log.Fatalf("Failed to create JWKS from resource at the given URL.\nError: %s", err.Error())
	}

	// Sleep for 3s.
	// Expectation: See 3 JWKS refreshes.
	time.Sleep(3 * time.Second)
	fmt.Println("---")

	// End the background refresh goroutine.
	cancel()

	// Sleep for 3s.
	// Expectation: See *no more* JWKS refreshes since the passed context got canceled.
	time.Sleep(3 * time.Second)
	fmt.Println("---")

	// This will be ineffectual because the line above this canceled the parent context.Context.
	// This method call is idempotent similar to context.CancelFunc.
	jwks.EndBackground()

	// Sleep for 5s.
	// Expectation: See *no* JWKS refreshes.
	time.Sleep(3 * time.Second)
	fmt.Println("END")
}

Output:

JWKS endpoint called 1 time(s)
JWKS endpoint called 2 time(s)
JWKS endpoint called 3 time(s)
---
JWKS endpoint called 4 time(s)
JWKS endpoint called 5 time(s)
JWKS endpoint called 6 time(s)
---
END

Potential fix:

diff --git a/get.go b/get.go
index 5dd754b..00f2f2f 100644
--- a/get.go
+++ b/get.go
@@ -53,7 +53,10 @@ func Get(jwksURL string, options Options) (jwks *JWKS, err error) {
        }
 
        if jwks.refreshInterval != 0 || jwks.refreshUnknownKID {
-               jwks.ctx, jwks.cancel = context.WithCancel(context.Background())
+               if jwks.ctx == nil {
+                       jwks.ctx = context.Background()
+               }
+               jwks.ctx, jwks.cancel = context.WithCancel(jwks.ctx)
                jwks.refreshRequests = make(chan refreshRequest, 1)
                go jwks.backgroundRefresh()
        }

Output after fix:

JWKS endpoint called 1 time(s)
JWKS endpoint called 2 time(s)
JWKS endpoint called 3 time(s)
---
---
END

Thank you for opening this issue.

If you would please make a pull request, I'll review after work today.

Fixed in #86, thank you @tho.