diff --git a/CHANGELOG.md b/CHANGELOG.md index e749c3372c8..c2599444d2d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,6 +35,7 @@ * [ENHANCEMENT] Ingester: Instrument Ingester CPU profile with userID for read APIs. #7184 * [ENHANCEMENT] Ingester: Add fetch timeout for Ingester expanded postings cache. #7185 * [ENHANCEMENT] Ingester: Add feature flag to collect metrics of how expensive an unoptimized regex matcher is and new limits to protect Ingester query path against expensive unoptimized regex matchers. #7194 #7210 +* [ENHANCEMENT] Querier: Add active API requests tracker logging to help with OOMKill troubleshooting. #7216 * [ENHANCEMENT] Compactor: Add partition group creation time to visit marker. #7217 * [BUGFIX] Ring: Change DynamoDB KV to retry indefinitely for WatchKey. #7088 * [BUGFIX] Ruler: Add XFunctions validation support. #7111 diff --git a/pkg/api/handlers.go b/pkg/api/handlers.go index 2b30e8aa587..7f170b5f4f0 100644 --- a/pkg/api/handlers.go +++ b/pkg/api/handlers.go @@ -32,6 +32,7 @@ import ( "github.com/cortexproject/cortex/pkg/querier/stats" "github.com/cortexproject/cortex/pkg/util" util_log "github.com/cortexproject/cortex/pkg/util/log" + "github.com/cortexproject/cortex/pkg/util/request_tracker" ) const ( @@ -285,35 +286,61 @@ func NewQuerierHandler( queryAPI := queryapi.NewQueryAPI(engine, translateSampleAndChunkQueryable, statsRenderer, logger, codecs, corsOrigin) + requestTracker := request_tracker.NewRequestTracker(querierCfg.ActiveQueryTrackerDir, "apis.active", querierCfg.MaxConcurrent, util_log.GoKitLogToSlog(logger)) + var apiHandler http.Handler + var instantQueryHandler http.Handler + var rangedQueryHandler http.Handler + var legacyAPIHandler http.Handler + if requestTracker != nil { + apiHandler = request_tracker.NewRequestWrapper(promRouter, requestTracker, &request_tracker.ApiExtractor{}) + legacyAPIHandler = request_tracker.NewRequestWrapper(legacyPromRouter, requestTracker, &request_tracker.ApiExtractor{}) + instantQueryHandler = request_tracker.NewRequestWrapper(queryAPI.Wrap(queryAPI.InstantQueryHandler), requestTracker, &request_tracker.InstantQueryExtractor{}) + rangedQueryHandler = request_tracker.NewRequestWrapper(queryAPI.Wrap(queryAPI.RangeQueryHandler), requestTracker, &request_tracker.RangedQueryExtractor{}) + + httpHeaderMiddleware := &HTTPHeaderMiddleware{ + TargetHeaders: cfg.HTTPRequestHeadersToLog, + RequestIdHeader: cfg.RequestIdHeader, + } + apiHandler = httpHeaderMiddleware.Wrap(apiHandler) + legacyAPIHandler = httpHeaderMiddleware.Wrap(legacyAPIHandler) + instantQueryHandler = httpHeaderMiddleware.Wrap(instantQueryHandler) + rangedQueryHandler = httpHeaderMiddleware.Wrap(rangedQueryHandler) + } else { + apiHandler = promRouter + legacyAPIHandler = legacyPromRouter + instantQueryHandler = queryAPI.Wrap(queryAPI.InstantQueryHandler) + rangedQueryHandler = queryAPI.Wrap(queryAPI.RangeQueryHandler) + } + // TODO(gotjosh): This custom handler is temporary until we're able to vendor the changes in: // https://github.com/prometheus/prometheus/pull/7125/files router.Path(path.Join(prefix, "/api/v1/metadata")).Handler(querier.MetadataHandler(metadataQuerier)) router.Path(path.Join(prefix, "/api/v1/read")).Handler(querier.RemoteReadHandler(queryable, logger)) router.Path(path.Join(prefix, "/api/v1/read")).Methods("POST").Handler(promRouter) - router.Path(path.Join(prefix, "/api/v1/query")).Methods("GET", "POST").Handler(queryAPI.Wrap(queryAPI.InstantQueryHandler)) - router.Path(path.Join(prefix, "/api/v1/query_range")).Methods("GET", "POST").Handler(queryAPI.Wrap(queryAPI.RangeQueryHandler)) + router.Path(path.Join(prefix, "/api/v1/query")).Methods("GET", "POST").Handler(instantQueryHandler) + router.Path(path.Join(prefix, "/api/v1/query_range")).Methods("GET", "POST").Handler(rangedQueryHandler) router.Path(path.Join(prefix, "/api/v1/query_exemplars")).Methods("GET", "POST").Handler(promRouter) router.Path(path.Join(prefix, "/api/v1/format_query")).Methods("GET", "POST").Handler(promRouter) router.Path(path.Join(prefix, "/api/v1/parse_query")).Methods("GET", "POST").Handler(promRouter) - router.Path(path.Join(prefix, "/api/v1/labels")).Methods("GET", "POST").Handler(promRouter) - router.Path(path.Join(prefix, "/api/v1/label/{name}/values")).Methods("GET").Handler(promRouter) - router.Path(path.Join(prefix, "/api/v1/series")).Methods("GET", "POST", "DELETE").Handler(promRouter) - router.Path(path.Join(prefix, "/api/v1/metadata")).Methods("GET").Handler(promRouter) + router.Path(path.Join(prefix, "/api/v1/labels")).Methods("GET", "POST").Handler(apiHandler) + router.Path(path.Join(prefix, "/api/v1/label/{name}/values")).Methods("GET").Handler(apiHandler) + router.Path(path.Join(prefix, "/api/v1/series")).Methods("GET", "POST", "DELETE").Handler(apiHandler) + router.Path(path.Join(prefix, "/api/v1/metadata")).Methods("GET").Handler(apiHandler) // TODO(gotjosh): This custom handler is temporary until we're able to vendor the changes in: // https://github.com/prometheus/prometheus/pull/7125/files router.Path(path.Join(legacyPrefix, "/api/v1/metadata")).Handler(querier.MetadataHandler(metadataQuerier)) router.Path(path.Join(legacyPrefix, "/api/v1/read")).Handler(querier.RemoteReadHandler(queryable, logger)) router.Path(path.Join(legacyPrefix, "/api/v1/read")).Methods("POST").Handler(legacyPromRouter) - router.Path(path.Join(legacyPrefix, "/api/v1/query")).Methods("GET", "POST").Handler(queryAPI.Wrap(queryAPI.InstantQueryHandler)) - router.Path(path.Join(legacyPrefix, "/api/v1/query_range")).Methods("GET", "POST").Handler(queryAPI.Wrap(queryAPI.RangeQueryHandler)) + router.Path(path.Join(legacyPrefix, "/api/v1/query")).Methods("GET", "POST").Handler(instantQueryHandler) + router.Path(path.Join(legacyPrefix, "/api/v1/query_range")).Methods("GET", "POST").Handler(rangedQueryHandler) router.Path(path.Join(legacyPrefix, "/api/v1/query_exemplars")).Methods("GET", "POST").Handler(legacyPromRouter) router.Path(path.Join(legacyPrefix, "/api/v1/format_query")).Methods("GET", "POST").Handler(legacyPromRouter) router.Path(path.Join(legacyPrefix, "/api/v1/parse_query")).Methods("GET", "POST").Handler(legacyPromRouter) - router.Path(path.Join(legacyPrefix, "/api/v1/labels")).Methods("GET", "POST").Handler(legacyPromRouter) - router.Path(path.Join(legacyPrefix, "/api/v1/label/{name}/values")).Methods("GET").Handler(legacyPromRouter) - router.Path(path.Join(legacyPrefix, "/api/v1/series")).Methods("GET", "POST", "DELETE").Handler(legacyPromRouter) - router.Path(path.Join(legacyPrefix, "/api/v1/metadata")).Methods("GET").Handler(legacyPromRouter) + router.Path(path.Join(legacyPrefix, "/api/v1/labels")).Methods("GET", "POST").Handler(legacyAPIHandler) + router.Path(path.Join(legacyPrefix, "/api/v1/label/{name}/values")).Methods("GET").Handler(legacyAPIHandler) + router.Path(path.Join(legacyPrefix, "/api/v1/series")).Methods("GET", "POST", "DELETE").Handler(legacyAPIHandler) + router.Path(path.Join(legacyPrefix, "/api/v1/metadata")).Methods("GET").Handler(legacyAPIHandler) if cfg.buildInfoEnabled { router.Path(path.Join(prefix, "/api/v1/status/buildinfo")).Methods("GET").Handler(promRouter) diff --git a/pkg/util/request_tracker/request_extractor.go b/pkg/util/request_tracker/request_extractor.go new file mode 100644 index 00000000000..dbf1a7a71e0 --- /dev/null +++ b/pkg/util/request_tracker/request_extractor.go @@ -0,0 +1,125 @@ +package request_tracker + +import ( + "encoding/json" + "net/http" + "strings" + "time" + "unicode/utf8" + + "github.com/cortexproject/cortex/pkg/util/requestmeta" + "github.com/cortexproject/cortex/pkg/util/users" +) + +type Extractor interface { + Extract(r *http.Request) []byte +} + +type DefaultExtractor struct{} + +type ApiExtractor struct{} + +type InstantQueryExtractor struct{} + +type RangedQueryExtractor struct{} + +func generateCommonMap(r *http.Request) map[string]interface{} { + ctx := r.Context() + entryMap := make(map[string]interface{}) + entryMap["timestamp-sec"] = time.Now().Unix() + entryMap["Path"] = r.URL.Path + entryMap["Method"] = r.Method + entryMap["TenantID"], _ = users.TenantID(ctx) + entryMap["RequestID"] = requestmeta.RequestIdFromContext(ctx) + entryMap["UserAgent"] = r.Header.Get("User-Agent") + entryMap["DashboardUID"] = r.Header.Get("X-Dashboard-UID") + entryMap["PanelId"] = r.Header.Get("X-Panel-Id") + + return entryMap +} + +func (e *DefaultExtractor) Extract(r *http.Request) []byte { + entryMap := generateCommonMap(r) + + return generateJSONEntry(entryMap) +} + +func (e *ApiExtractor) Extract(r *http.Request) []byte { + entryMap := generateCommonMap(r) + entryMap["limit"] = r.URL.Query().Get("limit") + entryMap["start"] = r.URL.Query().Get("start") + entryMap["end"] = r.URL.Query().Get("end") + + matches := r.URL.Query()["match[]"] + entryMap["number-of-matches"] = len(matches) + matchesStr := strings.Join(matches, ",") + + return generateJSONEntryWithTruncatedField(entryMap, "matches", matchesStr) +} + +func (e *InstantQueryExtractor) Extract(r *http.Request) []byte { + entryMap := generateCommonMap(r) + entryMap["time"] = r.URL.Query().Get("time") + return generateJSONEntryWithTruncatedField(entryMap, "query", r.URL.Query().Get("query")) +} + +func (e *RangedQueryExtractor) Extract(r *http.Request) []byte { + entryMap := generateCommonMap(r) + entryMap["start"] = r.URL.Query().Get("start") + entryMap["end"] = r.URL.Query().Get("end") + entryMap["step"] = r.URL.Query().Get("step") + return generateJSONEntryWithTruncatedField(entryMap, "query", r.URL.Query().Get("query")) +} + +func generateJSONEntry(entryMap map[string]interface{}) []byte { + jsonEntry, err := json.Marshal(entryMap) + if err != nil { + return []byte{} + } + + return jsonEntry +} + +func generateJSONEntryWithTruncatedField(entryMap map[string]interface{}, fieldName, fieldValue string) []byte { + entryMap[fieldName] = "" + minEntryJSON := generateJSONEntry(entryMap) + entryMap[fieldName] = trimForJsonMarshal(fieldValue, maxEntrySize-(len(minEntryJSON)+1)) + return generateJSONEntry(entryMap) +} + +func trimStringByBytes(bytesStr []byte, size int) string { + trimIndex := len(bytesStr) + if size < len(bytesStr) { + for !utf8.RuneStart(bytesStr[size]) { + size-- + } + trimIndex = size + } + + return string(bytesStr[:trimIndex]) +} + +func trimForJsonMarshal(field string, size int) string { + fieldValueEncoded, err := json.Marshal(field) + if err != nil { + return "" + } + fieldValueEncoded = fieldValueEncoded[1 : len(fieldValueEncoded)-1] + fieldValueEncodedTrimmed := trimStringByBytes(fieldValueEncoded, size) + fieldValueEncodedTrimmed = "\"" + removeHalfCutEscapeChar(fieldValueEncodedTrimmed) + "\"" + var fieldValue string + err = json.Unmarshal([]byte(fieldValueEncodedTrimmed), &fieldValue) + if err != nil { + return "" + } + + return fieldValue +} + +func removeHalfCutEscapeChar(str string) string { + trailingBashslashCount := len(str) - len(strings.TrimRight(str, "\\")) + if trailingBashslashCount%2 == 1 { + str = str[0 : len(str)-1] + } + return str +} diff --git a/pkg/util/request_tracker/request_tracker.go b/pkg/util/request_tracker/request_tracker.go new file mode 100644 index 00000000000..1c739bd24bd --- /dev/null +++ b/pkg/util/request_tracker/request_tracker.go @@ -0,0 +1,202 @@ +// Copyright The Prometheus Authors +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package request_tracker + +import ( + "context" + "errors" + "fmt" + "io" + "log/slog" + "os" + "path/filepath" + "strings" + "time" + + "github.com/edsrzf/mmap-go" +) + +type RequestTracker struct { + mmappedFile []byte + getNextIndex chan int + logger *slog.Logger + closer io.Closer + maxConcurrent int +} + +var _ io.Closer = &RequestTracker{} + +const ( + maxEntrySize int = 1000 +) + +func parseRequestBrokenJSON(brokenJSON []byte) (string, bool) { + requests := strings.ReplaceAll(string(brokenJSON), "\x00", "") + if len(requests) > 0 { + requests = requests[:len(requests)-1] + "]" + } + + if len(requests) <= 1 { + return "[]", false + } + + return requests, true +} + +func logUnfinishedRequests(filename string, filesize int, logger *slog.Logger) { + if _, err := os.Stat(filename); err == nil { + fd, err := os.Open(filename) + if err != nil { + logger.Error("Failed to open request log file", "err", err) + return + } + defer fd.Close() + + brokenJSON := make([]byte, filesize) + _, err = fd.Read(brokenJSON) + if err != nil { + logger.Error("Failed to read request log file", "err", err) + return + } + + requests, requestsExist := parseRequestBrokenJSON(brokenJSON) + if !requestsExist { + return + } + logger.Info("These requests didn't finish in cortex's last run:", "requests", requests) + } +} + +type mmappedRequestFile struct { + f io.Closer + m mmap.MMap +} + +func (f *mmappedRequestFile) Close() error { + err := f.m.Unmap() + if err != nil { + err = fmt.Errorf("mmappedRequestFile: unmapping: %w", err) + } + if fErr := f.f.Close(); fErr != nil { + return errors.Join(fmt.Errorf("close mmappedRequestFile.f: %w", fErr), err) + } + + return err +} + +func getRequestMMappedFile(filename string, filesize int, logger *slog.Logger) ([]byte, io.Closer, error) { + file, err := os.OpenFile(filename, os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0o666) + if err != nil { + absPath, pathErr := filepath.Abs(filename) + if pathErr != nil { + absPath = filename + } + logger.Error("Error opening request log file", "file", absPath, "err", err) + return nil, nil, err + } + + err = file.Truncate(int64(filesize)) + if err != nil { + file.Close() + logger.Error("Error setting filesize.", "filesize", filesize, "err", err) + return nil, nil, err + } + + fileAsBytes, err := mmap.Map(file, mmap.RDWR, 0) + if err != nil { + file.Close() + logger.Error("Failed to mmap", "file", filename, "Attempted size", filesize, "err", err) + return nil, nil, err + } + + return fileAsBytes, &mmappedRequestFile{f: file, m: fileAsBytes}, err +} + +func NewRequestTracker(localStoragePath string, fileName string, maxConcurrent int, logger *slog.Logger) *RequestTracker { + if localStoragePath == "" { + return nil + } + + err := os.MkdirAll(localStoragePath, 0o777) + if err != nil { + logger.Error("Failed to create directory for logging active requests") + return nil + } + + filename, filesize := filepath.Join(localStoragePath, fileName), 1+maxConcurrent*maxEntrySize + logUnfinishedRequests(filename, filesize, logger) + + fileAsBytes, closer, err := getRequestMMappedFile(filename, filesize, logger) + if err != nil { + logger.Error("Unable to create mmap-ed active request log", "err", err) + return nil + } + + copy(fileAsBytes, "[") + requestTracker := &RequestTracker{ + mmappedFile: fileAsBytes, + closer: closer, + getNextIndex: make(chan int, maxConcurrent), + logger: logger, + maxConcurrent: maxConcurrent, + } + + requestTracker.generateIndices(maxConcurrent) + + return requestTracker +} + +func (tracker *RequestTracker) generateIndices(maxConcurrent int) { + for i := 0; i < maxConcurrent; i++ { + tracker.getNextIndex <- 1 + (i * maxEntrySize) + } +} + +func (tracker *RequestTracker) Delete(insertIndex int) { + copy(tracker.mmappedFile[insertIndex:], strings.Repeat("\x00", maxEntrySize)) + tracker.getNextIndex <- insertIndex +} + +func (tracker *RequestTracker) Insert(ctx context.Context, entry []byte) (int, error) { + if len(entry) > maxEntrySize { + entry = generateMinEntry() + } + select { + case i := <-tracker.getNextIndex: + fileBytes := tracker.mmappedFile + start, end := i, i+maxEntrySize + + copy(fileBytes[start:], entry) + copy(fileBytes[end-1:], ",") + return i, nil + case <-ctx.Done(): + return 0, ctx.Err() + } +} + +func generateMinEntry() []byte { + entryMap := make(map[string]interface{}) + entryMap["timestamp_sec"] = time.Now().Unix() + return generateJSONEntry(entryMap) +} + +func (tracker *RequestTracker) Close() error { + if tracker == nil || tracker.closer == nil { + return nil + } + if err := tracker.closer.Close(); err != nil { + return fmt.Errorf("close RequestTracker.closer: %w", err) + } + return nil +} diff --git a/pkg/util/request_tracker/request_tracker_test.go b/pkg/util/request_tracker/request_tracker_test.go new file mode 100644 index 00000000000..e4d7a009ceb --- /dev/null +++ b/pkg/util/request_tracker/request_tracker_test.go @@ -0,0 +1,155 @@ +package request_tracker + +import ( + "context" + "log/slog" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAPITracker(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "api-tracker-test") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + tracker := NewRequestTracker(tmpDir, "apis.active", 10, logger) + require.NotNil(t, tracker) + defer tracker.Close() + + ctx := context.Background() + insertIndex, err := tracker.Insert(ctx, []byte{}) + require.NoError(t, err) + assert.Greater(t, insertIndex, 0) + + tracker.Delete(insertIndex) +} + +func TestAPITrackerLogUnfinished(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "api-tracker-test") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + filename := filepath.Join(tmpDir, "apis.active") + content := `[{"path":"/api/v1/series","method":"GET","timestamp_sec":1234567890},` + err = os.WriteFile(filename, []byte(content), 0644) + require.NoError(t, err) + + var logOutput strings.Builder + logger := slog.New(slog.NewTextHandler(&logOutput, nil)) + + tracker := NewRequestTracker(tmpDir, "apis.active", 10, logger) + require.NotNil(t, tracker) + defer tracker.Close() + output := logOutput.String() + assert.Contains(t, output, "These requests didn't finish in cortex's last run") + assert.Contains(t, output, "/api/v1/series") +} + +func TestAPITrackerNilDirectory(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + tracker := NewRequestTracker("", "apis.active", 10, logger) + assert.Nil(t, tracker) +} + +func TestAPIWrapper(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "api-wrapper-test") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + tracker := NewRequestTracker(tmpDir, "apis.active", 10, logger) + require.NotNil(t, tracker) + defer tracker.Close() + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + wrapper := NewRequestWrapper(handler, tracker, &ApiExtractor{}) + + req := httptest.NewRequest("GET", "/api/v1/series?match[]=up", nil) + rr := httptest.NewRecorder() + wrapper.ServeHTTP(rr, req) + assert.Equal(t, http.StatusOK, rr.Code) +} + +func TestAPIWrapperNilTracker(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + wrapper := NewRequestWrapper(handler, nil, &ApiExtractor{}) + + req := httptest.NewRequest("GET", "/api/v1/series?match[]=up", nil) + rr := httptest.NewRecorder() + wrapper.ServeHTTP(rr, req) + assert.Equal(t, http.StatusOK, rr.Code) +} + +func TestAPITrackerAboveMaxConcurrency(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "api-tracker-test") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + + tracker := NewRequestTracker(tmpDir, "apis.active", 2, logger) + require.NotNil(t, tracker) + defer tracker.Close() + ctx := context.Background() + + index1, err := tracker.Insert(ctx, []byte{}) + require.NoError(t, err) + + index2, err := tracker.Insert(ctx, []byte{}) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) + defer cancel() + + _, err = tracker.Insert(ctx, []byte{}) + assert.Error(t, err) // Should timeout + + tracker.Delete(index1) + ctx = context.Background() + index3, err := tracker.Insert(ctx, []byte{}) + require.NoError(t, err) + + tracker.Delete(index2) + tracker.Delete(index3) +} + +func TestAPITrackerLongQueryTruncate(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "api-tracker-test") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + logger := slog.New(slog.NewTextHandler(os.Stdout, nil)) + tracker := NewRequestTracker(tmpDir, "apis.active", 10, logger) + require.NotNil(t, tracker) + defer tracker.Close() + + longQuery := strings.Repeat("metric_name{label=\"value\"} or ", maxEntrySize*2) + "final_metric" + req := httptest.NewRequest("GET", "/api/v1/query", nil) + q := req.URL.Query() + q.Add("query", longQuery) + req.URL.RawQuery = q.Encode() + + extractor := &InstantQueryExtractor{} + extractedData := extractor.Extract(req) + + require.NotEmpty(t, extractedData) + assert.True(t, len(extractedData) > 0) + assert.LessOrEqual(t, len(extractedData), maxEntrySize) + assert.Contains(t, string(extractedData), "metric_name") + assert.NotContains(t, string(extractedData), "final_metric") +} diff --git a/pkg/util/request_tracker/request_wrapper.go b/pkg/util/request_tracker/request_wrapper.go new file mode 100644 index 00000000000..d7023239261 --- /dev/null +++ b/pkg/util/request_tracker/request_wrapper.go @@ -0,0 +1,30 @@ +package request_tracker + +import ( + "net/http" +) + +type RequestWrapper struct { + handler http.Handler + requestTracker *RequestTracker + extractor Extractor +} + +func NewRequestWrapper(handler http.Handler, requestTracker *RequestTracker, extractor Extractor) *RequestWrapper { + return &RequestWrapper{ + handler: handler, + requestTracker: requestTracker, + extractor: extractor, + } +} + +func (w *RequestWrapper) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + if w.requestTracker != nil { + insertIndex, err := w.requestTracker.Insert(r.Context(), w.extractor.Extract(r)) + if err == nil { + defer w.requestTracker.Delete(insertIndex) + } + } + + w.handler.ServeHTTP(rw, r) +}