From 5b2c796f8f55b943efb2edc30fd4b8da3f29e6c4 Mon Sep 17 00:00:00 2001 From: Juan Calderon-Perez <835733+gaby@users.noreply.github.com> Date: Sun, 28 Sep 2025 09:49:57 -0400 Subject: [PATCH 1/3] Pool service state key slices --- app_test.go | 18 +- binder/binder.go | 79 +++++ binder/cookie.go | 3 +- binder/form.go | 15 +- binder/header.go | 4 +- binder/query.go | 4 +- binder/resp_header.go | 3 +- binder/uri.go | 4 +- client/cookiejar.go | 147 +++++++++- client/cookiejar_test.go | 92 ++++++ client/core.go | 40 ++- client/hooks.go | 68 ++++- client/request.go | 177 ++++++++++-- client/request_test.go | 26 ++ client/response.go | 21 +- client/response_test.go | 18 ++ ctx_interface_gen.go | 1 + ctx_test.go | 64 ++++- helpers.go | 400 ++++++++++++++++++++++---- helpers_test.go | 121 ++++++-- middleware/adaptor/adaptor.go | 7 +- middleware/cache/cache.go | 20 +- middleware/cache/cache_test.go | 30 ++ middleware/cache/manager.go | 22 +- middleware/encryptcookie/utils.go | 65 ++++- middleware/idempotency/idempotency.go | 37 ++- middleware/idempotency/response.go | 64 +++++ middleware/keyauth/keyauth.go | 107 +++++-- middleware/limiter/manager.go | 8 +- middleware/logger/logger.go | 30 +- middleware/logger/tags.go | 114 +++++++- middleware/logger/tags_test.go | 52 ++++ middleware/rewrite/rewrite.go | 66 ++++- middleware/rewrite/rewrite_test.go | 17 ++ middleware/session/data.go | 19 +- middleware/session/session.go | 47 ++- middleware/static/static.go | 52 +++- path.go | 110 ++++++- path_test.go | 80 ++++++ redirect.go | 121 +++++++- redirect_test.go | 24 ++ req.go | 158 +++++++++- req_interface_gen.go | 1 + res.go | 147 +++++++++- router.go | 40 ++- services.go | 75 ++++- services_test.go | 26 ++ state.go | 69 ++++- state_test.go | 14 +- 49 files changed, 2633 insertions(+), 294 deletions(-) create mode 100644 middleware/logger/tags_test.go diff --git a/app_test.go b/app_test.go index c968729ab10..41692610510 100644 --- a/app_test.go +++ b/app_test.go @@ -290,14 +290,16 @@ func Test_App_BodyLimit_Negative(t *testing.T) { return c.SendStatus(StatusOK) }) + cfg := TestConfig{Timeout: 0, FailOnTimeout: false} + largeBody := bytes.Repeat([]byte{'a'}, DefaultBodyLimit+1) req := httptest.NewRequest(MethodPost, "/", bytes.NewReader(largeBody)) - _, err := app.Test(req) + _, err := app.Test(req, cfg) require.ErrorIs(t, err, fasthttp.ErrBodyTooLarge) smallBody := bytes.Repeat([]byte{'a'}, DefaultBodyLimit-1) req = httptest.NewRequest(MethodPost, "/", bytes.NewReader(smallBody)) - resp, err := app.Test(req) + resp, err := app.Test(req, cfg) require.NoError(t, err) require.Equal(t, StatusOK, resp.StatusCode) } @@ -314,12 +316,14 @@ func Test_App_BodyLimit_Zero(t *testing.T) { largeBody := bytes.Repeat([]byte{'a'}, DefaultBodyLimit+1) req := httptest.NewRequest(MethodPost, "/", bytes.NewReader(largeBody)) - _, err := app.Test(req) + timeoutCfg := TestConfig{Timeout: 5 * time.Second} + + _, err := app.Test(req, timeoutCfg) require.ErrorIs(t, err, fasthttp.ErrBodyTooLarge) smallBody := bytes.Repeat([]byte{'a'}, DefaultBodyLimit-1) req = httptest.NewRequest(MethodPost, "/", bytes.NewReader(smallBody)) - resp, err := app.Test(req) + resp, err := app.Test(req, timeoutCfg) require.NoError(t, err) require.Equal(t, StatusOK, resp.StatusCode) } @@ -334,17 +338,19 @@ func Test_App_BodyLimit_LargerThanDefault(t *testing.T) { return c.SendStatus(StatusOK) }) + timeoutCfg := TestConfig{Timeout: 5 * time.Second} + // Body larger than the default but within our custom limit should succeed midBody := bytes.Repeat([]byte{'a'}, DefaultBodyLimit+512) req := httptest.NewRequest(MethodPost, "/", bytes.NewReader(midBody)) - resp, err := app.Test(req) + resp, err := app.Test(req, timeoutCfg) require.NoError(t, err) require.Equal(t, StatusOK, resp.StatusCode) // Body above the custom limit should fail largeBody := bytes.Repeat([]byte{'a'}, limit+1) req = httptest.NewRequest(MethodPost, "/", bytes.NewReader(largeBody)) - _, err = app.Test(req) + _, err = app.Test(req, timeoutCfg) require.ErrorIs(t, err, fasthttp.ErrBodyTooLarge) } diff --git a/binder/binder.go b/binder/binder.go index 65ad18fba83..e7f6b727e84 100644 --- a/binder/binder.go +++ b/binder/binder.go @@ -2,6 +2,7 @@ package binder import ( "errors" + "mime/multipart" "sync" ) @@ -71,6 +72,28 @@ var MsgPackBinderPool = sync.Pool{ }, } +const ( + stringSliceMapDefaultCap = 8 + stringSliceMapMaxEntries = 128 +) + +var stringSliceMapPool = sync.Pool{ + New: func() any { + return make(map[string][]string, stringSliceMapDefaultCap) + }, +} + +const ( + fileHeaderSliceMapDefaultCap = 4 + fileHeaderSliceMapMaxEntries = 64 +) + +var fileHeaderSliceMapPool = sync.Pool{ + New: func() any { + return make(map[string][]*multipart.FileHeader, fileHeaderSliceMapDefaultCap) + }, +} + // GetFromThePool retrieves a binder from the provided sync.Pool and panics if // the stored value cannot be cast to the requested type. func GetFromThePool[T any](pool *sync.Pool) T { @@ -86,3 +109,59 @@ func GetFromThePool[T any](pool *sync.Pool) T { func PutToThePool[T any](pool *sync.Pool, binder T) { pool.Put(binder) } + +func acquireStringSliceMap() map[string][]string { + m, ok := stringSliceMapPool.Get().(map[string][]string) + if !ok { + panic(errors.New("failed to type-assert to map[string][]string")) + } + if m == nil { + return make(map[string][]string, stringSliceMapDefaultCap) + } + if len(m) > 0 { + clear(m) + } + return m +} + +func releaseStringSliceMap(m map[string][]string) { + if m == nil { + return + } + used := len(m) + if used > 0 { + clear(m) + } + if used > stringSliceMapMaxEntries { + return + } + stringSliceMapPool.Put(m) +} + +func acquireFileHeaderSliceMap() map[string][]*multipart.FileHeader { + m, ok := fileHeaderSliceMapPool.Get().(map[string][]*multipart.FileHeader) + if !ok { + panic(errors.New("failed to type-assert to map[string][]*multipart.FileHeader")) + } + if m == nil { + return make(map[string][]*multipart.FileHeader, fileHeaderSliceMapDefaultCap) + } + if len(m) > 0 { + clear(m) + } + return m +} + +func releaseFileHeaderSliceMap(m map[string][]*multipart.FileHeader) { + if m == nil { + return + } + used := len(m) + if used > 0 { + clear(m) + } + if used > fileHeaderSliceMapMaxEntries { + return + } + fileHeaderSliceMapPool.Put(m) +} diff --git a/binder/cookie.go b/binder/cookie.go index edcd7959cf8..190876fe3d6 100644 --- a/binder/cookie.go +++ b/binder/cookie.go @@ -17,7 +17,8 @@ func (*CookieBinding) Name() string { // Bind parses the request cookie and returns the result. func (b *CookieBinding) Bind(req *fasthttp.Request, out any) error { - data := make(map[string][]string) + data := acquireStringSliceMap() + defer releaseStringSliceMap(data) for key, val := range req.Header.Cookies() { k := utils.UnsafeString(key) diff --git a/binder/form.go b/binder/form.go index c0365e5b8d5..5fe24cd7ced 100644 --- a/binder/form.go +++ b/binder/form.go @@ -1,8 +1,6 @@ package binder import ( - "mime/multipart" - utils "github.com/gofiber/utils/v2" "github.com/valyala/fasthttp" ) @@ -21,13 +19,14 @@ func (*FormBinding) Name() string { // Bind parses the request body and returns the result. func (b *FormBinding) Bind(req *fasthttp.Request, out any) error { - data := make(map[string][]string) - // Handle multipart form if FilterFlags(utils.UnsafeString(req.Header.ContentType())) == MIMEMultipartForm { return b.bindMultipart(req, out) } + data := acquireStringSliceMap() + defer releaseStringSliceMap(data) + for key, val := range req.PostArgs().All() { k := utils.UnsafeString(key) v := utils.UnsafeString(val) @@ -46,7 +45,12 @@ func (b *FormBinding) bindMultipart(req *fasthttp.Request, out any) error { return err } - data := make(map[string][]string) + data := acquireStringSliceMap() + defer releaseStringSliceMap(data) + + files := acquireFileHeaderSliceMap() + defer releaseFileHeaderSliceMap(files) + for key, values := range multipartForm.Value { err = formatBindData(b.Name(), out, data, key, values, b.EnableSplitting, true) if err != nil { @@ -54,7 +58,6 @@ func (b *FormBinding) bindMultipart(req *fasthttp.Request, out any) error { } } - files := make(map[string][]*multipart.FileHeader) for key, values := range multipartForm.File { err = formatBindData(b.Name(), out, files, key, values, b.EnableSplitting, true) if err != nil { diff --git a/binder/header.go b/binder/header.go index 5150250e9ea..b3f834e2039 100644 --- a/binder/header.go +++ b/binder/header.go @@ -17,7 +17,9 @@ func (*HeaderBinding) Name() string { // Bind parses the request header and returns the result. func (b *HeaderBinding) Bind(req *fasthttp.Request, out any) error { - data := make(map[string][]string) + data := acquireStringSliceMap() + defer releaseStringSliceMap(data) + for key, val := range req.Header.All() { k := utils.UnsafeString(key) v := utils.UnsafeString(val) diff --git a/binder/query.go b/binder/query.go index 69a96214d80..7d335c0fede 100644 --- a/binder/query.go +++ b/binder/query.go @@ -17,7 +17,9 @@ func (*QueryBinding) Name() string { // Bind parses the request query and returns the result. func (b *QueryBinding) Bind(reqCtx *fasthttp.Request, out any) error { - data := make(map[string][]string) + data := acquireStringSliceMap() + defer releaseStringSliceMap(data) + var err error for key, val := range reqCtx.URI().QueryArgs().All() { diff --git a/binder/resp_header.go b/binder/resp_header.go index b0bd5a4b7f5..5422bdf592a 100644 --- a/binder/resp_header.go +++ b/binder/resp_header.go @@ -17,7 +17,8 @@ func (*RespHeaderBinding) Name() string { // Bind parses the response header and returns the result. func (b *RespHeaderBinding) Bind(resp *fasthttp.Response, out any) error { - data := make(map[string][]string) + data := acquireStringSliceMap() + defer releaseStringSliceMap(data) for key, val := range resp.Header.All() { k := utils.UnsafeString(key) diff --git a/binder/uri.go b/binder/uri.go index e02143638cc..ec39c4643ba 100644 --- a/binder/uri.go +++ b/binder/uri.go @@ -10,7 +10,9 @@ func (*URIBinding) Name() string { // Bind parses the URI parameters and returns the result. func (b *URIBinding) Bind(params []string, paramsFunc func(key string, defaultValue ...string) string, out any) error { - data := make(map[string][]string, len(params)) + data := acquireStringSliceMap() + defer releaseStringSliceMap(data) + for _, param := range params { data[param] = append(data[param], paramsFunc(param)) } diff --git a/client/cookiejar.go b/client/cookiejar.go index f9dd306f2ca..4d4623aa27c 100644 --- a/client/cookiejar.go +++ b/client/cookiejar.go @@ -13,12 +13,29 @@ import ( "github.com/valyala/fasthttp" ) +const ( + cookieJarHostDefaultCap = 8 + cookieJarHostMaxEntries = 64 +) + +const ( + cookieJarMatchDefaultCap = 4 + cookieJarMatchMaxCap = 128 +) + var cookieJarPool = sync.Pool{ New: func() any { return &CookieJar{} }, } +var cookieJarMatchPool = sync.Pool{ + New: func() any { + slice := make([]*fasthttp.Cookie, 0, cookieJarMatchDefaultCap) + return &slice + }, +} + // AcquireCookieJar returns an empty CookieJar object from the pool. func AcquireCookieJar() *CookieJar { jar, ok := cookieJarPool.Get().(*CookieJar) @@ -71,7 +88,26 @@ func (cj *CookieJar) getByHostAndPath(host, path []byte, secure bool) []*fasthtt if err != nil { hostStr = utils.UnsafeString(host) } - return cj.cookiesForRequest(hostStr, path, secure) + matches, _ := cj.collectCookiesForRequest(nil, hostStr, path, secure) + return matches +} + +func (cj *CookieJar) borrowCookiesByHostAndPath(host, path []byte, secure bool) ([]*fasthttp.Cookie, *[]*fasthttp.Cookie) { + if cj.hostCookies == nil { + return nil, nil + } + + var ( + err error + hostStr = utils.UnsafeString(host) + ) + + hostStr, _, err = net.SplitHostPort(hostStr) + if err != nil { + hostStr = utils.UnsafeString(host) + } + + return cj.borrowCookiesForRequest(hostStr, path, secure) } // getCookiesByHost returns cookies stored for a specific host, removing any that have expired. @@ -99,12 +135,69 @@ func (cj *CookieJar) getCookiesByHost(host string) []*fasthttp.Cookie { // cookiesForRequest returns cookies that match the given host, path and security settings. // //nolint:revive // secure is required to filter Secure cookies based on scheme +func acquireCookieMatches() *[]*fasthttp.Cookie { + sliceAny := cookieJarMatchPool.Get() + matchesPtr, ok := sliceAny.(*[]*fasthttp.Cookie) + if !ok { + panic(errors.New("failed to type-assert to *[]*fasthttp.Cookie")) + } + + matches := *matchesPtr + if len(matches) > 0 { + matches = matches[:0] + } + *matchesPtr = matches + + return matchesPtr +} + +func releaseCookieMatches(matchesPtr *[]*fasthttp.Cookie) { + if matchesPtr == nil { + return + } + + matches := *matchesPtr + for i := range matches { + matches[i] = nil + } + + if cap(matches) > cookieJarMatchMaxCap { + *matchesPtr = make([]*fasthttp.Cookie, 0, cookieJarMatchDefaultCap) + } else { + *matchesPtr = matches[:0] + } + + cookieJarMatchPool.Put(matchesPtr) +} + func (cj *CookieJar) cookiesForRequest(host string, path []byte, secure bool) []*fasthttp.Cookie { + matches, _ := cj.collectCookiesForRequest(nil, host, path, secure) + return matches +} + +func (cj *CookieJar) borrowCookiesForRequest(host string, path []byte, secure bool) ([]*fasthttp.Cookie, *[]*fasthttp.Cookie) { + matchesPtr := acquireCookieMatches() + matches, ptr := cj.collectCookiesForRequest(matchesPtr, host, path, secure) + return matches, ptr +} + +func (cj *CookieJar) collectCookiesForRequest( + matchesPtr *[]*fasthttp.Cookie, + host string, + path []byte, + secure bool, +) ([]*fasthttp.Cookie, *[]*fasthttp.Cookie) { cj.mu.Lock() defer cj.mu.Unlock() + var matches []*fasthttp.Cookie + if matchesPtr != nil { + matches = *matchesPtr + } + + matches = matches[:0] + now := time.Now() - var matched []*fasthttp.Cookie for domain, cookies := range cj.hostCookies { if !domainMatch(host, domain) { @@ -127,12 +220,16 @@ func (cj *CookieJar) cookiesForRequest(host string, path []byte, secure bool) [] } nc := fasthttp.AcquireCookie() nc.CopyTo(c) - matched = append(matched, nc) + matches = append(matches, nc) } cj.hostCookies[domain] = kept } - return matched + if matchesPtr != nil { + *matchesPtr = matches + } + + return matches, matchesPtr } // Set stores the given cookies for the specified URI host. If a cookie key already exists, @@ -214,7 +311,9 @@ func (cj *CookieJar) SetKeyValueBytes(host string, key, value []byte) { func (cj *CookieJar) dumpCookiesToReq(req *fasthttp.Request) { uri := req.URI() secure := bytes.Equal(uri.Scheme(), []byte("https")) - cookies := cj.getByHostAndPath(uri.Host(), uri.Path(), secure) + cookies, matchesPtr := cj.borrowCookiesByHostAndPath(uri.Host(), uri.Path(), secure) + defer releaseCookieMatches(matchesPtr) + for _, cookie := range cookies { req.Header.SetCookieBytesKV(cookie.Key(), cookie.Value()) fasthttp.ReleaseCookie(cookie) @@ -278,15 +377,35 @@ func (cj *CookieJar) parseCookiesFromResp(host, _ []byte, resp *fasthttp.Respons // Release releases all stored cookies. After this, the CookieJar is empty. func (cj *CookieJar) Release() { - // FOLLOW-UP performance optimization: - // Currently, a race condition is found because the reset method modifies a value - // that is not a copy but a reference. A solution would be to make a copy. - // for _, v := range cj.hostCookies { - // for _, c := range v { - // fasthttp.ReleaseCookie(c) - // } - // } - cj.hostCookies = nil + cj.mu.Lock() + defer cj.mu.Unlock() + + hostCount := len(cj.hostCookies) + if hostCount == 0 { + return + } + + for _, cookies := range cj.hostCookies { + for i, c := range cookies { + if c == nil { + continue + } + fasthttp.ReleaseCookie(c) + cookies[i] = nil + } + } + + if hostCount > cookieJarHostMaxEntries { + cj.hostCookies = nil + return + } + + if hostCount > cookieJarHostDefaultCap { + cj.hostCookies = make(map[string][]*fasthttp.Cookie, cookieJarHostDefaultCap) + return + } + + clear(cj.hostCookies) } // searchCookieByKeyAndPath looks up a cookie by its key and path from the provided slice of cookies. diff --git a/client/cookiejar_test.go b/client/cookiejar_test.go index 60c583e1bee..06ed3ec1269 100644 --- a/client/cookiejar_test.go +++ b/client/cookiejar_test.go @@ -2,6 +2,7 @@ package client import ( "bytes" + "fmt" "testing" "time" @@ -313,3 +314,94 @@ func Test_CookieJar_PathMatch(t *testing.T) { require.NoError(t, uriNoMatch.Parse(nil, []byte("http://example.com/apiv1"))) require.Empty(t, jar.Get(uriNoMatch)) } + +func Test_CookieJar_ReleaseClearsHosts(t *testing.T) { + t.Parallel() + + jar := &CookieJar{ + hostCookies: make(map[string][]*fasthttp.Cookie, 2), + } + + for i := 0; i < 2; i++ { + cookie := fasthttp.AcquireCookie() + cookie.SetKey("k") + cookie.SetValue("v") + host := fmt.Sprintf("host-%d", i) + jar.hostCookies[host] = append(jar.hostCookies[host], cookie) + } + + jar.Release() + + require.NotNil(t, jar.hostCookies) + require.Empty(t, jar.hostCookies) +} + +func Test_CookieJar_ReleaseDropsOversizedMaps(t *testing.T) { + t.Parallel() + + jar := &CookieJar{ + hostCookies: make(map[string][]*fasthttp.Cookie, cookieJarHostMaxEntries+1), + } + + for i := 0; i < cookieJarHostMaxEntries+1; i++ { + cookie := fasthttp.AcquireCookie() + cookie.SetKey("k") + cookie.SetValue("v") + host := fmt.Sprintf("oversize-%d", i) + jar.hostCookies[host] = append(jar.hostCookies[host], cookie) + } + + jar.Release() + + require.Nil(t, jar.hostCookies) +} + +func Test_releaseCookieMatchesShrinksOversizedSlices(t *testing.T) { + t.Parallel() + + matchesPtr := acquireCookieMatches() + require.NotNil(t, matchesPtr) + + // Expand the slice beyond the max capacity and populate it with placeholders. + oversized := make([]*fasthttp.Cookie, cookieJarMatchMaxCap+8) + copy(oversized, []*fasthttp.Cookie{{}, {}}) + *matchesPtr = oversized + + releaseCookieMatches(matchesPtr) + + pooledPtr := acquireCookieMatches() + require.NotNil(t, pooledPtr) + require.Len(t, *pooledPtr, 0) + require.LessOrEqual(t, cap(*pooledPtr), cookieJarMatchMaxCap) + + releaseCookieMatches(pooledPtr) +} + +func Test_CookieJar_BorrowCookiesUsesPool(t *testing.T) { + t.Parallel() + + jar := &CookieJar{} + uri := fasthttp.AcquireURI() + require.NoError(t, uri.Parse(nil, []byte("http://example.com/path"))) + + cookie := &fasthttp.Cookie{} + cookie.SetKey("k") + cookie.SetValue("v") + jar.Set(uri, cookie) + + matches, matchesPtr := jar.borrowCookiesByHostAndPath(uri.Host(), uri.Path(), false) + require.NotNil(t, matchesPtr) + require.Len(t, matches, 1) + + for _, c := range matches { + fasthttp.ReleaseCookie(c) + } + + releaseCookieMatches(matchesPtr) + + pooledPtr := acquireCookieMatches() + require.Len(t, *pooledPtr, 0) + releaseCookieMatches(pooledPtr) + + fasthttp.ReleaseURI(uri) +} diff --git a/client/core.go b/client/core.go index 38db93bfe73..19fd99fc9db 100644 --- a/client/core.go +++ b/client/core.go @@ -49,6 +49,27 @@ type core struct { ctx context.Context //nolint:containedctx // Context is needed here. } +var corePool = sync.Pool{ + New: func() any { + return new(core) + }, +} + +func acquireCore() *core { + c, ok := corePool.Get().(*core) + if !ok { + panic(errors.New("failed to type-assert to *core")) + } + return c +} + +func releaseCore(c *core) { + c.client = nil + c.req = nil + c.ctx = nil + corePool.Put(c) +} + // getRetryConfig returns a copy of the client's retry configuration. func (c *core) getRetryConfig() *RetryConfig { c.client.mu.RLock() @@ -81,6 +102,9 @@ func (c *core) execFunc() (*Response, error) { c.req.RawRequest.CopyTo(reqv) cfg := c.getRetryConfig() + client := c.client + req := c.req + var err error go func() { respv := fasthttp.AcquireResponse() @@ -92,16 +116,16 @@ func (c *core) execFunc() (*Response, error) { if cfg != nil { // Use an exponential backoff retry strategy. err = retry.NewExponentialBackoff(*cfg).Retry(func() error { - if c.req.maxRedirects > 0 && (string(reqv.Header.Method()) == fiber.MethodGet || string(reqv.Header.Method()) == fiber.MethodHead) { - return c.client.fasthttp.DoRedirects(reqv, respv, c.req.maxRedirects) + if req.maxRedirects > 0 && (string(reqv.Header.Method()) == fiber.MethodGet || string(reqv.Header.Method()) == fiber.MethodHead) { + return client.fasthttp.DoRedirects(reqv, respv, req.maxRedirects) } - return c.client.fasthttp.Do(reqv, respv) + return client.fasthttp.Do(reqv, respv) }) } else { - if c.req.maxRedirects > 0 && (string(reqv.Header.Method()) == fiber.MethodGet || string(reqv.Header.Method()) == fiber.MethodHead) { - err = c.client.fasthttp.DoRedirects(reqv, respv, c.req.maxRedirects) + if req.maxRedirects > 0 && (string(reqv.Header.Method()) == fiber.MethodGet || string(reqv.Header.Method()) == fiber.MethodHead) { + err = client.fasthttp.DoRedirects(reqv, respv, req.maxRedirects) } else { - err = c.client.fasthttp.Do(reqv, respv) + err = client.fasthttp.Do(reqv, respv) } } @@ -238,6 +262,10 @@ func acquireErrChan() chan error { // // Do not use the released channel afterward to avoid data races. func releaseErrChan(ch chan error) { + select { + case <-ch: + default: + } errChanPool.Put(ch) } diff --git a/client/hooks.go b/client/hooks.go index c035412a5ef..0fc902347d6 100644 --- a/client/hooks.go +++ b/client/hooks.go @@ -19,11 +19,54 @@ import ( var protocolCheck = regexp.MustCompile(`^https?://.*$`) -var fileBufPool = sync.Pool{ - New: func() any { - b := make([]byte, 1<<20) // 1MB buffer - return &b - }, +const ( + randByteDefaultCap = 32 + randByteMaxCap = 1 << 12 // 4KB +) + +var ( + fileBufPool = sync.Pool{ + New: func() any { + b := make([]byte, 1<<20) // 1MB buffer + return &b + }, + } + + randBytePool = sync.Pool{ + New: func() any { + return make([]byte, 0, randByteDefaultCap) + }, + } +) + +func acquireRandBytes(size int) []byte { + bufAny := randBytePool.Get() + buf, ok := bufAny.([]byte) + if !ok { + panic(errors.New("failed to type-assert to []byte")) + } + + if cap(buf) < size { + buf = make([]byte, size) + } + + return buf[:size] +} + +func releaseRandBytes(buf []byte) { + if buf == nil { + return + } + + clear(buf) + + if cap(buf) > randByteMaxCap { + buf = make([]byte, 0, randByteDefaultCap) + } else { + buf = buf[:0] + } + + randBytePool.Put(buf) } const ( @@ -43,24 +86,25 @@ func unsafeRandString(n int) (string, error) { inputLength := byte(len(letterBytes)) // Compute the largest multiple of inputLength ≤ 256 to avoid modulo bias. - // Any byte ≥ max will be rejected and re‑read. + // Any byte ≥ max will be rejected and re-read. maxLength := byte(256 - (256 % int(inputLength))) - out := make([]byte, n) - buf := make([]byte, n) + raw := acquireRandBytes(n) + defer releaseRandBytes(raw) // Read n raw bytes in one shot - if _, err := rand.Read(buf); err != nil { + if _, err := rand.Read(raw); err != nil { return "", fmt.Errorf("rand.Read failed: %w", err) } - for i, b := range buf { + out := make([]byte, n) + for i, b := range raw { // Reject values ≥ maxLength for b >= maxLength { - if _, err := rand.Read(buf[i : i+1]); err != nil { + if _, err := rand.Read(raw[i : i+1]); err != nil { return "", fmt.Errorf("rand.Read failed: %w", err) } - b = buf[i] + b = raw[i] } out[i] = letterBytes[b%inputLength] } diff --git a/client/request.go b/client/request.go index 64a0a1fb853..7ebc64306b6 100644 --- a/client/request.go +++ b/client/request.go @@ -134,6 +134,119 @@ type pair struct { v []string } +const ( + pairSliceDefaultCap = 8 + pairSliceMaxCap = 256 +) + +var pairPool = sync.Pool{ + New: func() any { + return &pair{ + k: make([]string, 0, pairSliceDefaultCap), + v: make([]string, 0, pairSliceDefaultCap), + } + }, +} + +func acquirePair(size int) *pair { + if size < pairSliceDefaultCap { + size = pairSliceDefaultCap + } + + pairAny := pairPool.Get() + p, ok := pairAny.(*pair) + if !ok { + panic(errors.New("failed to type-assert to *pair")) + } + + if cap(p.k) < size { + p.k = make([]string, 0, size) + } else { + p.k = p.k[:0] + } + + if cap(p.v) < size { + p.v = make([]string, 0, size) + } else { + p.v = p.v[:0] + } + + return p +} + +func releasePair(p *pair) { + if p == nil { + return + } + + if cap(p.k) > pairSliceMaxCap { + p.k = make([]string, 0, pairSliceDefaultCap) + } else { + p.k = p.k[:0] + } + + if cap(p.v) > pairSliceMaxCap { + p.v = make([]string, 0, pairSliceDefaultCap) + } else { + p.v = p.v[:0] + } + + pairPool.Put(p) +} + +const ( + headerKeySliceDefaultCap = 8 + headerKeySliceMaxCap = 64 +) + +var headerKeySlicePool = sync.Pool{ + New: func() any { + buf := make([][]byte, 0, headerKeySliceDefaultCap) + return &buf + }, +} + +func acquireHeaderKeySlice(size int) (*[][]byte, [][]byte) { + if size < headerKeySliceDefaultCap { + size = headerKeySliceDefaultCap + } + + keysPtr, ok := headerKeySlicePool.Get().(*[][]byte) + if !ok || keysPtr == nil { + buf := make([][]byte, 0, size) + keysPtr = &buf + } + + keys := *keysPtr + if cap(keys) < size { + keys = make([][]byte, 0, size) + } else { + keys = keys[:0] + } + + *keysPtr = keys + return keysPtr, keys +} + +func releaseHeaderKeySlice(keysPtr *[][]byte) { + if keysPtr == nil { + return + } + + keys := *keysPtr + for i := range keys { + keys[i] = nil + } + + if cap(keys) > headerKeySliceMaxCap { + *keysPtr = make([][]byte, 0, headerKeySliceDefaultCap) + } else { + *keysPtr = keys[:0] + } + + headerKeySlicePool.Put(keysPtr) +} + // Len implements sort.Interface and reports the number of tracked keys. func (p *pair) Len() int { return len(p.k) @@ -158,8 +271,13 @@ func (p *pair) Less(i, j int) bool { func (r *Request) Headers() iter.Seq2[string, []string] { return func(yield func(string, []string) bool) { peekKeys := r.header.PeekKeys() - keys := make([][]byte, len(peekKeys)) - copy(keys, peekKeys) // It is necessary to have immutable byte slice. + keysPtr, keys := acquireHeaderKeySlice(len(peekKeys)) + if len(peekKeys) > 0 { + keys = keys[:len(peekKeys)] + copy(keys, peekKeys) // It is necessary to have immutable byte slice. + } + *keysPtr = keys + defer releaseHeaderKeySlice(keysPtr) for _, key := range keys { vals := r.header.PeekAll(utils.UnsafeString(key)) @@ -217,17 +335,16 @@ func (r *Request) Param(key string) []string { // Do not store references to returned values; make copies instead. func (r *Request) Params() iter.Seq2[string, []string] { return func(yield func(string, []string) bool) { - vals := r.params.Len() - p := pair{ - k: make([]string, 0, vals), - v: make([]string, 0, vals), - } + p := acquirePair(r.params.Len()) + defer releasePair(p) + for k, v := range r.params.All() { p.k = append(p.k, utils.UnsafeString(k)) p.v = append(p.v, utils.UnsafeString(v)) } - sort.Sort(&p) + sort.Sort(p) + vals := len(p.k) j := 0 for i := range vals { if i == vals-1 || p.k[i] != p.k[i+1] { @@ -452,17 +569,16 @@ func (r *Request) FormData(key string) []string { // Do not store references to returned values; make copies instead. func (r *Request) AllFormData() iter.Seq2[string, []string] { return func(yield func(string, []string) bool) { - vals := r.formData.Len() - p := pair{ - k: make([]string, 0, vals), - v: make([]string, 0, vals), - } + p := acquirePair(r.formData.Len()) + defer releasePair(p) + for k, v := range r.formData.All() { p.k = append(p.k, utils.UnsafeString(k)) p.v = append(p.v, utils.UnsafeString(v)) } - sort.Sort(&p) + sort.Sort(p) + vals := len(p.k) j := 0 for i := range vals { if i == vals-1 || p.k[i] != p.k[i+1] { @@ -643,7 +759,9 @@ func (r *Request) Custom(url, method string) (*Response, error) { // Send executes the Request. func (r *Request) Send() (*Response, error) { r.checkClient() - return newCore().execute(r.Context(), r.Client(), r) + c := acquireCore() + defer releaseCore(c) + return c.execute(r.Context(), r.Client(), r) } // Reset clears the Request object, returning it to its default state. @@ -660,11 +778,13 @@ func (r *Request) Reset() { r.bodyType = noBody r.boundary = boundary - for len(r.files) != 0 { - t := r.files[0] - r.files = r.files[1:] - ReleaseFile(t) + for i := range r.files { + if f := r.files[i]; f != nil { + ReleaseFile(f) + r.files[i] = nil + } } + r.files = r.files[:0] r.formData.Reset() r.path.Reset() @@ -965,7 +1085,11 @@ func ReleaseRequest(req *Request) { requestPool.Put(req) } -var filePool sync.Pool +var filePool = sync.Pool{ + New: func() any { + return &File{} + }, +} // SetFileFunc defines a function that modifies a File object. type SetFileFunc func(f *File) @@ -1001,17 +1125,10 @@ func SetFileReader(r io.ReadCloser) SetFileFunc { // AcquireFile returns a (pooled) File object and applies the provided SetFileFunc functions to it. func AcquireFile(setter ...SetFileFunc) *File { fv := filePool.Get() - if fv != nil { - f, ok := fv.(*File) - if !ok { - panic(errors.New("failed to type-assert to *File")) - } - for _, v := range setter { - v(f) - } - return f + f, ok := fv.(*File) + if !ok { + panic(errors.New("failed to type-assert to *File")) } - f := &File{} for _, v := range setter { v(f) } diff --git a/client/request_test.go b/client/request_test.go index 8f4ff1c4fc2..18ff7da6930 100644 --- a/client/request_test.go +++ b/client/request_test.go @@ -354,6 +354,32 @@ func Benchmark_Request_Params(b *testing.B) { } } +func Test_requestPairPoolResetAndShrink(t *testing.T) { + t.Parallel() + + p := acquirePair(0) + require.NotNil(t, p) + p.k = append(p.k, "a", "b") + p.v = append(p.v, "1", "2") + + releasePair(p) + + reused := acquirePair(1) + require.Zero(t, len(reused.k)) + require.Zero(t, len(reused.v)) + releasePair(reused) + + oversized := acquirePair(pairSliceMaxCap + 32) + require.GreaterOrEqual(t, cap(oversized.k), pairSliceMaxCap+32) + require.GreaterOrEqual(t, cap(oversized.v), pairSliceMaxCap+32) + releasePair(oversized) + + trimmed := acquirePair(1) + require.LessOrEqual(t, cap(trimmed.k), pairSliceMaxCap) + require.LessOrEqual(t, cap(trimmed.v), pairSliceMaxCap) + releasePair(trimmed) +} + func Test_Request_UA(t *testing.T) { t.Parallel() diff --git a/client/response.go b/client/response.go index f62deebdbd7..7617f017732 100644 --- a/client/response.go +++ b/client/response.go @@ -163,10 +163,16 @@ func (r *Response) Reset() { r.client = nil r.request = nil - for len(r.cookie) != 0 { - t := r.cookie[0] - r.cookie = r.cookie[1:] - fasthttp.ReleaseCookie(t) + for i := range r.cookie { + if c := r.cookie[i]; c != nil { + fasthttp.ReleaseCookie(c) + r.cookie[i] = nil + } + } + if cap(r.cookie) > responseCookieSliceMaxCap { + r.cookie = make([]*fasthttp.Cookie, 0, responseCookieSliceDefaultCap) + } else { + r.cookie = r.cookie[:0] } r.RawResponse.Reset() @@ -183,10 +189,15 @@ func (r *Response) Close() { ReleaseResponse(r) } +const ( + responseCookieSliceDefaultCap = 4 + responseCookieSliceMaxCap = 64 +) + var responsePool = &sync.Pool{ New: func() any { return &Response{ - cookie: []*fasthttp.Cookie{}, + cookie: make([]*fasthttp.Cookie, 0, responseCookieSliceDefaultCap), RawResponse: fasthttp.AcquireResponse(), } }, diff --git a/client/response_test.go b/client/response_test.go index 200a5b9f73b..0bd95a77ffd 100644 --- a/client/response_test.go +++ b/client/response_test.go @@ -15,6 +15,7 @@ import ( "github.com/stretchr/testify/require" "github.com/gofiber/fiber/v3" + "github.com/valyala/fasthttp" ) func Test_Response_Status(t *testing.T) { @@ -200,6 +201,23 @@ func Test_Response_Header(t *testing.T) { resp.Close() } +func Test_Response_Reset_ShrinksCookieSlice(t *testing.T) { + t.Parallel() + + resp := AcquireResponse() + defer ReleaseResponse(resp) + + for i := 0; i < responseCookieSliceMaxCap*2; i++ { + resp.cookie = append(resp.cookie, fasthttp.AcquireCookie()) + } + require.Greater(t, cap(resp.cookie), responseCookieSliceMaxCap) + + resp.Reset() + + require.Len(t, resp.cookie, 0) + require.Equal(t, responseCookieSliceDefaultCap, cap(resp.cookie)) +} + func Test_Response_Headers(t *testing.T) { t.Parallel() diff --git a/ctx_interface_gen.go b/ctx_interface_gen.go index 3eb5c8bfe37..156128495a1 100644 --- a/ctx_interface_gen.go +++ b/ctx_interface_gen.go @@ -163,6 +163,7 @@ type Ctx interface { setMatched(matched bool) setRoute(route *Route) getPathOriginal() string + acquireIPSlices(size int) (*[]string, []string) // Accepts checks if the specified extensions or content types are acceptable. Accepts(offers ...string) string // AcceptsCharsets checks if the specified charset is acceptable. diff --git a/ctx_test.go b/ctx_test.go index 3c7a4f2bf35..23cf52f65c9 100644 --- a/ctx_test.go +++ b/ctx_test.go @@ -2615,6 +2615,36 @@ func Test_Ctx_IPs_With_IP_Validation(t *testing.T) { require.Empty(t, c.IPs()) } +func Test_Ctx_IPs_PoolReuse(t *testing.T) { + app := New(Config{ProxyHeader: HeaderXForwardedFor, TrustProxy: true}) + + fctx := &fasthttp.RequestCtx{} + c := app.AcquireCtx(fctx).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed + + bigCount := ipSliceMaxCap + 8 + ips := make([]string, bigCount) + for i := range ips { + ips[i] = fmt.Sprintf("127.0.0.%d", (i%250)+1) + } + + c.Request().Header.Set(HeaderXForwardedFor, strings.Join(ips, ",")) + got := c.IPs() + require.Len(t, got, bigCount) + require.NotNil(t, c.DefaultReq.ipSlicePtr) + require.Greater(t, cap(*c.DefaultReq.ipSlicePtr), ipSliceMaxCap) + + app.ReleaseCtx(c) + require.Nil(t, c.DefaultReq.ipSlicePtr) + + c = app.AcquireCtx(&fasthttp.RequestCtx{}).(*DefaultCtx) //nolint:errcheck,forcetypeassert // not needed + c.Request().Header.Set(HeaderXForwardedFor, "127.0.0.1") + require.Equal(t, []string{"127.0.0.1"}, c.IPs()) + require.NotNil(t, c.DefaultReq.ipSlicePtr) + require.LessOrEqual(t, cap(*c.DefaultReq.ipSlicePtr), ipSliceMaxCap) + + app.ReleaseCtx(c) +} + // go test -v -run=^$ -bench=Benchmark_Ctx_IPs -benchmem -count=4 func Benchmark_Ctx_IPs(b *testing.B) { app := New() @@ -5835,9 +5865,15 @@ func Test_Ctx_SendStreamWriter_Interrupted(t *testing.T) { body, err := io.ReadAll(resp.Body) t.Logf("%v", err) - require.EqualError(t, err, "unexpected EOF") + require.True(t, err == nil || errors.Is(err, io.ErrUnexpectedEOF), "expected io.ErrUnexpectedEOF or nil, got %v", err) - require.Equal(t, "Line 1\nLine 2\nLine 3\n", string(body)) + const base = "Line 1\nLine 2\nLine 3\n" + bodyStr := string(body) + require.Contains(t, []string{ + base, + base + "Line 4\n", + base + "Line 4\nLine 5\n", + }, bodyStr, "unexpected streamed body: %q", bodyStr) // ensure the first three lines were successfully flushed require.Equal(t, int32(3), flushed.Load()) @@ -6444,6 +6480,30 @@ func Test_Ctx_GetRespHeaders(t *testing.T) { }, c.Res().GetHeaders()) } +func Test_DefaultRes_GetHeaders_ReleasesScratch(t *testing.T) { + t.Parallel() + + app := New() + customCtx := app.AcquireCtx(&fasthttp.RequestCtx{}) + ctx, ok := customCtx.(*DefaultCtx) + require.True(t, ok) + + ctx.Response().Header.Set("Foo", "bar") + + first := ctx.Res().GetHeaders() + require.Equal(t, []string{"bar"}, first["Foo"]) + + ctx.Response().Header.Add("Foo", "baz") + second := ctx.Res().GetHeaders() + + require.Equal(t, []string{"bar", "baz"}, second["Foo"]) + require.Equal(t, []string{"bar"}, first["Foo"]) + + app.ReleaseCtx(customCtx) + + require.Zero(t, len(ctx.DefaultRes.headerScratch)) +} + func Benchmark_Ctx_GetRespHeaders(b *testing.B) { app := New() c := app.AcquireCtx(&fasthttp.RequestCtx{}) diff --git a/helpers.go b/helpers.go index 7d0ec3b3101..997cc633f0d 100644 --- a/helpers.go +++ b/helpers.go @@ -33,7 +33,7 @@ import ( // along with quality, specificity, parameters, and order. // Used for sorting accept headers. type acceptedType struct { - params headerParams + params *headerParams spec string quality float64 specificity int @@ -42,7 +42,132 @@ type acceptedType struct { const noCacheValue = "no-cache" -type headerParams map[string][]byte +type headerParams struct { + values map[string][]byte + pooled []*[]byte +} + +const ( + routeSetDefaultCap = 16 + routeSetMaxEntries = 256 +) + +var routeSetPool = sync.Pool{ + New: func() any { + return make(map[*Route]struct{}, routeSetDefaultCap) + }, +} + +func acquireRouteSet() map[*Route]struct{} { + m, ok := routeSetPool.Get().(map[*Route]struct{}) + if !ok { + panic(errors.New("failed to type-assert to map[*Route]struct{}")) + } + if m == nil { + return make(map[*Route]struct{}, routeSetDefaultCap) + } + if len(m) > 0 { + clear(m) + } + return m +} + +func releaseRouteSet(m map[*Route]struct{}) { + if m == nil { + return + } + used := len(m) + if used > 0 { + clear(m) + } + if used > routeSetMaxEntries { + return + } + routeSetPool.Put(m) +} + +// acceptedTypeSlicePool reuses the scratch slice used when parsing Accept +// headers so repeated negotiations avoid allocating temporary slices. +const ( + acceptedTypeSliceDefaultCap = 8 + acceptedTypeSliceMaxCap = 64 +) + +var acceptedTypeSlicePool = sync.Pool{ + New: func() any { + slice := make([]acceptedType, 0, acceptedTypeSliceDefaultCap) + return &slice + }, +} + +const ( + headerValueDefaultCap = 32 + headerValueMaxCap = 256 +) + +var headerValuePool = sync.Pool{ + New: func() any { + buf := make([]byte, 0, headerValueDefaultCap) + return &buf + }, +} + +const ( + languageTagSliceDefaultCap = 4 + languageTagSliceMaxCap = 32 +) + +var languageTagSlicePool = sync.Pool{ + New: func() any { + slice := make([]string, 0, languageTagSliceDefaultCap) + return &slice + }, +} + +func splitLanguageTags(s string) (*[]string, []string) { + sliceAny := languageTagSlicePool.Get() + tagsPtr, ok := sliceAny.(*[]string) + if !ok || tagsPtr == nil { + slice := make([]string, 0, languageTagSliceDefaultCap) + tagsPtr = &slice + } + + tags := (*tagsPtr)[:0] + if len(s) == 0 { + tags = append(tags, "") + } else { + start := 0 + for i := 0; i < len(s); i++ { + if s[i] == '-' { + tags = append(tags, s[start:i]) + start = i + 1 + } + } + tags = append(tags, s[start:]) + } + + *tagsPtr = tags + return tagsPtr, tags +} + +func releaseLanguageTagSlice(tagsPtr *([]string)) { + if tagsPtr == nil { + return + } + + tags := *tagsPtr + for i := range tags { + tags[i] = "" + } + + if cap(tags) > languageTagSliceMaxCap { + *tagsPtr = make([]string, 0, languageTagSliceDefaultCap) + } else { + *tagsPtr = tags[:0] + } + + languageTagSlicePool.Put(tagsPtr) +} // getTLSConfig returns a net listener's tls config func getTLSConfig(ln net.Listener) *tls.Config { @@ -151,13 +276,16 @@ func (*App) isASCII(s string) bool { // uniqueRouteStack drop all not unique routes from the slice func uniqueRouteStack(stack []*Route) []*Route { - var unique []*Route - m := make(map[*Route]struct{}) + routeSet := acquireRouteSet() + defer releaseRouteSet(routeSet) + + unique := make([]*Route, 0, len(stack)) for _, v := range stack { - if _, ok := m[v]; !ok { - m[v] = struct{}{} - unique = append(unique, v) + if _, ok := routeSet[v]; ok { + continue } + routeSet[v] = struct{}{} + unique = append(unique, v) } return unique @@ -186,7 +314,7 @@ func getGroupPath(prefix, path string) string { // acceptsOffer determines if an offer matches a given specification. // It supports a trailing '*' wildcard and performs case-insensitive exact matching. // Returns true if the offer matches the specification, false otherwise. -func acceptsOffer(spec, offer string, _ headerParams) bool { +func acceptsOffer(spec, offer string, _ *headerParams) bool { if len(spec) >= 1 && spec[len(spec)-1] == '*' { return true } @@ -200,7 +328,7 @@ func acceptsOffer(spec, offer string, _ headerParams) bool { // followed by a hyphen. The comparison is case-insensitive. Only a single "*" // as the entire range is allowed. Any "*" appearing after a hyphen renders the // range invalid and will not match. -func acceptsLanguageOfferBasic(spec, offer string, _ headerParams) bool { +func acceptsLanguageOfferBasic(spec, offer string, _ *headerParams) bool { if spec == "*" { return true } @@ -221,7 +349,7 @@ func acceptsLanguageOfferBasic(spec, offer string, _ headerParams) bool { // - '*' matches zero or more subtags (can “slide”) // - Unspecified subtags are treated like '*' (so trailing/extraneous tag subtags are fine) // - Matching fails if sliding encounters a singleton (incl. 'x') -func acceptsLanguageOfferExtended(spec, offer string, _ headerParams) bool { +func acceptsLanguageOfferExtended(spec, offer string, _ *headerParams) bool { if spec == "*" { return true } @@ -229,30 +357,37 @@ func acceptsLanguageOfferExtended(spec, offer string, _ headerParams) bool { return false } - rs := strings.Split(spec, "-") - ts := strings.Split(offer, "-") + specPtr, specTags := splitLanguageTags(spec) + defer releaseLanguageTagSlice(specPtr) + + offerPtr, offerTags := splitLanguageTags(offer) + defer releaseLanguageTagSlice(offerPtr) + + if len(specTags) == 0 || len(offerTags) == 0 { + return false + } // Step 2: first subtag must match (or be '*') - if !(rs[0] == "*" || utils.EqualFold(rs[0], ts[0])) { + if !(specTags[0] == "*" || utils.EqualFold(specTags[0], offerTags[0])) { return false } i, j := 1, 1 // i = range index, j = tag index - for i < len(rs) { - if rs[i] == "*" { // 3.A: '*' matches zero or more subtags + for i < len(specTags) { + if specTags[i] == "*" { // 3.A: '*' matches zero or more subtags i++ continue } - if j >= len(ts) { // 3.B: ran out of tag subtags + if j >= len(offerTags) { // 3.B: ran out of tag subtags return false } - if utils.EqualFold(rs[i], ts[j]) { // 3.C: exact subtag match + if utils.EqualFold(specTags[i], offerTags[j]) { // 3.C: exact subtag match i++ j++ continue } // 3.D: singleton barrier (one letter or digit, incl. 'x') - if len(ts[j]) == 1 { + if len(offerTags[j]) == 1 { return false } // 3.E: slide forward in the tag and try again @@ -268,7 +403,7 @@ func acceptsLanguageOfferExtended(spec, offer string, _ headerParams) bool { // It checks if the offer MIME type matches the specification MIME type or if the specification is of the form /* and the offer MIME type has the same MIME type. // It checks if the offer contains every parameter present in the specification. // Returns true if the offer type matches the specification, false otherwise. -func acceptsOfferType(spec, offerType string, specParams headerParams) bool { +func acceptsOfferType(spec, offerType string, specParams *headerParams) bool { var offerMime, offerParams string if i := strings.IndexByte(offerType, ';'); i == -1 { @@ -317,23 +452,24 @@ func acceptsOfferType(spec, offerType string, specParams headerParams) bool { // For the sake of simplicity, we forgo this and compare the value as-is. Besides, it would // be highly unusual for a client to escape something other than a double quote or backslash. // See https://www.rfc-editor.org/rfc/rfc9110#name-parameters -func paramsMatch(specParamStr headerParams, offerParams string) bool { - if len(specParamStr) == 0 { +func paramsMatch(specParamStr *headerParams, offerParams string) bool { + if specParamStr == nil || len(specParamStr.values) == 0 { return true } allSpecParamsMatch := true - for specParam, specVal := range specParamStr { + for specParam, specVal := range specParamStr.values { foundParam := false fasthttp.VisitHeaderParams(utils.UnsafeBytes(offerParams), func(key, value []byte) bool { if utils.EqualFold(specParam, utils.UnsafeString(key)) { foundParam = true - unescaped, err := unescapeHeaderValue(value) + unescaped, bufPtr, err := unescapeHeaderValue(value) if err != nil { allSpecParamsMatch = false return false } allSpecParamsMatch = utils.EqualFold(specVal, unescaped) + releaseHeaderValueBuffer(bufPtr) return false } return true @@ -353,7 +489,7 @@ func paramsMatch(specParamStr headerParams, offerParams string) bool { // If the given slice hasn't enough space, it will allocate more and return. func getSplicedStrList(headerValue string, dst []string) []string { if headerValue == "" { - return nil + return dst[:0] } dst = dst[:0] @@ -376,43 +512,91 @@ func getSplicedStrList(headerValue string, dst []string) []string { return dst } -func joinHeaderValues(headers [][]byte) []byte { +func joinHeaderValues(headers [][]byte) ([]byte, *[]byte) { switch len(headers) { case 0: - return nil + return nil, nil case 1: - return headers[0] - default: - return bytes.Join(headers, []byte{','}) + return headers[0], nil } + + bufAny := headerValuePool.Get() + bufPtr, ok := bufAny.(*[]byte) + if !ok { + panic(errors.New("failed to type-assert to *[]byte")) + } + buf := (*bufPtr)[:0] + + for i, header := range headers { + if i > 0 { + buf = append(buf, ',') + } + buf = append(buf, header...) + } + + *bufPtr = buf + return buf, bufPtr } -func unescapeHeaderValue(v []byte) ([]byte, error) { +func unescapeHeaderValue(v []byte) ([]byte, *[]byte, error) { if bytes.IndexByte(v, '\\') == -1 { - return v, nil + return v, nil, nil + } + + bufAny := headerValuePool.Get() + bufPtr, ok := bufAny.(*[]byte) + if !ok { + panic(errors.New("failed to type-assert to *[]byte")) } - res := make([]byte, 0, len(v)) + buf := (*bufPtr)[:0] + if cap(buf) < len(v) { + *bufPtr = buf[:0] + headerValuePool.Put(bufPtr) + bufPtr = nil + buf = make([]byte, 0, len(v)) + } + escaping := false for i, c := range v { if escaping { - res = append(res, c) + buf = append(buf, c) escaping = false continue } if c == '\\' { - // invalid escape at end of string if i == len(v)-1 { - return nil, errors.New("invalid escape sequence") + if bufPtr != nil { + *bufPtr = buf + releaseHeaderValueBuffer(bufPtr) + } + return nil, nil, errors.New("invalid escape sequence") } escaping = true continue } - res = append(res, c) + buf = append(buf, c) } if escaping { - return nil, errors.New("invalid escape sequence") + if bufPtr != nil { + *bufPtr = buf + releaseHeaderValueBuffer(bufPtr) + } + return nil, nil, errors.New("invalid escape sequence") + } + + if bufPtr == nil { + return buf, nil, nil + } + + if cap(buf) > headerValueMaxCap { + copied := append([]byte(nil), buf...) + *bufPtr = buf + releaseHeaderValueBuffer(bufPtr) + return copied, nil, nil } - return res, nil + + *bufPtr = buf + return buf, bufPtr, nil } // forEachMediaRange parses an Accept or Content-Type header, calling functor @@ -467,14 +651,93 @@ func forEachMediaRange(header []byte, functor func([]byte)) { // be cleared before being returned to the pool. var headerParamPool = sync.Pool{ New: func() any { - return make(headerParams) + return &headerParams{ + values: make(map[string][]byte, headerParamsValuesDefaultCap), + pooled: make([]*[]byte, 0, headerParamsPooledDefaultCap), + } }, } +const ( + headerParamsPooledDefaultCap = 4 + headerParamsPooledMaxCap = 32 +) + +const ( + headerParamsValuesDefaultCap = 4 + headerParamsValuesMaxEntries = 32 +) + +func acquireHeaderParams() *headerParams { + params, ok := headerParamPool.Get().(*headerParams) + if !ok || params == nil { + return &headerParams{ + values: make(map[string][]byte, headerParamsValuesDefaultCap), + pooled: make([]*[]byte, 0, headerParamsPooledDefaultCap), + } + } + if params.values == nil { + params.values = make(map[string][]byte, headerParamsValuesDefaultCap) + } + return params +} + +func releaseHeaderParams(params *headerParams) { + if params == nil { + return + } + pooled := params.pooled + for _, bufPtr := range pooled { + releaseHeaderValueBuffer(bufPtr) + } + if cap(pooled) > headerParamsPooledMaxCap { + params.pooled = make([]*[]byte, 0, headerParamsPooledDefaultCap) + } else { + params.pooled = pooled[:0] + } + if len(params.values) > headerParamsValuesMaxEntries { + params.values = make(map[string][]byte, headerParamsValuesDefaultCap) + } else { + for k := range params.values { + delete(params.values, k) + } + } + headerParamPool.Put(params) +} + +func (params *headerParams) set(key string, value []byte, pool *[]byte) { + if params == nil { + return + } + if params.values == nil { + params.values = make(map[string][]byte) + } + params.values[key] = value + if pool != nil { + params.pooled = append(params.pooled, pool) + } +} + +func releaseHeaderValueBuffer(bufPtr *[]byte) { + if bufPtr == nil { + return + } + buf := *bufPtr + for i := range buf { + buf[i] = 0 + } + if cap(buf) > headerValueMaxCap { + *bufPtr = make([]byte, 0, headerValueDefaultCap) + } else { + *bufPtr = buf[:0] + } + headerValuePool.Put(bufPtr) +} + // getOffer return valid offer for header negotiation. // Do not pass header using utils.UnsafeBytes - this can cause a panic due // to the use of utils.ToLowerBytes. -func getOffer(header []byte, isAccepted func(spec, offer string, specParams headerParams) bool, offers ...string) string { +func getOffer(header []byte, isAccepted func(spec, offer string, specParams *headerParams) bool, offers ...string) string { if len(offers) == 0 { return "" } @@ -482,7 +745,23 @@ func getOffer(header []byte, isAccepted func(spec, offer string, specParams head return offers[0] } - acceptedTypes := make([]acceptedType, 0, 8) + acceptedTypesPtr, ok := acceptedTypeSlicePool.Get().(*[]acceptedType) + if !ok { + panic(errors.New("failed to type-assert to *[]acceptedType")) + } + acceptedTypes := (*acceptedTypesPtr)[:0] + defer func() { + for i := range acceptedTypes { + acceptedTypes[i] = acceptedType{} + } + if cap(acceptedTypes) > acceptedTypeSliceMaxCap { + *acceptedTypesPtr = make([]acceptedType, 0, acceptedTypeSliceDefaultCap) + } else { + *acceptedTypesPtr = acceptedTypes[:0] + } + acceptedTypeSlicePool.Put(acceptedTypesPtr) + }() + order := 0 // Parse header and get accepted types with their quality and specificity @@ -490,7 +769,7 @@ func getOffer(header []byte, isAccepted func(spec, offer string, specParams head forEachMediaRange(header, func(accept []byte) { order++ spec, quality := accept, 1.0 - var params headerParams + var params *headerParams if i := bytes.IndexByte(accept, ';'); i != -1 { spec = accept[:i] @@ -502,10 +781,7 @@ func getOffer(header []byte, isAccepted func(spec, offer string, specParams head quality = q } } else { - params, _ = headerParamPool.Get().(headerParams) //nolint:errcheck // only contains headerParams - for k := range params { - delete(params, k) - } + params = acquireHeaderParams() fasthttp.VisitHeaderParams(accept[i:], func(key, value []byte) bool { if len(key) == 1 && key[0] == 'q' { if q, err := fasthttp.ParseUfloat(value); err == nil { @@ -514,11 +790,11 @@ func getOffer(header []byte, isAccepted func(spec, offer string, specParams head return false } lowerKey := utils.UnsafeString(utils.ToLowerBytes(key)) - val, err := unescapeHeaderValue(value) + val, bufPtr, err := unescapeHeaderValue(value) if err != nil { return true } - params[lowerKey] = val + params.set(lowerKey, val, bufPtr) return true }) } @@ -526,6 +802,7 @@ func getOffer(header []byte, isAccepted func(spec, offer string, specParams head // Skip this accept type if quality is 0.0 // See: https://www.rfc-editor.org/rfc/rfc9110#quality.values if quality == 0.0 { + releaseHeaderParams(params) return } } @@ -557,6 +834,7 @@ func getOffer(header []byte, isAccepted func(spec, offer string, specParams head order: order, params: params, }) + params = nil }) if len(acceptedTypes) > 1 { @@ -571,15 +849,11 @@ func getOffer(header []byte, isAccepted func(spec, offer string, specParams head continue } if isAccepted(acceptedType.spec, offer, acceptedType.params) { - if acceptedType.params != nil { - headerParamPool.Put(acceptedType.params) - } + releaseHeaderParams(acceptedType.params) return offer } } - if acceptedType.params != nil { - headerParamPool.Put(acceptedType.params) - } + releaseHeaderParams(acceptedType.params) } return "" @@ -590,17 +864,23 @@ func getOffer(header []byte, isAccepted func(spec, offer string, specParams head // e.g., text/html;a=1;b=2 comes before text/html;a=1 // See: https://www.rfc-editor.org/rfc/rfc9110#name-content-negotiation-fields func sortAcceptedTypes(at []acceptedType) { + paramsLen := func(hp *headerParams) int { + if hp == nil { + return 0 + } + return len(hp.values) + } for i := 1; i < len(at); i++ { lo, hi := 0, i-1 for lo <= hi { mid := (lo + hi) / 2 - if at[i].quality < at[mid].quality || - (at[i].quality == at[mid].quality && at[i].specificity < at[mid].specificity) || - (at[i].quality == at[mid].quality && at[i].specificity == at[mid].specificity && len(at[i].params) < len(at[mid].params)) || - (at[i].quality == at[mid].quality && at[i].specificity == at[mid].specificity && len(at[i].params) == len(at[mid].params) && at[i].order > at[mid].order) { - lo = mid + 1 - } else { + if at[i].quality > at[mid].quality || + (at[i].quality == at[mid].quality && at[i].specificity > at[mid].specificity) || + (at[i].quality == at[mid].quality && at[i].specificity == at[mid].specificity && paramsLen(at[i].params) > paramsLen(at[mid].params)) || + (at[i].quality == at[mid].quality && at[i].specificity == at[mid].specificity && paramsLen(at[i].params) == paramsLen(at[mid].params) && at[i].order < at[mid].order) { hi = mid - 1 + } else { + lo = mid + 1 } } for j := i; j > lo; j-- { diff --git a/helpers_test.go b/helpers_test.go index fb4c984dd1a..4ff9f046d26 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -6,7 +6,9 @@ package fiber import ( "math" + "reflect" "strconv" + "strings" "testing" "time" "unsafe" @@ -16,6 +18,17 @@ import ( "github.com/valyala/fasthttp" ) +func newHeaderParams(entries map[string]string) *headerParams { + if entries == nil { + return nil + } + params := &headerParams{values: make(map[string][]byte, len(entries))} + for k, v := range entries { + params.values[k] = []byte(v) + } + return params +} + func Test_Utils_GetOffer(t *testing.T) { t.Parallel() require.Equal(t, "", getOffer([]byte("hello"), acceptsOffer)) @@ -97,6 +110,61 @@ func Test_Utils_GetOffer(t *testing.T) { require.True(t, acceptsLanguageOfferExtended("*-CH", "de-Latn-CH", nil)) } +func Test_splitLanguageTags(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + input string + want []string + }{ + {name: "simple", input: "en-US", want: []string{"en", "US"}}, + {name: "leading hyphen", input: "-en", want: []string{"", "en"}}, + {name: "trailing hyphen", input: "en-", want: []string{"en", ""}}, + {name: "double hyphen", input: "en--US", want: []string{"en", "", "US"}}, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + ptr, tags := splitLanguageTags(tc.input) + require.Equal(t, tc.want, tags) + releaseLanguageTagSlice(ptr) + }) + } + + // Oversized slices should be trimmed back down on release. + large := strings.Repeat("x-", languageTagSliceMaxCap+1) + "z" + ptr, tags := splitLanguageTags(large) + require.Greater(t, cap(tags), languageTagSliceMaxCap) + releaseLanguageTagSlice(ptr) + + ptr, tags = splitLanguageTags("en") + require.LessOrEqual(t, cap(tags), languageTagSliceMaxCap) + releaseLanguageTagSlice(ptr) +} + +func Test_releaseHeaderParams(t *testing.T) { + params := acquireHeaderParams() + params.values["foo"] = []byte("bar") + releaseHeaderParams(params) + + params = acquireHeaderParams() + require.Len(t, params.values, 0) + + oldMapPtr := reflect.ValueOf(params.values).Pointer() + for i := 0; i < headerParamsValuesMaxEntries+5; i++ { + params.values[strconv.Itoa(i)] = []byte("v") + } + releaseHeaderParams(params) + + params = acquireHeaderParams() + require.Len(t, params.values, 0) + require.NotEqual(t, oldMapPtr, reflect.ValueOf(params.values).Pointer()) + releaseHeaderParams(params) +} + // go test -v -run=^$ -bench=Benchmark_Utils_GetOffer -benchmem -count=4 func Benchmark_Utils_GetOffer(b *testing.B) { testCases := []struct { @@ -187,7 +255,7 @@ func Benchmark_Utils_GetOffer(b *testing.B) { func Test_Utils_ParamsMatch(t *testing.T) { testCases := []struct { description string - accept headerParams + accept *headerParams offer string match bool }{ @@ -199,31 +267,31 @@ func Test_Utils_ParamsMatch(t *testing.T) { }, { description: "accept is empty, offer has params", - accept: make(headerParams), + accept: &headerParams{values: make(map[string][]byte)}, offer: ";foo=bar", match: true, }, { description: "offer is empty, accept has params", - accept: headerParams{"foo": []byte("bar")}, + accept: newHeaderParams(map[string]string{"foo": "bar"}), offer: "", match: false, }, { description: "accept has extra parameters", - accept: headerParams{"foo": []byte("bar"), "a": []byte("1")}, + accept: newHeaderParams(map[string]string{"foo": "bar", "a": "1"}), offer: ";foo=bar", match: false, }, { description: "matches regardless of order", - accept: headerParams{"b": []byte("2"), "a": []byte("1")}, + accept: newHeaderParams(map[string]string{"b": "2", "a": "1"}), offer: ";b=2;a=1", match: true, }, { description: "case insensitive", - accept: headerParams{"ParaM": []byte("FoO")}, + accept: newHeaderParams(map[string]string{"ParaM": "FoO"}), offer: ";pAram=foO", match: true, }, @@ -237,10 +305,7 @@ func Test_Utils_ParamsMatch(t *testing.T) { func Benchmark_Utils_ParamsMatch(b *testing.B) { var match bool - specParams := headerParams{ - "appLe": []byte("orange"), - "param": []byte("foo"), - } + specParams := newHeaderParams(map[string]string{"appLe": "orange", "param": "foo"}) b.ReportAllocs() for b.Loop() { match = paramsMatch(specParams, `;param=foo; apple=orange`) @@ -253,7 +318,7 @@ func Test_Utils_AcceptsOfferType(t *testing.T) { testCases := []struct { description string spec string - specParams headerParams + specParams *headerParams offerType string accepts bool }{ @@ -278,14 +343,14 @@ func Test_Utils_AcceptsOfferType(t *testing.T) { { description: "params match", spec: "application/json", - specParams: headerParams{"format": []byte("foo"), "version": []byte("1")}, + specParams: newHeaderParams(map[string]string{"format": "foo", "version": "1"}), offerType: "application/json;version=1;format=foo;q=0.1", accepts: true, }, { description: "spec has extra params", spec: "text/html", - specParams: headerParams{"charset": []byte("utf-8")}, + specParams: newHeaderParams(map[string]string{"charset": "utf-8"}), offerType: "text/html", accepts: false, }, @@ -298,14 +363,14 @@ func Test_Utils_AcceptsOfferType(t *testing.T) { { description: "ignores optional whitespace", spec: "application/json", - specParams: headerParams{"format": []byte("foo"), "version": []byte("1")}, + specParams: newHeaderParams(map[string]string{"format": "foo", "version": "1"}), offerType: "application/json; version=1 ; format=foo ", accepts: true, }, { description: "ignores optional whitespace", spec: "application/json", - specParams: headerParams{"format": []byte("foo bar"), "version": []byte("1")}, + specParams: newHeaderParams(map[string]string{"format": "foo bar", "version": "1"}), offerType: `application/json;version="1";format="foo bar"`, accepts: true, }, @@ -336,7 +401,7 @@ func Test_Utils_GetSplicedStrList(t *testing.T) { { description: "headerValue is empty", headerValue: "", - expectedList: nil, + expectedList: []string{}, }, { description: "has a comma without element", @@ -400,7 +465,7 @@ func Test_Utils_SortAcceptedTypes(t *testing.T) { {spec: "image/*", quality: 1, specificity: 2, order: 8}, {spec: "image/gif", quality: 1, specificity: 3, order: 9}, {spec: "text/plain", quality: 1, specificity: 3, order: 10}, - {spec: "application/json", quality: 0.999, specificity: 3, params: headerParams{"a": []byte("1")}, order: 11}, + {spec: "application/json", quality: 0.999, specificity: 3, params: newHeaderParams(map[string]string{"a": "1"}), order: 11}, {spec: "application/json", quality: 0.999, specificity: 3, order: 3}, } sortAcceptedTypes(acceptedTypes) @@ -413,7 +478,7 @@ func Test_Utils_SortAcceptedTypes(t *testing.T) { {spec: "image/gif", quality: 1, specificity: 3, order: 9}, {spec: "text/plain", quality: 1, specificity: 3, order: 10}, {spec: "image/*", quality: 1, specificity: 2, order: 8}, - {spec: "application/json", quality: 0.999, specificity: 3, params: headerParams{"a": []byte("1")}, order: 11}, + {spec: "application/json", quality: 0.999, specificity: 3, params: newHeaderParams(map[string]string{"a": "1"}), order: 11}, {spec: "application/json", quality: 0.999, specificity: 3, order: 3}, {spec: "text/*", quality: 0.5, specificity: 2, order: 1}, {spec: "*/*", quality: 0.1, specificity: 1, order: 2}, @@ -1334,26 +1399,36 @@ func Test_UnescapeHeaderValue(t *testing.T) { {in: "bad\\", ok: false}, } for _, tc := range cases { - out, err := unescapeHeaderValue([]byte(tc.in)) + out, bufPtr, err := unescapeHeaderValue([]byte(tc.in)) if tc.ok { require.NoError(t, err, tc.in) require.Equal(t, tc.out, out, tc.in) } else { require.Error(t, err, tc.in) } + releaseHeaderValueBuffer(bufPtr) } } func Test_JoinHeaderValues(t *testing.T) { t.Parallel() - require.Nil(t, joinHeaderValues(nil)) - require.Equal(t, []byte("a"), joinHeaderValues([][]byte{[]byte("a")})) - require.Equal(t, []byte("a,b"), joinHeaderValues([][]byte{[]byte("a"), []byte("b")})) + + buf, ptr := joinHeaderValues(nil) + require.Nil(t, buf) + require.Nil(t, ptr) + + buf, ptr = joinHeaderValues([][]byte{[]byte("a")}) + require.Equal(t, []byte("a"), buf) + require.Nil(t, ptr) + + buf, ptr = joinHeaderValues([][]byte{[]byte("a"), []byte("b")}) + require.Equal(t, []byte("a,b"), buf) + releaseHeaderValueBuffer(ptr) } func Test_ParamsMatch_InvalidEscape(t *testing.T) { t.Parallel() - match := paramsMatch(headerParams{"foo": []byte("bar")}, `;foo="bar\\`) + match := paramsMatch(newHeaderParams(map[string]string{"foo": "bar"}), `;foo="bar\\`) require.False(t, match) } diff --git a/middleware/adaptor/adaptor.go b/middleware/adaptor/adaptor.go index 2b2bc53faa7..068ee595543 100644 --- a/middleware/adaptor/adaptor.go +++ b/middleware/adaptor/adaptor.go @@ -1,6 +1,7 @@ package adaptor import ( + "errors" "io" "net" "net/http" @@ -175,7 +176,11 @@ func handlerFunc(app *fiber.App, h ...fiber.Handler) http.HandlerFunc { } // New fasthttp Ctx from pool - fctx := ctxPool.Get().(*fasthttp.RequestCtx) //nolint:forcetypeassert,errcheck // overlinting + ctxAny := ctxPool.Get() + fctx, ok := ctxAny.(*fasthttp.RequestCtx) + if !ok { + panic(errors.New("failed to type-assert to *fasthttp.RequestCtx")) + } fctx.Response.Reset() fctx.Request.Reset() defer ctxPool.Put(fctx) diff --git a/middleware/cache/cache.go b/middleware/cache/cache.go index d933f687cee..d97f8afc36b 100644 --- a/middleware/cache/cache.go +++ b/middleware/cache/cache.go @@ -290,14 +290,20 @@ func New(config ...Config) fiber.Handler { // Store all response headers // (more: https://datatracker.ietf.org/doc/html/rfc2616#section-13.5.1) if cfg.StoreResponseHeaders { - e.headers = make(map[string][]byte) - for key, value := range c.Response().Header.All() { - // create real copy - keyS := string(key) - if _, ok := ignoreHeaders[keyS]; !ok { - e.headers[keyS] = utils.CopyBytes(value) - } + if e.headers == nil { + e.headers = make(map[string][]byte) + } else if len(e.headers) > 0 { + clear(e.headers) } + c.Response().Header.VisitAll(func(key, value []byte) { + keyStr := string(key) + if _, ok := ignoreHeaders[keyStr]; ok { + return + } + e.headers[keyStr] = utils.CopyBytes(value) + }) + } else if len(e.headers) > 0 { + clear(e.headers) } // default cache expiration diff --git a/middleware/cache/cache_test.go b/middleware/cache/cache_test.go index c344191066a..b22623dfda2 100644 --- a/middleware/cache/cache_test.go +++ b/middleware/cache/cache_test.go @@ -11,6 +11,7 @@ import ( "math" "net/http/httptest" "os" + "reflect" "strconv" "testing" "time" @@ -85,6 +86,35 @@ func (s *failingCacheStorage) Reset() error { func (*failingCacheStorage) Close() error { return nil } +func TestManagerReleaseShrinksHeaderMap(t *testing.T) { + t.Parallel() + + storage := newFailingCacheStorage() + mgr := newManager(storage, false) + + item := mgr.acquire() + require.NotNil(t, item) + + headerCount := cacheHeaderMaxEntries + 16 + item.headers = make(map[string][]byte, headerCount) + for i := 0; i < headerCount; i++ { + key := fmt.Sprintf("Key-%d", i) + item.headers[key] = []byte("value") + } + + originalPtr := reflect.ValueOf(item.headers).Pointer() + + mgr.release(item) + + reacquired := mgr.acquire() + require.Equal(t, 0, len(reacquired.headers)) + + newPtr := reflect.ValueOf(reacquired.headers).Pointer() + require.NotEqual(t, originalPtr, newPtr) + + mgr.release(reacquired) +} + func TestCacheStorageGetError(t *testing.T) { t.Parallel() diff --git a/middleware/cache/manager.go b/middleware/cache/manager.go index bffc49eaf07..9ec7e5c6e82 100644 --- a/middleware/cache/manager.go +++ b/middleware/cache/manager.go @@ -37,6 +37,11 @@ type manager struct { const redactedKey = "[redacted]" +const ( + cacheHeaderDefaultCap = 8 + cacheHeaderMaxEntries = 128 +) + var errCacheMiss = errors.New("cache: miss") func newManager(storage fiber.Storage, redactKeys bool) *manager { @@ -61,7 +66,12 @@ func newManager(storage fiber.Storage, redactKeys bool) *manager { // acquire returns an *entry from the sync.Pool func (m *manager) acquire() *item { - return m.pool.Get().(*item) //nolint:forcetypeassert,errcheck // We store nothing else in the pool + entryAny := m.pool.Get() + entry, ok := entryAny.(*item) + if !ok { + panic(errors.New("failed to type-assert to *item")) + } + return entry } // release and reset *entry to sync.Pool @@ -72,11 +82,19 @@ func (m *manager) release(e *item) { } e.body = nil e.ctype = nil + e.cencoding = nil e.status = 0 e.age = 0 e.exp = 0 e.ttl = 0 - e.headers = nil + e.heapidx = 0 + headersLen := len(e.headers) + if headersLen > 0 { + clear(e.headers) + } + if headersLen > cacheHeaderMaxEntries { + e.headers = make(map[string][]byte, cacheHeaderDefaultCap) + } m.pool.Put(e) } diff --git a/middleware/encryptcookie/utils.go b/middleware/encryptcookie/utils.go index 867d590b031..9fbe419b4b0 100644 --- a/middleware/encryptcookie/utils.go +++ b/middleware/encryptcookie/utils.go @@ -9,10 +9,61 @@ import ( "fmt" "io" "slices" + "sync" ) var ErrInvalidKeyLength = errors.New("encryption key must be 16, 24, or 32 bytes") +const ( + encryptCookieBufferDefaultCap = 32 + encryptCookieBufferMaxCap = 4096 +) + +var encryptCookieBufferPool = sync.Pool{ + New: func() any { + buf := make([]byte, 0, encryptCookieBufferDefaultCap) + return &buf + }, +} + +func acquireEncryptCookieBuffer(requiredCap, nonceSize int) *[]byte { + bufAny := encryptCookieBufferPool.Get() + bufPtr, ok := bufAny.(*[]byte) + if !ok { + panic(errors.New("failed to type-assert to *[]byte")) + } + + buf := *bufPtr + if cap(buf) < requiredCap { + buf = make([]byte, 0, requiredCap) + } + + buf = buf[:nonceSize] + *bufPtr = buf + + return bufPtr +} + +func releaseEncryptCookieBuffer(bufPtr *[]byte) { + if bufPtr == nil { + return + } + + buf := *bufPtr + if cap(buf) > encryptCookieBufferMaxCap { + *bufPtr = make([]byte, 0, encryptCookieBufferDefaultCap) + encryptCookieBufferPool.Put(bufPtr) + return + } + + for i := range buf { + buf[i] = 0 + } + + *bufPtr = buf[:0] + encryptCookieBufferPool.Put(bufPtr) +} + // decodeKey decodes the provided base64-encoded key and validates its length. // It returns the decoded key bytes or an error when invalid. func decodeKey(key string) ([]byte, error) { @@ -52,13 +103,23 @@ func EncryptCookie(value, key string) (string, error) { return "", fmt.Errorf("failed to create GCM mode: %w", err) } - nonce := make([]byte, gcm.NonceSize()) + nonceSize := gcm.NonceSize() + requiredCap := nonceSize + len(value) + gcm.Overhead() + noncePtr := acquireEncryptCookieBuffer(requiredCap, nonceSize) + nonce := *noncePtr + if _, err = io.ReadFull(rand.Reader, nonce); err != nil { + releaseEncryptCookieBuffer(noncePtr) return "", fmt.Errorf("failed to read nonce: %w", err) } ciphertext := gcm.Seal(nonce, nonce, []byte(value), nil) - return base64.StdEncoding.EncodeToString(ciphertext), nil + *noncePtr = ciphertext + + encoded := base64.StdEncoding.EncodeToString(ciphertext) + releaseEncryptCookieBuffer(noncePtr) + + return encoded, nil } // DecryptCookie Decrypts a cookie value with specific encryption key diff --git a/middleware/idempotency/idempotency.go b/middleware/idempotency/idempotency.go index 05adc96d741..1797e9e0621 100644 --- a/middleware/idempotency/idempotency.go +++ b/middleware/idempotency/idempotency.go @@ -59,7 +59,9 @@ func New(config ...Config) fiber.Handler { if val, err := cfg.Storage.GetWithContext(c, key); err != nil { return false, fmt.Errorf("failed to read response: %w", err) } else if val != nil { - var res response + res := acquireCachedResponse() + defer releaseCachedResponse(res) + if _, err := res.UnmarshalMsg(val); err != nil { return false, fmt.Errorf("failed to unmarshal response: %w", err) } @@ -74,8 +76,10 @@ func New(config ...Config) fiber.Handler { if len(res.Body) != 0 { if err := c.Send(res.Body); err != nil { + res.Body = nil return true, err } + res.Body = nil } _ = c.Locals(localsKeyIsFromCache, true) @@ -133,27 +137,22 @@ func New(config ...Config) fiber.Handler { } // Construct response - res := &response{ - StatusCode: c.Response().StatusCode(), + res := acquireCachedResponse() + defer releaseCachedResponse(res) + + res.StatusCode = c.Response().StatusCode() + res.Body = append(res.Body[:0], c.Response().Body()...) - Body: utils.CopyBytes(c.Response().Body()), + headers := res.Headers + resetCachedResponseHeaders(headers) + if err := c.Bind().RespHeader(headers); err != nil { + return fmt.Errorf("failed to bind to response headers: %w", err) } - { - headers := make(map[string][]string) - if err := c.Bind().RespHeader(headers); err != nil { - return fmt.Errorf("failed to bind to response headers: %w", err) - } - if cfg.KeepResponseHeaders == nil { - // Keep all - res.Headers = headers - } else { - // Filter - res.Headers = make(map[string][]string) - for h := range headers { - if _, ok := keepResponseHeadersMap[utils.ToLower(h)]; ok { - res.Headers[h] = headers[h] - } + if len(keepResponseHeadersMap) > 0 { + for h := range headers { + if _, ok := keepResponseHeadersMap[strings.ToLower(h)]; !ok { + delete(headers, h) } } } diff --git a/middleware/idempotency/response.go b/middleware/idempotency/response.go index aafbca60468..4ba1f8accd7 100644 --- a/middleware/idempotency/response.go +++ b/middleware/idempotency/response.go @@ -1,5 +1,7 @@ package idempotency +import "sync" + // response is a struct that represents the response of a request. // generation tool `go install github.com/tinylib/msgp@latest` // @@ -10,3 +12,65 @@ type response struct { Body []byte `msg:"b"` StatusCode int `msg:"sc"` } + +const ( + cachedResponseBodyDefaultCap = 4 << 10 // 4 KiB default body buffer + cachedResponseBodyMaxCap = 256 << 10 // 256 KiB maximum retained buffer + cachedResponseHeaderHint = 8 +) + +var cachedResponsePool = sync.Pool{ + New: func() any { + return &response{ + Headers: make(map[string][]string, cachedResponseHeaderHint), + Body: make([]byte, 0, cachedResponseBodyDefaultCap), + } + }, +} + +func acquireCachedResponse() *response { + res, ok := cachedResponsePool.Get().(*response) + if !ok { + panic("failed to type-assert to *response") + } + if res.Headers == nil { + res.Headers = make(map[string][]string, cachedResponseHeaderHint) + } + if res.Body == nil { + res.Body = make([]byte, 0, cachedResponseBodyDefaultCap) + } + return res +} + +func releaseCachedResponse(res *response) { + if res == nil { + return + } + res.StatusCode = 0 + if res.Body != nil { + res.Body = resetCachedResponseBody(res.Body) + } + if res.Headers != nil { + resetCachedResponseHeaders(res.Headers) + } + cachedResponsePool.Put(res) +} + +func resetCachedResponseHeaders(headers map[string][]string) { + if headers == nil { + return + } + for key, values := range headers { + for i := range values { + values[i] = "" + } + delete(headers, key) + } +} + +func resetCachedResponseBody(body []byte) []byte { + if cap(body) > cachedResponseBodyMaxCap { + return make([]byte, 0, cachedResponseBodyDefaultCap) + } + return body[:0] +} diff --git a/middleware/keyauth/keyauth.go b/middleware/keyauth/keyauth.go index b7d84f14edb..1be5b542ad5 100644 --- a/middleware/keyauth/keyauth.go +++ b/middleware/keyauth/keyauth.go @@ -2,8 +2,9 @@ package keyauth import ( "errors" - "fmt" + "strconv" "strings" + "sync" "github.com/gofiber/fiber/v3" "github.com/gofiber/fiber/v3/extractors" @@ -22,6 +23,49 @@ const ( // ErrMissingOrMalformedAPIKey is returned when the API key is missing or invalid. var ErrMissingOrMalformedAPIKey = errors.New("missing or invalid API Key") +const ( + challengeBufferDefaultCap = 128 + challengeBufferMaxCap = 1024 +) + +var ( + challengeSlicePool = sync.Pool{ + New: func() any { + s := make([]string, 0, 4) + return &s + }, + } + + challengeBufferPool = sync.Pool{ + New: func() any { + buf := make([]byte, 0, challengeBufferDefaultCap) + return &buf + }, + } +) + +func releaseChallengeBuffer(bufPtr *[]byte, used int) { + if bufPtr == nil { + return + } + + buf := *bufPtr + if used > len(buf) { + used = len(buf) + } + for i := 0; i < used; i++ { + buf[i] = 0 + } + + if cap(buf) > challengeBufferMaxCap { + *bufPtr = make([]byte, 0, challengeBufferDefaultCap) + } else { + *bufPtr = buf[:0] + } + + challengeBufferPool.Put(bufPtr) +} + // New creates a new middleware handler func New(config ...Config) fiber.Handler { // Init config @@ -63,26 +107,55 @@ func New(config ...Config) fiber.Handler { header = fiber.HeaderProxyAuthenticate } if len(authSchemes) > 0 { - challenges := make([]string, 0, len(authSchemes)) + challengesAny := challengeSlicePool.Get() + challengesPtr, ok := challengesAny.(*[]string) + if !ok { + panic(errors.New("failed to type-assert to *[]string")) + } + challenges := (*challengesPtr)[:0] + defer func() { + for i := range challenges { + challenges[i] = "" + } + *challengesPtr = challenges[:0] + challengeSlicePool.Put(challengesPtr) + }() + for _, scheme := range authSchemes { - var b strings.Builder - fmt.Fprintf(&b, "%s realm=%q", scheme, cfg.Realm) - if utils.EqualFold(scheme, "Bearer") { - if cfg.Error != "" { - fmt.Fprintf(&b, ", error=%q", cfg.Error) - if cfg.ErrorDescription != "" { - fmt.Fprintf(&b, ", error_description=%q", cfg.ErrorDescription) - } - if cfg.ErrorURI != "" { - fmt.Fprintf(&b, ", error_uri=%q", cfg.ErrorURI) - } - if cfg.Error == ErrorInsufficientScope { - fmt.Fprintf(&b, ", scope=%q", cfg.Scope) - } + bufPtr, ok := challengeBufferPool.Get().(*[]byte) + if !ok { + panic(errors.New("failed to type-assert to *[]byte")) + } + buf := (*bufPtr)[:0] + + buf = append(buf, scheme...) + buf = append(buf, ' ') + buf = append(buf, "realm="...) + buf = strconv.AppendQuote(buf, cfg.Realm) + + if utils.EqualFold(scheme, "Bearer") && cfg.Error != "" { + buf = append(buf, ", error="...) + buf = strconv.AppendQuote(buf, cfg.Error) + + if cfg.ErrorDescription != "" { + buf = append(buf, ", error_description="...) + buf = strconv.AppendQuote(buf, cfg.ErrorDescription) + } + if cfg.ErrorURI != "" { + buf = append(buf, ", error_uri="...) + buf = strconv.AppendQuote(buf, cfg.ErrorURI) + } + if cfg.Error == ErrorInsufficientScope { + buf = append(buf, ", scope="...) + buf = strconv.AppendQuote(buf, cfg.Scope) } } - challenges = append(challenges, b.String()) + + challenge := string(buf) + releaseChallengeBuffer(bufPtr, len(buf)) + challenges = append(challenges, challenge) } + c.Set(header, strings.Join(challenges, ", ")) } else if cfg.Challenge != "" { c.Set(header, cfg.Challenge) diff --git a/middleware/limiter/manager.go b/middleware/limiter/manager.go index 13e1eb9df4d..b7545dab31b 100644 --- a/middleware/limiter/manager.go +++ b/middleware/limiter/manager.go @@ -2,6 +2,7 @@ package limiter import ( "context" + "errors" "fmt" "sync" "time" @@ -51,7 +52,12 @@ func newManager(storage fiber.Storage, redactKeys bool) *manager { // acquire returns an *entry from the sync.Pool func (m *manager) acquire() *item { - return m.pool.Get().(*item) //nolint:forcetypeassert,errcheck // We store nothing else in the pool + entryAny := m.pool.Get() + entry, ok := entryAny.(*item) + if !ok { + panic(errors.New("failed to type-assert to *item")) + } + return entry } // release and reset *entry to sync.Pool diff --git a/middleware/logger/logger.go b/middleware/logger/logger.go index 7d4befc9213..1c7e9d92492 100644 --- a/middleware/logger/logger.go +++ b/middleware/logger/logger.go @@ -1,6 +1,7 @@ package logger import ( + "errors" "os" "strconv" "strings" @@ -51,6 +52,29 @@ func New(config ...Config) fiber.Handler { dataPool = sync.Pool{New: func() any { return new(Data) }} ) + acquireData := func() *Data { + dataAny := dataPool.Get() + data, ok := dataAny.(*Data) + if !ok { + panic(errors.New("failed to type-assert to *Data")) + } + return data + } + + releaseData := func(data *Data) { + if data == nil { + return + } + + data.Start = time.Time{} + data.Stop = time.Time{} + data.ChainErr = nil + data.TemplateChain = nil + data.LogFuncChain = nil + + dataPool.Put(data) + } + // Err padding errPadding := 15 errPaddingStr := strconv.Itoa(errPadding) @@ -90,15 +114,15 @@ func New(config ...Config) fiber.Handler { }) // Logger data - data := dataPool.Get().(*Data) //nolint:forcetypeassert,errcheck // We store nothing else in the pool - // no need for a reset, as long as we always override everything + data := acquireData() + // Fields are overwritten on each request and releaseData clears residual state after logging. data.Pid = pid data.ErrPaddingStr = errPaddingStr data.Timestamp = timestamp data.TemplateChain = templateChain data.LogFuncChain = logFunChain // put data back in the pool - defer dataPool.Put(data) + defer releaseData(data) // Set latency start time if cfg.enableLatency { diff --git a/middleware/logger/tags.go b/middleware/logger/tags.go index 67a10a51225..7d48983d836 100644 --- a/middleware/logger/tags.go +++ b/middleware/logger/tags.go @@ -1,13 +1,70 @@ package logger import ( + "errors" "fmt" "maps" - "strings" + "sync" "github.com/gofiber/fiber/v3" ) +const ( + reqHeaderMapDefaultCap = 16 + reqHeaderMapMaxEntries = 256 +) + +var reqHeaderMapPool = sync.Pool{ + New: func() any { + m := make(map[string][]string, reqHeaderMapDefaultCap) + return &m + }, +} + +func acquireReqHeaderMap() *map[string][]string { + mapAny := reqHeaderMapPool.Get() + headerMapPtr, ok := mapAny.(*map[string][]string) + if !ok { + panic(errors.New("failed to type-assert to *map[string][]string")) + } + + headers := *headerMapPtr + if headers == nil { + headers = make(map[string][]string, reqHeaderMapDefaultCap) + } else if len(headers) > 0 { + clear(headers) + } + + *headerMapPtr = headers + + return headerMapPtr +} + +func releaseReqHeaderMap(headerMapPtr *map[string][]string) { + if headerMapPtr == nil { + return + } + + headers := *headerMapPtr + if headers == nil { + *headerMapPtr = make(map[string][]string, reqHeaderMapDefaultCap) + reqHeaderMapPool.Put(headerMapPtr) + return + } + + entryCount := len(headers) + if entryCount > 0 { + clear(headers) + } + + if entryCount > reqHeaderMapMaxEntries { + headers = nil + } + + *headerMapPtr = headers + reqHeaderMapPool.Put(headerMapPtr) +} + // Logger variables const ( TagPid = "pid" @@ -100,16 +157,59 @@ func createTagMap(cfg *Config) map[string]LogFunc { return output.Write(c.Response().Body()) }, TagReqHeaders: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) { - out := make(map[string][]string, 0) - if err := c.Bind().Header(&out); err != nil { + headerMapPtr := acquireReqHeaderMap() + headers := *headerMapPtr + defer func() { + *headerMapPtr = headers + releaseReqHeaderMap(headerMapPtr) + }() + + if err := c.Bind().Header(&headers); err != nil { return 0, err } - reqHeaders := make([]string, 0) - for k, v := range out { - reqHeaders = append(reqHeaders, k+"="+strings.Join(v, ",")) + var ( + written int + firstPair = true + ) + + for key, values := range headers { + if !firstPair { + if err := output.WriteByte('&'); err != nil { + return written, err + } + written++ + } + firstPair = false + + n, err := output.WriteString(key) + written += n + if err != nil { + return written, err + } + + if err := output.WriteByte('='); err != nil { + return written, err + } + written++ + + for i, value := range values { + if i > 0 { + if err := output.WriteByte(','); err != nil { + return written, err + } + written++ + } + + n, err = output.WriteString(value) + written += n + if err != nil { + return written, err + } + } } - return output.Write([]byte(strings.Join(reqHeaders, "&"))) + + return written, nil }, TagQueryStringParams: func(output Buffer, c fiber.Ctx, _ *Data, _ string) (int, error) { return output.WriteString(c.Request().URI().QueryArgs().String()) diff --git a/middleware/logger/tags_test.go b/middleware/logger/tags_test.go new file mode 100644 index 00000000000..7334a755865 --- /dev/null +++ b/middleware/logger/tags_test.go @@ -0,0 +1,52 @@ +package logger + +import ( + "fmt" + "strconv" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAcquireReqHeaderMapResets(t *testing.T) { + t.Parallel() + + ptr := acquireReqHeaderMap() + headers := *ptr + + headers["X-Test"] = []string{"value"} + + *ptr = headers + releaseReqHeaderMap(ptr) + + ptr = acquireReqHeaderMap() + headers = *ptr + require.Empty(t, headers) + + *ptr = headers + releaseReqHeaderMap(ptr) +} + +func TestReleaseReqHeaderMapDropsOversized(t *testing.T) { + t.Parallel() + + ptr := acquireReqHeaderMap() + headers := *ptr + firstAddr := fmt.Sprintf("%p", headers) + + for i := 0; i <= reqHeaderMapMaxEntries; i++ { + key := "Header-" + strconv.Itoa(i) + headers[key] = []string{"value"} + } + + *ptr = headers + releaseReqHeaderMap(ptr) + + ptr = acquireReqHeaderMap() + headers = *ptr + secondAddr := fmt.Sprintf("%p", headers) + require.NotEqual(t, firstAddr, secondAddr) + + *ptr = headers + releaseReqHeaderMap(ptr) +} diff --git a/middleware/rewrite/rewrite.go b/middleware/rewrite/rewrite.go index 5a7ad882886..368cde14ae4 100644 --- a/middleware/rewrite/rewrite.go +++ b/middleware/rewrite/rewrite.go @@ -1,13 +1,68 @@ package rewrite import ( + "errors" "regexp" "strconv" "strings" + "sync" "github.com/gofiber/fiber/v3" ) +const ( + rewriteReplaceDefaultCap = 8 + rewriteReplaceMaxCap = 128 +) + +var replacerArgPool = sync.Pool{ //nolint:gochecknoglobals // shared argument pool + New: func() any { + slice := make([]string, 0, rewriteReplaceDefaultCap) + return &slice + }, +} + +func acquireReplacerArgs(pairCount int) *[]string { + argsAny := replacerArgPool.Get() + argsPtr, ok := argsAny.(*[]string) + if !ok { + panic(errors.New("failed to type-assert to *[]string")) + } + + needed := pairCount * 2 + args := *argsPtr + + if cap(args) < needed { + args = make([]string, needed) + } else { + args = args[:needed] + } + + *argsPtr = args + + return argsPtr +} + +func releaseReplacerArgs(argsPtr *[]string) { + if argsPtr == nil { + return + } + + args := *argsPtr + if len(args) > 0 { + clear(args) + } + + if cap(args) > rewriteReplaceMaxCap { + args = make([]string, 0, rewriteReplaceDefaultCap) + } else { + args = args[:0] + } + + *argsPtr = args + replacerArgPool.Put(argsPtr) +} + // New creates a new middleware handler func New(config ...Config) fiber.Handler { cfg := configDefault(config...) @@ -44,11 +99,18 @@ func captureTokens(pattern *regexp.Regexp, input string) *strings.Replacer { return nil } values := groups[0][1:] - replace := make([]string, 2*len(values)) + if len(values) == 0 { + return strings.NewReplacer() + } + + replacePtr := acquireReplacerArgs(len(values)) + replace := *replacePtr for i, v := range values { j := 2 * i replace[j] = "$" + strconv.Itoa(i+1) replace[j+1] = v } - return strings.NewReplacer(replace...) + replacer := strings.NewReplacer(replace...) + releaseReplacerArgs(replacePtr) + return replacer } diff --git a/middleware/rewrite/rewrite_test.go b/middleware/rewrite/rewrite_test.go index ce35cd6ba4d..b8516617485 100644 --- a/middleware/rewrite/rewrite_test.go +++ b/middleware/rewrite/rewrite_test.go @@ -46,6 +46,23 @@ func Test_New(t *testing.T) { } } +func TestAcquireReleaseReplacerArgs(t *testing.T) { + ptr := acquireReplacerArgs(3) + + args := *ptr + require.Len(t, args, 6) + + releaseReplacerArgs(ptr) + + ptr = acquireReplacerArgs(1) + args = *ptr + require.Len(t, args, 2) + for _, value := range args { + require.Empty(t, value) + } + releaseReplacerArgs(ptr) +} + func Test_Rewrite(t *testing.T) { // Case 1: Next function always returns true app := fiber.New() diff --git a/middleware/session/data.go b/middleware/session/data.go index c9eb82fa31f..1417e7240de 100644 --- a/middleware/session/data.go +++ b/middleware/session/data.go @@ -4,6 +4,11 @@ import ( "sync" ) +const ( + sessionDataDefaultCap = 8 + sessionDataMaxEntries = 256 +) + // msgp -file="data.go" -o="data_msgp.go" -tests=true -unexported // //go:generate msgp -o=data_msgp.go -tests=true -unexported @@ -16,7 +21,7 @@ type data struct { var dataPool = sync.Pool{ New: func() any { d := new(data) - d.Data = make(map[any]any) + d.Data = make(map[any]any, sessionDataDefaultCap) return d }, } @@ -46,7 +51,17 @@ func acquireData() *data { func (d *data) Reset() { d.Lock() defer d.Unlock() - d.Data = make(map[any]any) + if d.Data == nil { + d.Data = make(map[any]any, sessionDataDefaultCap) + return + } + used := len(d.Data) + if used > 0 { + clear(d.Data) + } + if used > sessionDataMaxEntries { + d.Data = make(map[any]any, sessionDataDefaultCap) + } } // Get retrieves a value from the data map by key. diff --git a/middleware/session/session.go b/middleware/session/session.go index a3beed2c97e..6a986edb928 100644 --- a/middleware/session/session.go +++ b/middleware/session/session.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/gob" + "errors" "fmt" "sync" "time" @@ -33,12 +34,40 @@ const ( ) // Session pool for reusing byte buffers. +const ( + sessionByteBufferDefaultCap = 256 + sessionByteBufferMaxCap = 4096 +) + var byteBufferPool = sync.Pool{ New: func() any { - return new(bytes.Buffer) + buf := bytes.NewBuffer(make([]byte, 0, sessionByteBufferDefaultCap)) + buf.Reset() + return buf }, } +func acquireByteBuffer() *bytes.Buffer { + buf, ok := byteBufferPool.Get().(*bytes.Buffer) + if !ok { + panic(errors.New("failed to type-assert to *bytes.Buffer")) + } + buf.Reset() + return buf +} + +func releaseByteBuffer(buf *bytes.Buffer) { + if buf == nil { + return + } + if buf.Cap() > sessionByteBufferMaxCap { + buf = bytes.NewBuffer(make([]byte, 0, sessionByteBufferDefaultCap)) + } else { + buf.Reset() + } + byteBufferPool.Put(buf) +} + var sessionPool = sync.Pool{ New: func() any { return &Session{} @@ -54,7 +83,11 @@ var sessionPool = sync.Pool{ // // s := acquireSession() func acquireSession() *Session { - s := sessionPool.Get().(*Session) //nolint:forcetypeassert,errcheck // We store nothing else in the pool + sessionAny := sessionPool.Get() + s, ok := sessionAny.(*Session) + if !ok { + panic(errors.New("failed to type-assert to *Session")) + } if s.data == nil { s.data = acquireData() } @@ -502,9 +535,8 @@ func (s *Session) setCookieAttributes(fcookie *fasthttp.Cookie) { // // err := s.decodeSessionData(rawData) func (s *Session) decodeSessionData(rawData []byte) error { - byteBuffer := byteBufferPool.Get().(*bytes.Buffer) //nolint:forcetypeassert,errcheck // We store nothing else in the pool - defer byteBufferPool.Put(byteBuffer) - defer byteBuffer.Reset() + byteBuffer := acquireByteBuffer() + defer releaseByteBuffer(byteBuffer) _, _ = byteBuffer.Write(rawData) decCache := gob.NewDecoder(byteBuffer) if err := decCache.Decode(&s.data.Data); err != nil { @@ -525,9 +557,8 @@ func (s *Session) decodeSessionData(rawData []byte) error { // // err := s.encodeSessionData(rawData) func (s *Session) encodeSessionData() ([]byte, error) { - byteBuffer := byteBufferPool.Get().(*bytes.Buffer) //nolint:forcetypeassert,errcheck // We store nothing else in the pool - defer byteBufferPool.Put(byteBuffer) - defer byteBuffer.Reset() + byteBuffer := acquireByteBuffer() + defer releaseByteBuffer(byteBuffer) encCache := gob.NewEncoder(byteBuffer) if err := encCache.Encode(&s.data.Data); err != nil { return nil, fmt.Errorf("failed to encode session data: %w", err) diff --git a/middleware/static/static.go b/middleware/static/static.go index e6a4730173d..f66c4dd03b9 100644 --- a/middleware/static/static.go +++ b/middleware/static/static.go @@ -19,6 +19,34 @@ import ( "github.com/gofiber/fiber/v3" ) +const ( + sanitizePathDefaultCap = 64 + sanitizePathMaxCap = 4096 +) + +var sanitizePathBufPool = sync.Pool{ + New: func() any { + buf := make([]byte, 0, sanitizePathDefaultCap) + return &buf + }, +} + +func releaseSanitizePathBuf(bufPtr *[]byte, buf []byte) { + if bufPtr == nil { + return + } + + clear(buf) + if cap(buf) > sanitizePathMaxCap { + buf = make([]byte, 0, sanitizePathDefaultCap) + } else { + buf = buf[:0] + } + + *bufPtr = buf + sanitizePathBufPool.Put(bufPtr) +} + // sanitizePath validates and cleans the requested path. // It returns an error if the path attempts to traverse directories. func sanitizePath(p []byte, filesystem fs.FS) ([]byte, error) { @@ -27,14 +55,26 @@ func sanitizePath(p []byte, filesystem fs.FS) ([]byte, error) { hasTrailingSlash := len(p) > 0 && p[len(p)-1] == '/' if bytes.IndexByte(p, '\\') >= 0 { - b := make([]byte, len(p)) - copy(b, p) - for i := range b { - if b[i] == '\\' { - b[i] = '/' + bufAny := sanitizePathBufPool.Get() + bufPtr, ok := bufAny.(*[]byte) + if !ok { + panic(errors.New("failed to type-assert to *[]byte")) + } + buf := *bufPtr + if cap(buf) < len(p) { + buf = make([]byte, len(p)) + } else { + buf = buf[:len(p)] + } + + copy(buf, p) + for i := range buf { + if buf[i] == '\\' { + buf[i] = '/' } } - s = utils.UnsafeString(b) + s = utils.UnsafeString(buf) + defer releaseSanitizePathBuf(bufPtr, buf) } else { s = utils.UnsafeString(p) } diff --git a/path.go b/path.go index e3cb16d0741..2088e4bd4c8 100644 --- a/path.go +++ b/path.go @@ -8,6 +8,7 @@ package fiber import ( "bytes" + "errors" "regexp" "strconv" "strings" @@ -27,9 +28,25 @@ type routeParser struct { plusCount int // number of plus parameters, used internally to give the plus parameter its number } +const ( + routeParserSegDefaultCap = 16 + routeParserSegMaxCap = 128 + routeParserParamDefaultCap = 16 + routeParserParamMaxCap = 128 +) + var routerParserPool = &sync.Pool{ New: func() any { - return &routeParser{} + return &routeParser{ + segs: make([]*routeSegment, 0, routeParserSegDefaultCap), + params: make([]string, 0, routeParserParamDefaultCap), + } + }, +} + +var routeSegmentPool = sync.Pool{ + New: func() any { + return &routeSegment{} }, } @@ -52,6 +69,50 @@ type routeSegment struct { HasOptionalSlash bool // segment has the possibility of an optional slash } +func acquireRouteSegment() *routeSegment { + segAny := routeSegmentPool.Get() + seg, ok := segAny.(*routeSegment) + if !ok { + panic(errors.New("failed to type-assert to *routeSegment")) + } + + resetRouteSegment(seg) + + return seg +} + +func releaseRouteSegment(seg *routeSegment) { + if seg == nil { + return + } + + resetRouteSegment(seg) + routeSegmentPool.Put(seg) +} + +func resetRouteSegment(seg *routeSegment) { + if seg == nil { + return + } + + seg.Const = "" + seg.ParamName = "" + seg.ComparePart = "" + if len(seg.Constraints) > 0 { + for i := range seg.Constraints { + seg.Constraints[i] = nil + } + seg.Constraints = nil + } + seg.PartCount = 0 + seg.Length = 0 + seg.IsParam = false + seg.IsGreedy = false + seg.IsOptional = false + seg.IsLast = false + seg.HasOptionalSlash = false +} + // different special routing signs const ( wildcardParam byte = '*' // indicates an optional greedy parameter @@ -179,7 +240,11 @@ func RoutePatternMatch(path, pattern string, cfg ...Config) bool { patternPretty = utils.TrimRight(patternPretty, '/') } - parser, _ := routerParserPool.Get().(*routeParser) //nolint:errcheck // only contains routeParser + parserAny := routerParserPool.Get() + parser, ok := parserAny.(*routeParser) + if !ok { + panic(errors.New("failed to type-assert to *routeParser")) + } parser.reset() parser.parseRoute(string(patternPretty)) defer routerParserPool.Put(parser) @@ -204,8 +269,25 @@ func RoutePatternMatch(path, pattern string, cfg ...Config) bool { } func (parser *routeParser) reset() { - parser.segs = parser.segs[:0] - parser.params = parser.params[:0] + for i, seg := range parser.segs { + releaseRouteSegment(seg) + parser.segs[i] = nil + } + if cap(parser.segs) > routeParserSegMaxCap { + parser.segs = make([]*routeSegment, 0, routeParserSegDefaultCap) + } else { + parser.segs = parser.segs[:0] + } + + if len(parser.params) > 0 { + clear(parser.params) + } + if cap(parser.params) > routeParserParamMaxCap { + parser.params = make([]string, 0, routeParserParamDefaultCap) + } else { + parser.params = parser.params[:0] + } + parser.wildCardCount = 0 parser.plusCount = 0 } @@ -319,10 +401,11 @@ func (*routeParser) analyseConstantPart(pattern string, nextParamPosition int) ( processedPart = pattern[:nextParamPosition] } constPart := RemoveEscapeChar(processedPart) - return len(processedPart), &routeSegment{ - Const: constPart, - Length: len(constPart), - } + segment := acquireRouteSegment() + segment.Const = constPart + segment.Length = len(constPart) + + return len(processedPart), segment } // analyseParameterPart find the parameter end and create the route segment @@ -428,12 +511,11 @@ func (parser *routeParser) analyseParameterPart(pattern string, customConstraint paramName += strconv.Itoa(parser.plusCount) } - segment := &routeSegment{ - ParamName: paramName, - IsParam: true, - IsOptional: isWildCard || pattern[paramEndPosition] == optionalParam, - IsGreedy: isWildCard || isPlusParam, - } + segment := acquireRouteSegment() + segment.ParamName = paramName + segment.IsParam = true + segment.IsOptional = isWildCard || pattern[paramEndPosition] == optionalParam + segment.IsGreedy = isWildCard || isPlusParam if len(constraints) > 0 { segment.Constraints = constraints diff --git a/path_test.go b/path_test.go index c340c7fc33b..3b35e2a291c 100644 --- a/path_test.go +++ b/path_test.go @@ -165,11 +165,91 @@ func Test_RoutePatternMatch(t *testing.T) { require.Equal(t, c.match, match, "route: '%s', url: '%s'", pattern, c.url) } } + for _, testCase := range routeTestCases { testCaseFn(testCase.pattern, testCase.testCases) } } +func TestRouteParserResetBounds(t *testing.T) { + t.Parallel() + + parser := &routeParser{ + segs: make([]*routeSegment, 1, routeParserSegMaxCap+32), + params: make([]string, 1, routeParserParamMaxCap+32), + wildCardCount: 3, + plusCount: 2, + } + + parser.segs[0] = &routeSegment{Const: "value"} + parser.params[0] = "param" + + parser.reset() + + require.Zero(t, len(parser.segs)) + require.Zero(t, len(parser.params)) + require.Equal(t, routeParserSegDefaultCap, cap(parser.segs)) + require.Equal(t, routeParserParamDefaultCap, cap(parser.params)) + require.Zero(t, parser.wildCardCount) + require.Zero(t, parser.plusCount) +} + +func TestRouteSegmentPoolResetsState(t *testing.T) { + t.Parallel() + + seg := acquireRouteSegment() + seg.Const = "value" + seg.ParamName = "param" + seg.ComparePart = "compare" + seg.Constraints = []*Constraint{{Name: "id", Data: []string{"1"}}} + seg.PartCount = 3 + seg.Length = 42 + seg.IsParam = true + seg.IsGreedy = true + seg.IsOptional = true + seg.IsLast = true + seg.HasOptionalSlash = true + + releaseRouteSegment(seg) + + reused := acquireRouteSegment() + require.Empty(t, reused.Const) + require.Empty(t, reused.ParamName) + require.Empty(t, reused.ComparePart) + require.Nil(t, reused.Constraints) + require.Zero(t, reused.PartCount) + require.Zero(t, reused.Length) + require.False(t, reused.IsParam) + require.False(t, reused.IsGreedy) + require.False(t, reused.IsOptional) + require.False(t, reused.IsLast) + require.False(t, reused.HasOptionalSlash) + + releaseRouteSegment(reused) +} + +func TestRouteParserResetReleasesSegments(t *testing.T) { + t.Parallel() + + parser := &routeParser{ + segs: make([]*routeSegment, 0, 1), + params: make([]string, 0, 1), + } + + seg := acquireRouteSegment() + seg.Const = "value" + parser.segs = append(parser.segs, seg) + parser.params = append(parser.params, "param") + + parser.reset() + + require.Empty(t, parser.segs) + require.Empty(t, parser.params) + require.Empty(t, seg.Const) + require.False(t, seg.IsParam) + require.False(t, seg.IsGreedy) +} + func TestHasPartialMatchBoundary(t *testing.T) { t.Parallel() diff --git a/redirect.go b/redirect.go index ac17c48b343..db086b90918 100644 --- a/redirect.go +++ b/redirect.go @@ -15,15 +15,113 @@ import ( ) // Pool for redirection +const ( + redirectMessagesDefaultCap = 4 + redirectMessagesMaxCap = 64 +) + var redirectPool = sync.Pool{ New: func() any { return &Redirect{ status: StatusSeeOther, - messages: make(redirectionMsgs, 0), + messages: make(redirectionMsgs, 0, redirectMessagesDefaultCap), } }, } +const ( + redirectInputDefaultCap = 8 + redirectInputMaxEntries = 64 +) + +const ( + redirectMsgBufferDefaultCap = 64 + redirectMsgBufferMaxCap = 1024 +) + +var redirectMsgBufferPool = sync.Pool{ + New: func() any { + buf := make([]byte, 0, redirectMsgBufferDefaultCap) + return &buf + }, +} + +var redirectInputPool = sync.Pool{ + New: func() any { + return make(map[string]string, redirectInputDefaultCap) + }, +} + +func acquireRedirectInputMap() map[string]string { + m, ok := redirectInputPool.Get().(map[string]string) + if !ok { + panic(errors.New("failed to type-assert to map[string]string")) + } + + if len(m) > 0 { + clear(m) + } + + return m +} + +func releaseRedirectInputMap(m map[string]string) { + if m == nil { + return + } + + used := len(m) + if used > 0 { + clear(m) + } + + if used > redirectInputMaxEntries { + redirectInputPool.Put(make(map[string]string, redirectInputDefaultCap)) + return + } + + redirectInputPool.Put(m) +} + +func acquireRedirectMsgBuffer() *[]byte { + bufAny := redirectMsgBufferPool.Get() + bufPtr, ok := bufAny.(*[]byte) + if !ok || bufPtr == nil { + buf := make([]byte, 0, redirectMsgBufferDefaultCap) + return &buf + } + + buf := *bufPtr + if cap(buf) < redirectMsgBufferDefaultCap { + buf = make([]byte, 0, redirectMsgBufferDefaultCap) + } else { + buf = buf[:0] + } + + *bufPtr = buf + return bufPtr +} + +func releaseRedirectMsgBuffer(bufPtr *[]byte) { + if bufPtr == nil { + return + } + + buf := *bufPtr + for i := range buf { + buf[i] = 0 + } + + if cap(buf) > redirectMsgBufferMaxCap { + buf = make([]byte, 0, redirectMsgBufferDefaultCap) + } else { + buf = buf[:0] + } + + *bufPtr = buf + redirectMsgBufferPool.Put(bufPtr) +} + // Cookie name to send flash messages when to use redirection. const ( FlashCookieName = "fiber_flash" @@ -94,7 +192,15 @@ func ReleaseRedirect(r *Redirect) { func (r *Redirect) release() { r.status = StatusSeeOther - r.messages = r.messages[:0] + msgs := r.messages + for i := range msgs { + msgs[i] = redirectionMsg{} + } + if cap(msgs) > redirectMessagesMaxCap { + r.messages = make(redirectionMsgs, 0, redirectMessagesDefaultCap) + } else { + r.messages = msgs[:0] + } r.c = nil } @@ -145,7 +251,9 @@ func (r *Redirect) WithInput() *Redirect { ctype := utils.ToLower(utils.UnsafeString(r.c.RequestCtx().Request.Header.ContentType())) ctype = binder.FilterFlags(utils.ParseVendorSpecificContentType(ctype)) - oldInput := make(map[string]string) + oldInput := acquireRedirectInputMap() + defer releaseRedirectInputMap(oldInput) + switch ctype { case MIMEApplicationForm, MIMEMultipartForm: _ = r.c.Bind().Form(oldInput) //nolint:errcheck // not needed @@ -315,11 +423,16 @@ func (r *Redirect) processFlashMessages() { return } - val, err := r.messages.MarshalMsg(nil) + bufPtr := acquireRedirectMsgBuffer() + defer releaseRedirectMsgBuffer(bufPtr) + + val, err := r.messages.MarshalMsg((*bufPtr)[:0]) if err != nil { return } + *bufPtr = val + dst := make([]byte, hex.EncodedLen(len(val))) hex.Encode(dst, val) diff --git a/redirect_test.go b/redirect_test.go index 89e01127024..f50df7ad80c 100644 --- a/redirect_test.go +++ b/redirect_test.go @@ -61,6 +61,30 @@ func Test_Redirect_To_WithFlashMessages(t *testing.T) { require.Contains(t, msgs, redirectionMsg{key: "message", value: "test", level: 2, isOldInput: false}) } +func Test_redirectMsgBufferPool(t *testing.T) { + t.Parallel() + + bufPtr := acquireRedirectMsgBuffer() + require.NotNil(t, bufPtr) + buf := *bufPtr + require.Equal(t, 0, len(buf)) + require.GreaterOrEqual(t, cap(buf), redirectMsgBufferDefaultCap) + + releaseRedirectMsgBuffer(bufPtr) + + bufPtr = acquireRedirectMsgBuffer() + buf = *bufPtr + // Inflate the buffer beyond the max cap and ensure it resets on release. + big := make([]byte, redirectMsgBufferMaxCap+redirectMsgBufferDefaultCap) + *bufPtr = big + releaseRedirectMsgBuffer(bufPtr) + + bufPtr = acquireRedirectMsgBuffer() + buf = *bufPtr + require.LessOrEqual(t, cap(buf), redirectMsgBufferDefaultCap) + releaseRedirectMsgBuffer(bufPtr) +} + // go test -run Test_Redirect_Route_WithParams func Test_Redirect_Route_WithParams(t *testing.T) { t.Parallel() diff --git a/req.go b/req.go index d6e3d116583..7a710449e83 100644 --- a/req.go +++ b/req.go @@ -7,6 +7,7 @@ import ( "net" "strconv" "strings" + "sync" utils "github.com/gofiber/utils/v2" "github.com/valyala/fasthttp" @@ -29,38 +30,166 @@ type RangeSet struct { // //go:generate ifacemaker --file req.go --struct DefaultReq --iface Req --pkg fiber --output req_interface_gen.go --not-exported true --iface-comment "Req is an interface for request-related Ctx methods." type DefaultReq struct { - c *DefaultCtx + c *DefaultCtx + ipSlicePtr *[]string +} + +const ( + encodingOrderDefaultCap = 4 + encodingOrderMaxCap = 16 +) + +var encodingOrderPool = sync.Pool{ + New: func() any { + slice := make([]string, 0, encodingOrderDefaultCap) + return &slice + }, +} + +func acquireEncodingOrder() *[]string { + orderPtr, ok := encodingOrderPool.Get().(*[]string) + if !ok { + panic(errors.New("failed to type-assert to *[]string")) + } + + order := *orderPtr + order = order[:0] + *orderPtr = order + + return orderPtr +} + +func releaseEncodingOrder(orderPtr *[]string) { + if orderPtr == nil { + return + } + + order := *orderPtr + for i := range order { + order[i] = "" + } + + if cap(order) > encodingOrderMaxCap { + order = make([]string, 0, encodingOrderDefaultCap) + } else { + order = order[:0] + } + + *orderPtr = order + encodingOrderPool.Put(orderPtr) +} + +const ( + ipSliceDefaultCap = 4 + ipSliceMaxCap = 64 +) + +var ipSlicePool = sync.Pool{ + New: func() any { + slice := make([]string, 0, ipSliceDefaultCap) + return &slice + }, +} + +func acquireIPSliceBuffer(size int) (*[]string, []string) { + if size < ipSliceDefaultCap { + size = ipSliceDefaultCap + } + + sliceAny := ipSlicePool.Get() + slicePtr, ok := sliceAny.(*[]string) + if !ok || slicePtr == nil { + slice := make([]string, 0, size) + slicePtr = &slice + } + + slice := *slicePtr + if cap(slice) < size { + slice = make([]string, 0, size) + } else { + slice = slice[:0] + } + + *slicePtr = slice + return slicePtr, slice +} + +func releaseIPSliceBuffer(slicePtr *[]string) { + if slicePtr == nil { + return + } + + slice := *slicePtr + for i := range slice { + slice[i] = "" + } + + if cap(slice) > ipSliceMaxCap { + slice = make([]string, 0, ipSliceDefaultCap) + } else { + slice = slice[:0] + } + + *slicePtr = slice + ipSlicePool.Put(slicePtr) +} + +func (r *DefaultReq) acquireIPSlices(size int) (*[]string, []string) { + if r.ipSlicePtr == nil { + slicePtr, slice := acquireIPSliceBuffer(size) + r.ipSlicePtr = slicePtr + return slicePtr, slice + } + + if size < ipSliceDefaultCap { + size = ipSliceDefaultCap + } + + slice := *r.ipSlicePtr + if cap(slice) < size { + slice = make([]string, 0, size) + } else { + slice = slice[:0] + } + + *r.ipSlicePtr = slice + return r.ipSlicePtr, slice } // Accepts checks if the specified extensions or content types are acceptable. func (r *DefaultReq) Accepts(offers ...string) string { - header := joinHeaderValues(r.c.fasthttp.Request.Header.PeekAll(HeaderAccept)) + header, bufPtr := joinHeaderValues(r.c.fasthttp.Request.Header.PeekAll(HeaderAccept)) + defer releaseHeaderValueBuffer(bufPtr) return getOffer(header, acceptsOfferType, offers...) } // AcceptsCharsets checks if the specified charset is acceptable. func (r *DefaultReq) AcceptsCharsets(offers ...string) string { - header := joinHeaderValues(r.c.fasthttp.Request.Header.PeekAll(HeaderAcceptCharset)) + header, bufPtr := joinHeaderValues(r.c.fasthttp.Request.Header.PeekAll(HeaderAcceptCharset)) + defer releaseHeaderValueBuffer(bufPtr) return getOffer(header, acceptsOffer, offers...) } // AcceptsEncodings checks if the specified encoding is acceptable. func (r *DefaultReq) AcceptsEncodings(offers ...string) string { - header := joinHeaderValues(r.c.fasthttp.Request.Header.PeekAll(HeaderAcceptEncoding)) + header, bufPtr := joinHeaderValues(r.c.fasthttp.Request.Header.PeekAll(HeaderAcceptEncoding)) + defer releaseHeaderValueBuffer(bufPtr) return getOffer(header, acceptsOffer, offers...) } // AcceptsLanguages checks if the specified language is acceptable using // RFC 4647 Basic Filtering. func (r *DefaultReq) AcceptsLanguages(offers ...string) string { - header := joinHeaderValues(r.c.fasthttp.Request.Header.PeekAll(HeaderAcceptLanguage)) + header, bufPtr := joinHeaderValues(r.c.fasthttp.Request.Header.PeekAll(HeaderAcceptLanguage)) + defer releaseHeaderValueBuffer(bufPtr) return getOffer(header, acceptsLanguageOfferBasic, offers...) } // AcceptsLanguagesExtended checks if the specified language is acceptable using // RFC 4647 Extended Filtering. func (r *DefaultReq) AcceptsLanguagesExtended(offers ...string) string { - header := joinHeaderValues(r.c.fasthttp.Request.Header.PeekAll(HeaderAcceptLanguage)) + header, bufPtr := joinHeaderValues(r.c.fasthttp.Request.Header.PeekAll(HeaderAcceptLanguage)) + defer releaseHeaderValueBuffer(bufPtr) return getOffer(header, acceptsLanguageOfferExtended, offers...) } @@ -140,7 +269,6 @@ func (r *DefaultReq) Body() []byte { err error body, originalBody []byte headerEncoding string - encodingOrder = []string{"", "", ""} ) request := &r.c.fasthttp.Request @@ -153,9 +281,13 @@ func (r *DefaultReq) Body() []byte { return r.getBody() } + encodingOrderPtr := acquireEncodingOrder() + encodingOrder := getSplicedStrList(headerEncoding, *encodingOrderPtr) + *encodingOrderPtr = encodingOrder + defer releaseEncodingOrder(encodingOrderPtr) + // Split and get the encodings list, in order to attend the // rule defined at: https://www.rfc-editor.org/rfc/rfc9110#section-8.4-5 - encodingOrder = getSplicedStrList(headerEncoding, encodingOrder) for i := range encodingOrder { encodingOrder[i] = utils.ToLower(encodingOrder[i]) } @@ -375,7 +507,7 @@ func (r *DefaultReq) extractIPsFromHeader(header string) []string { // Avoid big allocation on big header maxEstimatedCount) - ipsFound := make([]string, 0, estimatedCount) + slicePtr, ipsFound := r.acquireIPSlices(estimatedCount) i := 0 j := -1 @@ -417,6 +549,10 @@ iploop: ipsFound = append(ipsFound, s) } + if slicePtr != nil { + *slicePtr = ipsFound + } + return ipsFound } @@ -908,6 +1044,10 @@ func (r *DefaultReq) IsFromLocal() bool { // Release is a method to reset Req fields when to use ReleaseCtx() func (r *DefaultReq) release() { + if r.ipSlicePtr != nil { + releaseIPSliceBuffer(r.ipSlicePtr) + r.ipSlicePtr = nil + } r.c = nil } diff --git a/req_interface_gen.go b/req_interface_gen.go index 77b5549719a..2a83581efa7 100644 --- a/req_interface_gen.go +++ b/req_interface_gen.go @@ -10,6 +10,7 @@ import ( // Req is an interface for request-related Ctx methods. type Req interface { + acquireIPSlices(size int) (*[]string, []string) // Accepts checks if the specified extensions or content types are acceptable. Accepts(offers ...string) string // AcceptsCharsets checks if the specified charset is acceptable. diff --git a/res.go b/res.go index b7f3176d255..38978f8e034 100644 --- a/res.go +++ b/res.go @@ -3,6 +3,7 @@ package fiber import ( "bufio" "bytes" + "errors" "fmt" "html/template" "io" @@ -12,6 +13,7 @@ import ( "path/filepath" "strconv" "strings" + "sync" "time" utils "github.com/gofiber/utils/v2" @@ -19,6 +21,114 @@ import ( "github.com/valyala/fasthttp" ) +const ( + resHeaderMapDefaultCap = 8 + resHeaderMapMaxEntries = 64 + resHeaderValuesDefaultCap = 1 + resHeaderValuesMaxLen = 16 + resHeaderScratchMaxKeep = 4 +) + +var ( + resHeaderValuesPool = sync.Pool{ //nolint:gochecknoglobals // shared buffer pool + New: func() any { + slice := make([]string, 0, resHeaderValuesDefaultCap) + return &slice + }, + } + + responseHeaderScratchPool = sync.Pool{ //nolint:gochecknoglobals // shared scratch pool + New: func() any { + return &responseHeaderScratch{} + }, + } +) + +type responseHeaderScratch struct { + headers map[string][]string + valuePtrs map[string]*[]string +} + +func acquireResponseHeaderScratch() *responseHeaderScratch { + scratchAny := responseHeaderScratchPool.Get() + scratch, ok := scratchAny.(*responseHeaderScratch) + if !ok { + panic(errors.New("failed to type-assert to *responseHeaderScratch")) + } + scratch.prepare() + return scratch +} + +func releaseResponseHeaderScratch(s *responseHeaderScratch) { + if s == nil { + return + } + s.prepare() + responseHeaderScratchPool.Put(s) +} + +func (s *responseHeaderScratch) prepare() { + if s.headers == nil { + s.headers = make(map[string][]string, resHeaderMapDefaultCap) + } else { + mapLen := len(s.headers) + if mapLen > 0 { + clear(s.headers) + } + if mapLen > resHeaderMapMaxEntries { + s.headers = make(map[string][]string, resHeaderMapDefaultCap) + } + } + + if s.valuePtrs == nil { + s.valuePtrs = make(map[string]*[]string, resHeaderMapDefaultCap) + return + } + + ptrCount := len(s.valuePtrs) + for key, ptr := range s.valuePtrs { + releaseResHeaderValues(ptr) + delete(s.valuePtrs, key) + } + if ptrCount > resHeaderMapMaxEntries { + s.valuePtrs = make(map[string]*[]string, resHeaderMapDefaultCap) + } +} + +func acquireResHeaderValues() *[]string { + sliceAny := resHeaderValuesPool.Get() + slicePtr, ok := sliceAny.(*[]string) + if !ok { + panic(errors.New("failed to type-assert to *[]string")) + } + slice := *slicePtr + if len(slice) > 0 { + slice = slice[:0] + } + *slicePtr = slice + return slicePtr +} + +func releaseResHeaderValues(slicePtr *[]string) { + if slicePtr == nil { + return + } + + slice := *slicePtr + if len(slice) > 0 { + clear(slice) + } + + if cap(slice) > resHeaderValuesMaxLen { + slice = make([]string, 0, resHeaderValuesDefaultCap) + } else { + slice = slice[:0] + } + + *slicePtr = slice + resHeaderValuesPool.Put(slicePtr) +} + // SendFile defines configuration options when to transfer file with SendFile. type SendFile struct { // FS is the file system to serve the static files from. @@ -123,7 +233,8 @@ type ResFmt struct { // //go:generate ifacemaker --file res.go --struct DefaultRes --iface Res --pkg fiber --output res_interface_gen.go --not-exported true --iface-comment "Res is an interface for response-related Ctx methods." type DefaultRes struct { - c *DefaultCtx + c *DefaultCtx + headerScratch []*responseHeaderScratch } // App returns the *App reference to the instance of the Fiber application @@ -415,13 +526,24 @@ func (r *DefaultRes) Get(key string, defaultValue ...string) string { // Returned value is only valid within the handler. Do not store any references. // Make copies or use the Immutable setting instead. func (r *DefaultRes) GetHeaders() map[string][]string { + scratch := acquireResponseHeaderScratch() app := r.c.app - headers := make(map[string][]string) - for k, v := range r.c.fasthttp.Response.Header.All() { - key := app.toString(k) - headers[key] = append(headers[key], app.toString(v)) - } - return headers + + r.c.fasthttp.Response.Header.VisitAll(func(key, value []byte) { + headerKey := app.toString(key) + valuesPtr, ok := scratch.valuePtrs[headerKey] + if !ok { + valuesPtr = acquireResHeaderValues() + scratch.valuePtrs[headerKey] = valuesPtr + } + + values := append(*valuesPtr, app.toString(value)) + *valuesPtr = values + scratch.headers[headerKey] = values + }) + + r.headerScratch = append(r.headerScratch, scratch) + return scratch.headers } // JSON converts any interface or string to JSON. @@ -972,6 +1094,17 @@ func (r *DefaultRes) WriteString(s string) (int, error) { // Release is a method to reset Res fields when to use ReleaseCtx() func (r *DefaultRes) release() { + for i := range r.headerScratch { + releaseResponseHeaderScratch(r.headerScratch[i]) + r.headerScratch[i] = nil + } + + if cap(r.headerScratch) > resHeaderScratchMaxKeep { + r.headerScratch = nil + } else { + r.headerScratch = r.headerScratch[:0] + } + r.c = nil } diff --git a/router.go b/router.go index f9ffcb3fb81..40ed24a7316 100644 --- a/router.go +++ b/router.go @@ -6,8 +6,10 @@ package fiber import ( "bytes" + "errors" "fmt" "slices" + "sync" "sync/atomic" utils "github.com/gofiber/utils/v2" @@ -60,6 +62,41 @@ type Route struct { root bool // Path equals '/' } +const ( + removedUseRouteDefaultCap = 8 + removedUseRouteMaxEntries = 64 +) + +var removedUseRoutePool = sync.Pool{ + New: func() any { + return make(map[string]struct{}, removedUseRouteDefaultCap) + }, +} + +func acquireRemovedUseRouteSet() map[string]struct{} { + set, ok := removedUseRoutePool.Get().(map[string]struct{}) + if !ok { + panic(errors.New("failed to type-assert to map[string]struct{}")) + } + if len(set) > 0 { + clear(set) + } + return set +} + +func releaseRemovedUseRouteSet(set map[string]struct{}) { + if set == nil { + return + } + used := len(set) + clear(set) + if used > removedUseRouteMaxEntries { + removedUseRoutePool.Put(make(map[string]struct{}, removedUseRouteDefaultCap)) + return + } + removedUseRoutePool.Put(set) +} + func (r *Route) match(detectionPath, path string, params *[maxParams]string) bool { // root detectionPath check if r.root && len(detectionPath) == 1 && detectionPath[0] == '/' { @@ -436,7 +473,8 @@ func (app *App) deleteRoute(methods []string, matchFunc func(r *Route) bool) { app.mutex.Lock() defer app.mutex.Unlock() - removedUseRoutes := make(map[string]struct{}) + removedUseRoutes := acquireRemovedUseRouteSet() + defer releaseRemovedUseRouteSet(removedUseRoutes) for _, method := range methods { // Uppercase HTTP methods diff --git a/services.go b/services.go index 6151a7e581a..782cef65b60 100644 --- a/services.go +++ b/services.go @@ -6,8 +6,64 @@ import ( "fmt" "io" "strings" + "sync" ) +const ( + serviceErrorSliceDefaultCap = 4 + serviceErrorSliceMaxCap = 64 +) + +var serviceErrorSlicePool = sync.Pool{ //nolint:gochecknoglobals // shared error slice pool + New: func() any { + slice := make([]error, 0, serviceErrorSliceDefaultCap) + return &slice + }, +} + +func acquireServiceErrorSlice(size int) (*[]error, []error) { + if size < serviceErrorSliceDefaultCap { + size = serviceErrorSliceDefaultCap + } + + sliceAny := serviceErrorSlicePool.Get() + slicePtr, ok := sliceAny.(*[]error) + if !ok || slicePtr == nil { + slice := make([]error, 0, size) + return &slice, slice + } + + slice := *slicePtr + if cap(slice) < size { + slice = make([]error, 0, size) + } else { + slice = slice[:0] + } + + *slicePtr = slice + return slicePtr, slice +} + +func releaseServiceErrorSlice(slicePtr *[]error) { + if slicePtr == nil { + return + } + + slice := *slicePtr + for i := range slice { + slice[i] = nil + } + + if cap(slice) > serviceErrorSliceMaxCap { + slice = make([]error, 0, serviceErrorSliceDefaultCap) + } else { + slice = slice[:0] + } + + *slicePtr = slice + serviceErrorSlicePool.Put(slicePtr) +} + // Service is an interface that defines the methods for a service. type Service interface { // Start starts the service, returning an error if it fails. @@ -69,7 +125,12 @@ func (app *App) startServices(ctx context.Context) error { return nil } - var errs []error + errsPtr, errs := acquireServiceErrorSlice(len(app.configured.Services)) + defer func() { + *errsPtr = errs + releaseServiceErrorSlice(errsPtr) + }() + for _, srv := range app.configured.Services { if err := ctx.Err(); err != nil { // Context is canceled, return an error the soonest possible, so that @@ -97,12 +158,18 @@ func (app *App) startServices(ctx context.Context) error { // Iterates over all the started services in reverse order and tries to terminate them, // returning an error if any error occurs. func (app *App) shutdownServices(ctx context.Context) error { - if app.state.ServicesLen() == 0 { + services := app.state.Services() + if len(services) == 0 { return nil } - var errs []error - for _, srv := range app.state.Services() { + errsPtr, errs := acquireServiceErrorSlice(len(services)) + defer func() { + *errsPtr = errs + releaseServiceErrorSlice(errsPtr) + }() + + for _, srv := range services { if err := ctx.Err(); err != nil { // Context is canceled, do a best effort to terminate the services. errs = append(errs, fmt.Errorf("service %s terminate: %w", srv.String(), err)) diff --git a/services_test.go b/services_test.go index 553d17d6f3c..ac5e5f51250 100644 --- a/services_test.go +++ b/services_test.go @@ -738,3 +738,29 @@ func Benchmark_ServicesMemory(b *testing.B) { }) }) } + +func TestAcquireReleaseServiceErrorSlice(t *testing.T) { + ptr, errs := acquireServiceErrorSlice(0) + require.NotNil(t, ptr) + require.Len(t, errs, 0) + require.GreaterOrEqual(t, cap(errs), serviceErrorSliceDefaultCap) + + errs = append(errs, errors.New("boom")) + *ptr = errs + releaseServiceErrorSlice(ptr) + + ptr2, errs2 := acquireServiceErrorSlice(0) + require.NotNil(t, ptr2) + require.Len(t, errs2, 0) + require.GreaterOrEqual(t, cap(errs2), serviceErrorSliceDefaultCap) + + oversized := make([]error, 0, serviceErrorSliceMaxCap*2) + *ptr2 = oversized + releaseServiceErrorSlice(ptr2) + + ptr3, errs3 := acquireServiceErrorSlice(0) + require.NotNil(t, ptr3) + require.Len(t, errs3, 0) + require.LessOrEqual(t, cap(errs3), serviceErrorSliceMaxCap) + releaseServiceErrorSlice(ptr3) +} diff --git a/state.go b/state.go index 786f98177fd..eacc004d0f2 100644 --- a/state.go +++ b/state.go @@ -16,6 +16,61 @@ func init() { servicesStatePrefixHash = hex.EncodeToString([]byte(servicesStatePrefix + uuid.New().String())) } +const ( + serviceKeySliceDefaultCap = 4 + serviceKeySliceMaxCap = 64 +) + +var serviceKeySlicePool = sync.Pool{ //nolint:gochecknoglobals // shared slice pool + New: func() any { + slice := make([]string, 0, serviceKeySliceDefaultCap) + return &slice + }, +} + +func acquireServiceKeySlice(size int) (*[]string, []string) { + if size < serviceKeySliceDefaultCap { + size = serviceKeySliceDefaultCap + } + + sliceAny := serviceKeySlicePool.Get() + slicePtr, ok := sliceAny.(*[]string) + if !ok || slicePtr == nil { + slice := make([]string, 0, size) + return &slice, slice + } + + slice := *slicePtr + if cap(slice) < size { + slice = make([]string, 0, size) + } else { + slice = slice[:0] + } + + *slicePtr = slice + return slicePtr, slice +} + +func releaseServiceKeySlice(slicePtr *[]string) { + if slicePtr == nil { + return + } + + slice := *slicePtr + for i := range slice { + slice[i] = "" + } + + if cap(slice) > serviceKeySliceMaxCap { + slice = make([]string, 0, serviceKeySliceDefaultCap) + } else { + slice = slice[:0] + } + + *slicePtr = slice + serviceKeySlicePool.Put(slicePtr) +} + // State is a key-value store for Fiber's app in order to be used as a global storage for the app's dependencies. // It's a thread-safe implementation of a map[string]any, using sync.Map. type State struct { @@ -254,8 +309,8 @@ func (s *State) deleteService(srv Service) { } // serviceKeys returns a slice containing all keys present for services in the application's State. -func (s *State) serviceKeys() []string { - keys := make([]string, 0) +func (s *State) serviceKeys() (*[]string, []string) { + keysPtr, keys := acquireServiceKeySlice(serviceKeySliceDefaultCap) s.dependencies.Range(func(key, _ any) bool { keyStr, ok := key.(string) if !ok { @@ -270,15 +325,19 @@ func (s *State) serviceKeys() []string { return true }) - return keys + *keysPtr = keys + return keysPtr, keys } // Services returns a map containing all services present in the State. // The key is the hash of the service String() value and the value is the service itself. func (s *State) Services() map[string]Service { - services := make(map[string]Service) + keysPtr, keys := s.serviceKeys() + defer releaseServiceKeySlice(keysPtr) + + services := make(map[string]Service, len(keys)) - for _, key := range s.serviceKeys() { + for _, key := range keys { services[key] = MustGetState[Service](s, key) } diff --git a/state_test.go b/state_test.go index 1f874f5408a..d1862f18dbc 100644 --- a/state_test.go +++ b/state_test.go @@ -564,7 +564,9 @@ func TestState_Service(t *testing.T) { st := newState() require.Equal(t, 0, st.Len()) - require.Empty(t, st.serviceKeys()) + keysPtr, keys := st.serviceKeys() + require.Empty(t, keys) + releaseServiceKeySlice(keysPtr) }) t.Run("with-services", func(t *testing.T) { @@ -602,7 +604,9 @@ func TestState_Service(t *testing.T) { st.Set("key1", "value1") st.Set("key2", "value2") - require.Empty(t, st.serviceKeys()) + keysPtr, keys := st.serviceKeys() + require.Empty(t, keys) + releaseServiceKeySlice(keysPtr) }) t.Run("with-services", func(t *testing.T) { @@ -612,10 +616,11 @@ func TestState_Service(t *testing.T) { st.setService(srv1) st.setService(srv2) - keys := st.serviceKeys() + keysPtr, keys := st.serviceKeys() require.Len(t, keys, 2) require.Contains(t, keys, st.serviceKey(srv1.String())) require.Contains(t, keys, st.serviceKey(srv2.String())) + releaseServiceKeySlice(keysPtr) }) t.Run("with-services/with-keys", func(t *testing.T) { @@ -627,12 +632,13 @@ func TestState_Service(t *testing.T) { st.Set("key1", "value1") st.Set("key2", "value2") - keys := st.serviceKeys() + keysPtr, keys := st.serviceKeys() require.Len(t, keys, 2) require.Contains(t, keys, st.serviceKey(srv1.String())) require.Contains(t, keys, st.serviceKey(srv2.String())) require.NotContains(t, keys, "key1") require.NotContains(t, keys, "key2") + releaseServiceKeySlice(keysPtr) }) }) From 42e60b4efe13a48630554e8993df591992548069 Mon Sep 17 00:00:00 2001 From: Juan Calderon-Perez Date: Wed, 1 Oct 2025 15:12:05 -0400 Subject: [PATCH 2/3] Fix CI issues --- client/request.go | 4 ++-- middleware/adaptor/adaptor.go | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/client/request.go b/client/request.go index 9596be7a0d7..7ebc64306b6 100644 --- a/client/request.go +++ b/client/request.go @@ -280,13 +280,13 @@ func (r *Request) Headers() iter.Seq2[string, []string] { defer releaseHeaderKeySlice(keysPtr) for _, key := range keys { - vals := r.header.PeekAll(key) + vals := r.header.PeekAll(utils.UnsafeString(key)) valsStr := make([]string, len(vals)) for i, v := range vals { valsStr[i] = utils.UnsafeString(v) } - if !yield(key, valsStr) { + if !yield(utils.UnsafeString(key), valsStr) { return } } diff --git a/middleware/adaptor/adaptor.go b/middleware/adaptor/adaptor.go index c3475ef4205..1afb2eeab7d 100644 --- a/middleware/adaptor/adaptor.go +++ b/middleware/adaptor/adaptor.go @@ -2,6 +2,7 @@ package adaptor import ( "errors" + "fmt" "io" "net" "net/http" From 98f436a64339b5542d773a85d17e9a86b7bf0a64 Mon Sep 17 00:00:00 2001 From: Juan Calderon-Perez Date: Wed, 1 Oct 2025 17:38:07 -0400 Subject: [PATCH 3/3] More lint fixes --- client/cookiejar.go | 9 +-------- client/cookiejar_test.go | 4 ++-- client/hooks.go | 15 ++++++++------- client/request_test.go | 4 ++-- client/response_test.go | 2 +- ctx_test.go | 2 +- helpers_test.go | 5 ++--- middleware/cache/cache_test.go | 2 +- middleware/idempotency/response.go | 4 +++- middleware/logger/tags.go | 9 ++++++--- path_test.go | 4 ++-- redirect_test.go | 4 ++-- services_test.go | 6 +++--- 13 files changed, 34 insertions(+), 36 deletions(-) diff --git a/client/cookiejar.go b/client/cookiejar.go index 4d4623aa27c..e1d1a76e2f7 100644 --- a/client/cookiejar.go +++ b/client/cookiejar.go @@ -133,8 +133,6 @@ func (cj *CookieJar) getCookiesByHost(host string) []*fasthttp.Cookie { } // cookiesForRequest returns cookies that match the given host, path and security settings. -// -//nolint:revive // secure is required to filter Secure cookies based on scheme func acquireCookieMatches() *[]*fasthttp.Cookie { sliceAny := cookieJarMatchPool.Get() matchesPtr, ok := sliceAny.(*[]*fasthttp.Cookie) @@ -170,18 +168,13 @@ func releaseCookieMatches(matchesPtr *[]*fasthttp.Cookie) { cookieJarMatchPool.Put(matchesPtr) } -func (cj *CookieJar) cookiesForRequest(host string, path []byte, secure bool) []*fasthttp.Cookie { - matches, _ := cj.collectCookiesForRequest(nil, host, path, secure) - return matches -} - func (cj *CookieJar) borrowCookiesForRequest(host string, path []byte, secure bool) ([]*fasthttp.Cookie, *[]*fasthttp.Cookie) { matchesPtr := acquireCookieMatches() matches, ptr := cj.collectCookiesForRequest(matchesPtr, host, path, secure) return matches, ptr } -func (cj *CookieJar) collectCookiesForRequest( +func (cj *CookieJar) collectCookiesForRequest( //nolint:revive // Accepting a bool param is fine here matchesPtr *[]*fasthttp.Cookie, host string, path []byte, diff --git a/client/cookiejar_test.go b/client/cookiejar_test.go index 06ed3ec1269..1aeda213456 100644 --- a/client/cookiejar_test.go +++ b/client/cookiejar_test.go @@ -371,7 +371,7 @@ func Test_releaseCookieMatchesShrinksOversizedSlices(t *testing.T) { pooledPtr := acquireCookieMatches() require.NotNil(t, pooledPtr) - require.Len(t, *pooledPtr, 0) + require.Empty(t, *pooledPtr) require.LessOrEqual(t, cap(*pooledPtr), cookieJarMatchMaxCap) releaseCookieMatches(pooledPtr) @@ -400,7 +400,7 @@ func Test_CookieJar_BorrowCookiesUsesPool(t *testing.T) { releaseCookieMatches(matchesPtr) pooledPtr := acquireCookieMatches() - require.Len(t, *pooledPtr, 0) + require.Empty(t, *pooledPtr) releaseCookieMatches(pooledPtr) fasthttp.ReleaseURI(uri) diff --git a/client/hooks.go b/client/hooks.go index 0fc902347d6..509743c26a5 100644 --- a/client/hooks.go +++ b/client/hooks.go @@ -34,23 +34,24 @@ var ( randBytePool = sync.Pool{ New: func() any { - return make([]byte, 0, randByteDefaultCap) + b := make([]byte, 0, randByteDefaultCap) + return &b }, } ) func acquireRandBytes(size int) []byte { bufAny := randBytePool.Get() - buf, ok := bufAny.([]byte) + buf, ok := bufAny.(*[]byte) if !ok { - panic(errors.New("failed to type-assert to []byte")) + panic(errors.New("failed to type-assert to *[]byte")) } - if cap(buf) < size { - buf = make([]byte, size) + if cap(*buf) < size { + *buf = make([]byte, size) } - return buf[:size] + return (*buf)[:size] } func releaseRandBytes(buf []byte) { @@ -66,7 +67,7 @@ func releaseRandBytes(buf []byte) { buf = buf[:0] } - randBytePool.Put(buf) + randBytePool.Put(&buf) } const ( diff --git a/client/request_test.go b/client/request_test.go index 6926190489c..84ef4df37fd 100644 --- a/client/request_test.go +++ b/client/request_test.go @@ -390,8 +390,8 @@ func Test_requestPairPoolResetAndShrink(t *testing.T) { releasePair(p) reused := acquirePair(1) - require.Zero(t, len(reused.k)) - require.Zero(t, len(reused.v)) + require.Empty(t, reused.k) + require.Empty(t, reused.v) releasePair(reused) oversized := acquirePair(pairSliceMaxCap + 32) diff --git a/client/response_test.go b/client/response_test.go index 0bd95a77ffd..9d6a53e6d0b 100644 --- a/client/response_test.go +++ b/client/response_test.go @@ -214,7 +214,7 @@ func Test_Response_Reset_ShrinksCookieSlice(t *testing.T) { resp.Reset() - require.Len(t, resp.cookie, 0) + require.Empty(t, resp.cookie) require.Equal(t, responseCookieSliceDefaultCap, cap(resp.cookie)) } diff --git a/ctx_test.go b/ctx_test.go index 23cf52f65c9..7eb7e5d4a2b 100644 --- a/ctx_test.go +++ b/ctx_test.go @@ -6501,7 +6501,7 @@ func Test_DefaultRes_GetHeaders_ReleasesScratch(t *testing.T) { app.ReleaseCtx(customCtx) - require.Zero(t, len(ctx.DefaultRes.headerScratch)) + require.Empty(t, ctx.DefaultRes.headerScratch) } func Benchmark_Ctx_GetRespHeaders(b *testing.B) { diff --git a/helpers_test.go b/helpers_test.go index 4ff9f046d26..b52f9a34e52 100644 --- a/helpers_test.go +++ b/helpers_test.go @@ -125,7 +125,6 @@ func Test_splitLanguageTags(t *testing.T) { } for _, tc := range testCases { - tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() ptr, tags := splitLanguageTags(tc.input) @@ -151,7 +150,7 @@ func Test_releaseHeaderParams(t *testing.T) { releaseHeaderParams(params) params = acquireHeaderParams() - require.Len(t, params.values, 0) + require.Empty(t, params.values) oldMapPtr := reflect.ValueOf(params.values).Pointer() for i := 0; i < headerParamsValuesMaxEntries+5; i++ { @@ -160,7 +159,7 @@ func Test_releaseHeaderParams(t *testing.T) { releaseHeaderParams(params) params = acquireHeaderParams() - require.Len(t, params.values, 0) + require.Empty(t, params.values) require.NotEqual(t, oldMapPtr, reflect.ValueOf(params.values).Pointer()) releaseHeaderParams(params) } diff --git a/middleware/cache/cache_test.go b/middleware/cache/cache_test.go index 4f7f8c0c433..ba77d43d44d 100644 --- a/middleware/cache/cache_test.go +++ b/middleware/cache/cache_test.go @@ -107,7 +107,7 @@ func TestManagerReleaseShrinksHeaderMap(t *testing.T) { mgr.release(item) reacquired := mgr.acquire() - require.Equal(t, 0, len(reacquired.headers)) + require.Empty(t, reacquired.headers) newPtr := reflect.ValueOf(reacquired.headers).Pointer() require.NotEqual(t, originalPtr, newPtr) diff --git a/middleware/idempotency/response.go b/middleware/idempotency/response.go index 4ba1f8accd7..8280e7abacf 100644 --- a/middleware/idempotency/response.go +++ b/middleware/idempotency/response.go @@ -1,6 +1,8 @@ package idempotency -import "sync" +import ( + "sync" +) // response is a struct that represents the response of a request. // generation tool `go install github.com/tinylib/msgp@latest` diff --git a/middleware/logger/tags.go b/middleware/logger/tags.go index 7d48983d836..7d53617dca7 100644 --- a/middleware/logger/tags.go +++ b/middleware/logger/tags.go @@ -175,7 +175,8 @@ func createTagMap(cfg *Config) map[string]LogFunc { for key, values := range headers { if !firstPair { - if err := output.WriteByte('&'); err != nil { + err := output.WriteByte('&') + if err != nil { return written, err } written++ @@ -188,14 +189,16 @@ func createTagMap(cfg *Config) map[string]LogFunc { return written, err } - if err := output.WriteByte('='); err != nil { + err = output.WriteByte('=') + if err != nil { return written, err } written++ for i, value := range values { if i > 0 { - if err := output.WriteByte(','); err != nil { + err = output.WriteByte(',') + if err != nil { return written, err } written++ diff --git a/path_test.go b/path_test.go index 3b35e2a291c..498442a1e3d 100644 --- a/path_test.go +++ b/path_test.go @@ -186,8 +186,8 @@ func TestRouteParserResetBounds(t *testing.T) { parser.reset() - require.Zero(t, len(parser.segs)) - require.Zero(t, len(parser.params)) + require.Empty(t, parser.segs) + require.Empty(t, parser.params) require.Equal(t, routeParserSegDefaultCap, cap(parser.segs)) require.Equal(t, routeParserParamDefaultCap, cap(parser.params)) require.Zero(t, parser.wildCardCount) diff --git a/redirect_test.go b/redirect_test.go index f50df7ad80c..d8d21275d21 100644 --- a/redirect_test.go +++ b/redirect_test.go @@ -67,13 +67,13 @@ func Test_redirectMsgBufferPool(t *testing.T) { bufPtr := acquireRedirectMsgBuffer() require.NotNil(t, bufPtr) buf := *bufPtr - require.Equal(t, 0, len(buf)) + require.Empty(t, buf) require.GreaterOrEqual(t, cap(buf), redirectMsgBufferDefaultCap) releaseRedirectMsgBuffer(bufPtr) bufPtr = acquireRedirectMsgBuffer() - buf = *bufPtr + // buf = *bufPtr // Inflate the buffer beyond the max cap and ensure it resets on release. big := make([]byte, redirectMsgBufferMaxCap+redirectMsgBufferDefaultCap) *bufPtr = big diff --git a/services_test.go b/services_test.go index ac5e5f51250..ff16d2fce0f 100644 --- a/services_test.go +++ b/services_test.go @@ -742,7 +742,7 @@ func Benchmark_ServicesMemory(b *testing.B) { func TestAcquireReleaseServiceErrorSlice(t *testing.T) { ptr, errs := acquireServiceErrorSlice(0) require.NotNil(t, ptr) - require.Len(t, errs, 0) + require.Empty(t, errs) require.GreaterOrEqual(t, cap(errs), serviceErrorSliceDefaultCap) errs = append(errs, errors.New("boom")) @@ -751,7 +751,7 @@ func TestAcquireReleaseServiceErrorSlice(t *testing.T) { ptr2, errs2 := acquireServiceErrorSlice(0) require.NotNil(t, ptr2) - require.Len(t, errs2, 0) + require.Empty(t, errs2) require.GreaterOrEqual(t, cap(errs2), serviceErrorSliceDefaultCap) oversized := make([]error, 0, serviceErrorSliceMaxCap*2) @@ -760,7 +760,7 @@ func TestAcquireReleaseServiceErrorSlice(t *testing.T) { ptr3, errs3 := acquireServiceErrorSlice(0) require.NotNil(t, ptr3) - require.Len(t, errs3, 0) + require.Empty(t, errs3) require.LessOrEqual(t, cap(errs3), serviceErrorSliceMaxCap) releaseServiceErrorSlice(ptr3) }