diff --git a/middleware/csrf.go b/middleware/csrf.go index 33757b760..b163da796 100644 --- a/middleware/csrf.go +++ b/middleware/csrf.go @@ -152,6 +152,42 @@ func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) { return nil, cErr } + // setTokenInContext retrieves the existing CSRF token from the cookie or generates a new one, + // then sets the cookie and stores the token in the context so handlers can access it + // (e.g. to render it in forms). + setTokenInContext := func(c *echo.Context) { + token := "" + if k, err := c.Cookie(config.CookieName); err != nil { + token = config.Generator() // Generate token + } else { + token = k.Value // Reuse token + } + + // Set CSRF cookie + cookie := new(http.Cookie) + cookie.Name = config.CookieName + cookie.Value = token + if config.CookiePath != "" { + cookie.Path = config.CookiePath + } + if config.CookieDomain != "" { + cookie.Domain = config.CookieDomain + } + if config.CookieSameSite != http.SameSiteDefaultMode { + cookie.SameSite = config.CookieSameSite + } + cookie.Expires = time.Now().Add(time.Duration(config.CookieMaxAge) * time.Second) + cookie.Secure = config.CookieSecure + cookie.HttpOnly = config.CookieHTTPOnly + c.SetCookie(cookie) + + // Store token in the context + c.Set(config.ContextKey, token) + + // Protect clients from caching the response + c.Response().Header().Add(echo.HeaderVary, echo.HeaderCookie) + } + return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c *echo.Context) error { if config.Skipper(c) { @@ -164,6 +200,10 @@ func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) { return err } if allow { + // Even though the Sec-Fetch-Site check passed, we still need to set + // the CSRF token in the context and cookie so that handlers can access + // the token (e.g. to render it in forms for subsequent POST requests). + setTokenInContext(c) return next(c) } diff --git a/middleware/csrf_test.go b/middleware/csrf_test.go index ddecc10e3..40bd7aba2 100644 --- a/middleware/csrf_test.go +++ b/middleware/csrf_test.go @@ -288,6 +288,7 @@ func TestCSRFWithConfig(t *testing.T) { echo.HeaderSecFetchSite: "same-origin", }, whenMethod: http.MethodPost, + expectCookieContains: "_csrf", }, { name: "nok, unsafe method + SecFetchSite=same-cross blocked", @@ -298,6 +299,22 @@ func TestCSRFWithConfig(t *testing.T) { expectEmptyBody: true, expectErr: `code=403, message=cross-site request blocked by CSRF`, }, + { + name: "ok, safe GET + SecFetchSite=none sets token and cookie", + whenHeaders: map[string]string{ + echo.HeaderSecFetchSite: "none", + }, + whenMethod: http.MethodGet, + expectCookieContains: "_csrf", + }, + { + name: "ok, safe GET + SecFetchSite=same-origin sets token and cookie", + whenHeaders: map[string]string{ + echo.HeaderSecFetchSite: "same-origin", + }, + whenMethod: http.MethodGet, + expectCookieContains: "_csrf", + }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { @@ -852,3 +869,81 @@ func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) { }) } } + +func TestCSRF_SecFetchSiteSetsTokenInContext(t *testing.T) { + // Regression test for https://github.com/labstack/echo/issues/2874 + // When Sec-Fetch-Site validation passes (e.g. direct navigation with "none"), + // the middleware must still set the CSRF token in context and cookie so that + // handlers can render forms with the token for subsequent POST requests. + var testCases = []struct { + name string + whenMethod string + whenSecFetchSite string + whenExistingCookie string + }{ + { + name: "GET with Sec-Fetch-Site: none (direct navigation)", + whenMethod: http.MethodGet, + whenSecFetchSite: "none", + }, + { + name: "GET with Sec-Fetch-Site: same-origin", + whenMethod: http.MethodGet, + whenSecFetchSite: "same-origin", + }, + { + name: "POST with Sec-Fetch-Site: same-origin", + whenMethod: http.MethodPost, + whenSecFetchSite: "same-origin", + }, + { + name: "GET with Sec-Fetch-Site: none reuses existing cookie token", + whenMethod: http.MethodGet, + whenSecFetchSite: "none", + whenExistingCookie: "existing_token_value", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + req := httptest.NewRequest(tc.whenMethod, "/", nil) + req.Header.Set(echo.HeaderSecFetchSite, tc.whenSecFetchSite) + if tc.whenExistingCookie != "" { + req.Header.Set(echo.HeaderCookie, "_csrf="+tc.whenExistingCookie) + } + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + csrf := CSRF() + var contextToken string + h := csrf(func(c *echo.Context) error { + token, ok := c.Get("csrf").(string) + if !ok { + t.Fatal("CSRF token not found in context") + } + contextToken = token + return c.String(http.StatusOK, "test") + }) + + err := h(c) + assert.NoError(t, err) + + // Token must be set in context + assert.NotEmpty(t, contextToken, "CSRF token should be set in context") + + // Cookie must be set in response + setCookie := rec.Header().Get(echo.HeaderSetCookie) + assert.Contains(t, setCookie, "_csrf", "CSRF cookie should be set in response") + + // Vary header must include Cookie + assert.Contains(t, rec.Header().Get(echo.HeaderVary), echo.HeaderCookie, "Vary header should include Cookie") + + // If there was an existing cookie, the token should match + if tc.whenExistingCookie != "" { + assert.Equal(t, tc.whenExistingCookie, contextToken, "should reuse existing cookie token") + } + }) + } +}