Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/middleware/requestid.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ app.Use(requestid.New(requestid.Config{
}))
```

If the request already includes the configured header, that value is reused instead of generating a new one.
If the request already includes the configured header, that value is reused instead of generating a new one. The middleware
rejects IDs containing characters outside the visible ASCII range (for example, control characters or obs-text bytes) and
will regenerate the value using the configured generator or a UUID to keep headers RFC-compliant across transports.

Retrieve the request ID

Expand Down
55 changes: 52 additions & 3 deletions middleware/requestid/requestid.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package requestid

import (
"github.com/gofiber/fiber/v3"
"github.com/gofiber/utils/v2"
)

// The contextKey type is unexported to prevent collisions with context keys defined in
Expand All @@ -24,10 +25,9 @@ func New(config ...Config) fiber.Handler {
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Get id from request, else we generate one
rid := c.Get(cfg.Header)
rid := sanitizeRequestID(c.Get(cfg.Header), cfg.Generator)
if rid == "" {
rid = cfg.Generator()
rid = utils.UUID()
}

// Set new id to response header
Expand All @@ -41,6 +41,55 @@ func New(config ...Config) fiber.Handler {
}
}

// sanitizeRequestID returns the provided request ID when it is valid, otherwise
// it tries up to three values from the configured generator, then three UUIDs,
// falling back to an empty string when no visible ASCII ID is produced.
func sanitizeRequestID(rid string, generator func() string) string {
if isValidRequestID(rid) {
return rid
}

generatorFn := generator
if generatorFn == nil {
generatorFn = utils.UUID
}

for range 3 {
rid = generatorFn()
if isValidRequestID(rid) {
return rid
}
}

if generator != nil {
for range 3 {
rid = utils.UUID()
if isValidRequestID(rid) {
return rid
}
}
}

return ""
}

// isValidRequestID reports whether the request ID contains only visible ASCII
// characters (0x20–0x7E) and is non-empty.
func isValidRequestID(rid string) bool {
if rid == "" {
return false
}

for i := 0; i < len(rid); i++ {
c := rid[i]
if c < 0x20 || c > 0x7e {
return false
}
}

return true
}

// FromContext returns the request ID from context.
// If there is no request ID, an empty string is returned.
func FromContext(c fiber.Ctx) string {
Expand Down
46 changes: 46 additions & 0 deletions middleware/requestid/requestid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,52 @@ func Test_RequestID(t *testing.T) {
require.Equal(t, reqid, resp.Header.Get(fiber.HeaderXRequestID))
}

func Test_RequestID_InvalidHeaderValue(t *testing.T) {
t.Parallel()

rid := sanitizeRequestID("bad\r\nid", func() string {
return "clean-generated-id"
})

require.Equal(t, "clean-generated-id", rid)
}

func Test_RequestID_InvalidGeneratedValue(t *testing.T) {
t.Parallel()

app := fiber.New()
app.Use(New(Config{
Generator: func() string {
return "bad\r\nid"
},
}))

app.Get("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})

resp, err := app.Test(httptest.NewRequest(fiber.MethodGet, "/", http.NoBody))
require.NoError(t, err)
require.Equal(t, fiber.StatusOK, resp.StatusCode)

rid := resp.Header.Get(fiber.HeaderXRequestID)
require.NotEmpty(t, rid)
require.NotContains(t, rid, "\r")
require.NotContains(t, rid, "\n")
}

func Test_isValidRequestID_VisibleASCII(t *testing.T) {
t.Parallel()

require.True(t, isValidRequestID("request-id-09AZaz ~"))
}

func Test_isValidRequestID_RejectsObsText(t *testing.T) {
t.Parallel()

require.False(t, isValidRequestID("valid\xff"))
}

// go test -run Test_RequestID_Next
func Test_RequestID_Next(t *testing.T) {
t.Parallel()
Expand Down
Loading