Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions middleware/csrf.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,42 @@ func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
return nil, cErr
}

// setTokenInContext retrieves the existing CSRF token from the cookie or generates a new one,
// then sets the cookie and stores the token in the context so handlers can access it
// (e.g. to render it in forms).
setTokenInContext := func(c *echo.Context) {
token := ""
if k, err := c.Cookie(config.CookieName); err != nil {
token = config.Generator() // Generate token
} else {
token = k.Value // Reuse token
}

// Set CSRF cookie
cookie := new(http.Cookie)
cookie.Name = config.CookieName
cookie.Value = token
if config.CookiePath != "" {
cookie.Path = config.CookiePath
}
if config.CookieDomain != "" {
cookie.Domain = config.CookieDomain
}
if config.CookieSameSite != http.SameSiteDefaultMode {
cookie.SameSite = config.CookieSameSite
}
cookie.Expires = time.Now().Add(time.Duration(config.CookieMaxAge) * time.Second)
cookie.Secure = config.CookieSecure
cookie.HttpOnly = config.CookieHTTPOnly
c.SetCookie(cookie)

// Store token in the context
c.Set(config.ContextKey, token)

// Protect clients from caching the response
c.Response().Header().Add(echo.HeaderVary, echo.HeaderCookie)
}

return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c *echo.Context) error {
if config.Skipper(c) {
Expand All @@ -164,6 +200,10 @@ func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) {
return err
}
if allow {
// Even though the Sec-Fetch-Site check passed, we still need to set
// the CSRF token in the context and cookie so that handlers can access
// the token (e.g. to render it in forms for subsequent POST requests).
setTokenInContext(c)
return next(c)
}

Expand Down
95 changes: 95 additions & 0 deletions middleware/csrf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ func TestCSRFWithConfig(t *testing.T) {
echo.HeaderSecFetchSite: "same-origin",
},
whenMethod: http.MethodPost,
expectCookieContains: "_csrf",
},
{
name: "nok, unsafe method + SecFetchSite=same-cross blocked",
Expand All @@ -298,6 +299,22 @@ func TestCSRFWithConfig(t *testing.T) {
expectEmptyBody: true,
expectErr: `code=403, message=cross-site request blocked by CSRF`,
},
{
name: "ok, safe GET + SecFetchSite=none sets token and cookie",
whenHeaders: map[string]string{
echo.HeaderSecFetchSite: "none",
},
whenMethod: http.MethodGet,
expectCookieContains: "_csrf",
},
{
name: "ok, safe GET + SecFetchSite=same-origin sets token and cookie",
whenHeaders: map[string]string{
echo.HeaderSecFetchSite: "same-origin",
},
whenMethod: http.MethodGet,
expectCookieContains: "_csrf",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
Expand Down Expand Up @@ -852,3 +869,81 @@ func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) {
})
}
}

func TestCSRF_SecFetchSiteSetsTokenInContext(t *testing.T) {
// Regression test for https://github.com/labstack/echo/issues/2874
// When Sec-Fetch-Site validation passes (e.g. direct navigation with "none"),
// the middleware must still set the CSRF token in context and cookie so that
// handlers can render forms with the token for subsequent POST requests.
var testCases = []struct {
name string
whenMethod string
whenSecFetchSite string
whenExistingCookie string
}{
{
name: "GET with Sec-Fetch-Site: none (direct navigation)",
whenMethod: http.MethodGet,
whenSecFetchSite: "none",
},
{
name: "GET with Sec-Fetch-Site: same-origin",
whenMethod: http.MethodGet,
whenSecFetchSite: "same-origin",
},
{
name: "POST with Sec-Fetch-Site: same-origin",
whenMethod: http.MethodPost,
whenSecFetchSite: "same-origin",
},
{
name: "GET with Sec-Fetch-Site: none reuses existing cookie token",
whenMethod: http.MethodGet,
whenSecFetchSite: "none",
whenExistingCookie: "existing_token_value",
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
e := echo.New()

req := httptest.NewRequest(tc.whenMethod, "/", nil)
req.Header.Set(echo.HeaderSecFetchSite, tc.whenSecFetchSite)
if tc.whenExistingCookie != "" {
req.Header.Set(echo.HeaderCookie, "_csrf="+tc.whenExistingCookie)
}
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)

csrf := CSRF()
var contextToken string
h := csrf(func(c *echo.Context) error {
token, ok := c.Get("csrf").(string)
if !ok {
t.Fatal("CSRF token not found in context")
}
contextToken = token
return c.String(http.StatusOK, "test")
})

err := h(c)
assert.NoError(t, err)

// Token must be set in context
assert.NotEmpty(t, contextToken, "CSRF token should be set in context")

// Cookie must be set in response
setCookie := rec.Header().Get(echo.HeaderSetCookie)
assert.Contains(t, setCookie, "_csrf", "CSRF cookie should be set in response")

// Vary header must include Cookie
assert.Contains(t, rec.Header().Get(echo.HeaderVary), echo.HeaderCookie, "Vary header should include Cookie")

// If there was an existing cookie, the token should match
if tc.whenExistingCookie != "" {
assert.Equal(t, tc.whenExistingCookie, contextToken, "should reuse existing cookie token")
}
})
}
}