diff --git a/pkg/app/app.go b/pkg/app/app.go index c70bc9a..8ca8142 100644 --- a/pkg/app/app.go +++ b/pkg/app/app.go @@ -3,6 +3,7 @@ package app import ( "context" "errors" + "net/http" "github.com/rs/zerolog" "go.opentelemetry.io/otel/codes" @@ -25,6 +26,7 @@ type App struct { type AppHTTP struct { Funcs []srv.HTTPFunc + Middleware []http.Handler HealthChecks []srv.HealthCheckFunc httpDone <-chan interface{} } @@ -78,9 +80,12 @@ func (a *App) MustRun() { func (a *App) initHTTP() { var httpShutdown shutdownFunc httpShutdown, a.HTTP.httpDone = srv.MustInitHTTPServer( - a.AppContext, - a.HTTP.Funcs, - a.HTTP.HealthChecks..., + &srv.HTTPServerOpts{ + Ctx: a.AppContext, + HandleFuncs: a.HTTP.Funcs, + Middleware: a.HTTP.Middleware, + HealthCheckFuncs: a.HTTP.HealthChecks, + }, ) a.shutdownFuncs = append(a.shutdownFuncs, httpShutdown) } diff --git a/pkg/srv/http.go b/pkg/srv/http.go index 0b43962..6c9e983 100644 --- a/pkg/srv/http.go +++ b/pkg/srv/http.go @@ -29,10 +29,17 @@ type HTTPFunc struct { HandlerFunc http.HandlerFunc } -func prepHTTPServer(ctx context.Context, handleFuncs []HTTPFunc, hcFuncs ...HealthCheckFunc) *http.Server { +type HTTPServerOpts struct { + Ctx context.Context + HandleFuncs []HTTPFunc + Middleware []http.Handler + HealthCheckFuncs []HealthCheckFunc +} + +func prepHTTPServer(opts *HTTPServerOpts) *http.Server { var ( - cfg = config.MustFromCtx(ctx) - l = zerolog.Ctx(ctx) + cfg = config.MustFromCtx(opts.Ctx) + l = zerolog.Ctx(opts.Ctx) mux = &http.ServeMux{} ) @@ -43,11 +50,11 @@ func prepHTTPServer(ctx context.Context, handleFuncs []HTTPFunc, hcFuncs ...Heal mux.Handle(pattern, handler) // Associate pattern with handler } - healthChecks := handleHealthCheckFunc(ctx, hcFuncs...) + healthChecks := handleHealthCheckFunc(opts.Ctx, opts.HealthCheckFuncs...) otelHandleFunc("/health", healthChecks) otelHandleFunc("/", healthChecks) - for _, f := range handleFuncs { + for _, f := range opts.HandleFuncs { otelHandleFunc(f.Path, f.HandlerFunc) } @@ -89,9 +96,19 @@ func prepHTTPServer(ctx context.Context, handleFuncs []HTTPFunc, hcFuncs ...Heal idleTimeout = *iT } + // Inject any supplied middleware + for i := len(opts.Middleware) - 1; i >= 0; i-- { + mw := opts.Middleware[i] + next := handler + handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mw.ServeHTTP(w, r) + next.ServeHTTP(w, r) + }) + } + // Inject logging middleware if cfg.HTTP.LogRequests { - handler = loggingMiddleware(ctx, handler) + handler = loggingMiddleware(opts.Ctx, handler) } return &http.Server{ @@ -101,17 +118,17 @@ func prepHTTPServer(ctx context.Context, handleFuncs []HTTPFunc, hcFuncs ...Heal IdleTimeout: idleTimeout, Handler: handler, BaseContext: func(_ net.Listener) context.Context { - return ctx + return opts.Ctx }, } } // Returns a shutdown func and a done channel if the // server aborts abnormally. Panics on error. -func MustInitHTTPServer(ctx context.Context, funcs []HTTPFunc, hcFuncs ...HealthCheckFunc) ( +func MustInitHTTPServer(opts *HTTPServerOpts) ( func(context.Context) error, <-chan interface{}, ) { - shutdownFunc, doneChan, err := InitHTTPServer(ctx, funcs, hcFuncs...) + shutdownFunc, doneChan, err := InitHTTPServer(opts) if err != nil { panic(err) } @@ -120,18 +137,18 @@ func MustInitHTTPServer(ctx context.Context, funcs []HTTPFunc, hcFuncs ...Health // Returns a shutdown func and a done channel if the // server aborts abnormally. Returns error on failure to start -func InitHTTPServer(ctx context.Context, funcs []HTTPFunc, hcFuncs ...HealthCheckFunc) ( +func InitHTTPServer(opts *HTTPServerOpts) ( func(context.Context) error, <-chan interface{}, error, ) { - l := zerolog.Ctx(ctx) + l := zerolog.Ctx(opts.Ctx) doneChan := make(chan interface{}) var server *http.Server - httpMeter = otel.GetMeter(ctx, "http") - httpTracer = otel.GetTracer(ctx, "http") + httpMeter = otel.GetMeter(opts.Ctx, "http") + httpTracer = otel.GetTracer(opts.Ctx, "http") - server = prepHTTPServer(ctx, funcs, hcFuncs...) + server = prepHTTPServer(opts) go func() { l.Debug().Msg("HTTP Server Started")