diff --git a/pkg/inference/models/http_handler.go b/pkg/inference/models/http_handler.go index d8eef014..ca5dbf21 100644 --- a/pkg/inference/models/http_handler.go +++ b/pkg/inference/models/http_handler.go @@ -23,6 +23,21 @@ import ( "github.com/docker/model-runner/pkg/middleware" ) +// parseBoolQueryParam parses a boolean query parameter from the request. +// Returns the parsed value, or false if the parameter is absent or unparseable +// (logging a warning in the latter case). +func parseBoolQueryParam(r *http.Request, log logging.Logger, name string) bool { + if !r.URL.Query().Has(name) { + return false + } + val, err := strconv.ParseBool(r.URL.Query().Get(name)) + if err != nil { + log.Warn("error while parsing query parameter", "param", name, "error", err) + return false + } + return val +} + // HTTPHandler manages inference model pulls and storage. type HTTPHandler struct { // log is the associated logger. @@ -195,16 +210,7 @@ func (h *HTTPHandler) handleGetModel(w http.ResponseWriter, r *http.Request) { } func (h *HTTPHandler) handleGetModelByRef(w http.ResponseWriter, r *http.Request, modelRef string) { - // Parse remote query parameter - remote := false - if r.URL.Query().Has("remote") { - val, err := strconv.ParseBool(r.URL.Query().Get("remote")) - if err != nil { - h.log.Warn("error while parsing remote query parameter", "error", err) - } else { - remote = val - } - } + remote := parseBoolQueryParam(r, h.log, "remote") var ( apiModel *Model @@ -309,14 +315,7 @@ func (h *HTTPHandler) handleDeleteModel(w http.ResponseWriter, r *http.Request) modelRef := r.PathValue("name") - var force bool - if r.URL.Query().Has("force") { - if val, err := strconv.ParseBool(r.URL.Query().Get("force")); err != nil { - h.log.Warn("error while parsing force query parameter", "error", err) - } else { - force = val - } - } + force := parseBoolQueryParam(r, h.log, "force") // First try to delete without normalization (as ID), then with normalization if not found resp, err := h.manager.Delete(modelRef, force) diff --git a/pkg/inference/scheduling/http_handler.go b/pkg/inference/scheduling/http_handler.go index a9f3077b..93738bf8 100644 --- a/pkg/inference/scheduling/http_handler.go +++ b/pkg/inference/scheduling/http_handler.go @@ -23,6 +23,23 @@ import ( type contextKey bool +// readRequestBody reads up to maxSize bytes from the request body and writes +// an appropriate HTTP error if reading fails. Returns (body, true) on success +// or (nil, false) after writing the error response. +func readRequestBody(w http.ResponseWriter, r *http.Request, maxSize int64) ([]byte, bool) { + body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maxSize)) + if err != nil { + var maxBytesError *http.MaxBytesError + if errors.As(err, &maxBytesError) { + http.Error(w, "request too large", http.StatusBadRequest) + } else { + http.Error(w, "failed to read request body", http.StatusInternalServerError) + } + return nil, false + } + return body, true +} + const preloadOnlyKey contextKey = false // HTTPHandler handles HTTP requests for the scheduler. @@ -132,14 +149,8 @@ func (h *HTTPHandler) handleOpenAIInference(w http.ResponseWriter, r *http.Reque // Read the entire request body. We put some basic size constraints in place // to avoid DoS attacks. We do this early to avoid client write timeouts. - body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maximumOpenAIInferenceRequestSize)) - if err != nil { - var maxBytesError *http.MaxBytesError - if errors.As(err, &maxBytesError) { - http.Error(w, "request too large", http.StatusBadRequest) - } else { - http.Error(w, "failed to read request body", http.StatusInternalServerError) - } + body, ok := readRequestBody(w, r, maximumOpenAIInferenceRequestSize) + if !ok { return } @@ -338,14 +349,8 @@ func (h *HTTPHandler) GetDiskUsage(w http.ResponseWriter, _ *http.Request) { // Unload unloads the specified runners (backend, model) from the backend. // Currently, this doesn't work for runners that are handling an OpenAI request. func (h *HTTPHandler) Unload(w http.ResponseWriter, r *http.Request) { - body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maximumOpenAIInferenceRequestSize)) - if err != nil { - var maxBytesError *http.MaxBytesError - if errors.As(err, &maxBytesError) { - http.Error(w, "request too large", http.StatusBadRequest) - } else { - http.Error(w, "failed to read request body", http.StatusInternalServerError) - } + body, ok := readRequestBody(w, r, maximumOpenAIInferenceRequestSize) + if !ok { return } @@ -371,14 +376,8 @@ type installBackendRequest struct { // InstallBackend handles POST /install-backend requests. // It triggers on-demand installation of a deferred backend. func (h *HTTPHandler) InstallBackend(w http.ResponseWriter, r *http.Request) { - body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maximumOpenAIInferenceRequestSize)) - if err != nil { - var maxBytesError *http.MaxBytesError - if errors.As(err, &maxBytesError) { - http.Error(w, "request too large", http.StatusBadRequest) - } else { - http.Error(w, "failed to read request body", http.StatusInternalServerError) - } + body, ok := readRequestBody(w, r, maximumOpenAIInferenceRequestSize) + if !ok { return } @@ -414,14 +413,8 @@ func (h *HTTPHandler) Configure(w http.ResponseWriter, r *http.Request) { return } - body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maximumOpenAIInferenceRequestSize)) - if err != nil { - var maxBytesError *http.MaxBytesError - if errors.As(err, &maxBytesError) { - http.Error(w, "request too large", http.StatusBadRequest) - } else { - http.Error(w, "failed to read request body", http.StatusInternalServerError) - } + body, ok := readRequestBody(w, r, maximumOpenAIInferenceRequestSize) + if !ok { return } @@ -433,7 +426,7 @@ func (h *HTTPHandler) Configure(w http.ResponseWriter, r *http.Request) { return } - backend, err = h.scheduler.ConfigureRunner(r.Context(), backend, configureRequest, r.UserAgent()) + backend, err := h.scheduler.ConfigureRunner(r.Context(), backend, configureRequest, r.UserAgent()) if err != nil { if errors.Is(err, errRunnerAlreadyActive) { http.Error(w, err.Error(), http.StatusConflict)