Skip to content

Commit aa2f7c5

Browse files
authored
🐛 bug: Fix enforcement of Immutable config for some edge cases (#3835)
1 parent fa098a7 commit aa2f7c5

File tree

4 files changed

+196
-17
lines changed

4 files changed

+196
-17
lines changed

client.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -800,9 +800,7 @@ func printDebugInfo(req *Request, resp *Response, w io.Writer) {
800800
func (a *Agent) String() (int, string, []error) {
801801
defer a.release()
802802
code, body, errs := a.bytes()
803-
// TODO: There might be a data race here on body. Maybe use utils.CopyBytes on it?
804-
805-
return code, utils.UnsafeString(body), errs
803+
return code, string(body), errs
806804
}
807805

808806
// Struct returns the status code, bytes body and errors of URL.

client_test.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"path/filepath"
1616
"regexp"
1717
"strings"
18+
"sync/atomic"
1819
"testing"
1920
"time"
2021

@@ -989,6 +990,44 @@ func Test_Client_Agent_Reuse(t *testing.T) {
989990
utils.AssertEqual(t, 0, len(errs))
990991
}
991992

993+
func Test_Client_Agent_StringCopiesBody(t *testing.T) {
994+
t.Parallel()
995+
996+
ln := fasthttputil.NewInmemoryListener()
997+
998+
app := New(Config{DisableStartupMessage: true})
999+
1000+
var hits int32
1001+
app.Get("/", func(c *Ctx) error {
1002+
current := atomic.AddInt32(&hits, 1)
1003+
return c.SendString(fmt.Sprintf("body-%d", current))
1004+
})
1005+
1006+
go func() { utils.AssertEqual(t, nil, app.Listener(ln)) }()
1007+
1008+
a := Get("http://example.com").
1009+
Reuse()
1010+
1011+
a.HostClient.Dial = func(addr string) (net.Conn, error) { return ln.Dial() }
1012+
1013+
code, firstBody, errs := a.String()
1014+
1015+
utils.AssertEqual(t, StatusOK, code)
1016+
utils.AssertEqual(t, "body-1", firstBody)
1017+
utils.AssertEqual(t, 0, len(errs))
1018+
1019+
code, secondBody, errs := a.String()
1020+
1021+
utils.AssertEqual(t, StatusOK, code)
1022+
utils.AssertEqual(t, "body-2", secondBody)
1023+
utils.AssertEqual(t, 0, len(errs))
1024+
1025+
utils.AssertEqual(t, "body-1", firstBody)
1026+
utils.AssertEqual(t, "body-2", secondBody)
1027+
1028+
ReleaseAgent(a)
1029+
}
1030+
9921031
func Test_Client_Agent_InsecureSkipVerify(t *testing.T) {
9931032
t.Parallel()
9941033

ctx.go

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ func (c *Ctx) Body() []byte {
351351
c.fasthttp.Request.SetBodyRaw(originalBody)
352352
}
353353
if err != nil {
354-
return []byte(err.Error())
354+
return c.app.getBytes(err.Error())
355355
}
356356

357357
if c.app.config.Immutable {
@@ -423,7 +423,20 @@ func (c *Ctx) BodyParser(out interface{}) error {
423423

424424
data := make(map[string][]string)
425425
for key, values := range multipartForm.Value {
426-
err = formatParserData(out, data, bodyTag, key, values, c.app.config.EnableSplittingOnParsers, true)
426+
processedKey := key
427+
processedValues := values
428+
if c.app.config.Immutable {
429+
processedKey = c.app.getString([]byte(key))
430+
if len(values) > 0 {
431+
copied := make([]string, len(values))
432+
for i, val := range values {
433+
copied[i] = c.app.getString([]byte(val))
434+
}
435+
processedValues = copied
436+
}
437+
}
438+
439+
err = formatParserData(out, data, bodyTag, processedKey, processedValues, c.app.config.EnableSplittingOnParsers, true)
427440
if err != nil {
428441
return err
429442
}
@@ -714,17 +727,16 @@ func (c *Ctx) GetRespHeaders() map[string][]string {
714727
}
715728

716729
// Hostname contains the hostname derived from the X-Forwarded-Host or Host HTTP header.
717-
// Returned value is only valid within the handler. Do not store any references.
718-
// Make copies or use the Immutable setting instead.
730+
// Returned value is only valid within the handler. Do not store any references unless
731+
// Config.Immutable is enabled, in which case the value is copied before it is returned.
719732
// Please use Config.EnableTrustedProxyCheck to prevent header spoofing, in case when your app is behind the proxy.
720733
func (c *Ctx) Hostname() string {
721734
if c.IsProxyTrusted() {
722-
if host := c.Get(HeaderXForwardedHost); len(host) > 0 {
723-
commaPos := strings.Index(host, ",")
724-
if commaPos != -1 {
725-
return host[:commaPos]
735+
if hostBytes := c.fasthttp.Request.Header.Peek(HeaderXForwardedHost); len(hostBytes) > 0 {
736+
if commaPos := bytes.IndexByte(hostBytes, ','); commaPos != -1 {
737+
hostBytes = hostBytes[:commaPos]
726738
}
727-
return host
739+
return c.app.getString(hostBytes)
728740
}
729741
}
730742
return c.app.getString(c.fasthttp.Request.URI().Host())
@@ -1050,8 +1062,8 @@ func (c *Ctx) OriginalURL() string {
10501062
// Params is used to get the route parameters.
10511063
// Defaults to empty string "" if the param doesn't exist.
10521064
// If a default value is given, it will return that value if the param doesn't exist.
1053-
// Returned value is only valid within the handler. Do not store any references.
1054-
// Make copies or use the Immutable setting to use the value outside the Handler.
1065+
// Returned value is only valid within the handler. Do not store any references unless
1066+
// Config.Immutable is enabled, in which case the value is copied before it is returned.
10551067
func (c *Ctx) Params(key string, defaultValue ...string) string {
10561068
if key == "*" || key == "+" {
10571069
key += "1"
@@ -1065,7 +1077,11 @@ func (c *Ctx) Params(key string, defaultValue ...string) string {
10651077
if len(c.values) <= i || len(c.values[i]) == 0 {
10661078
break
10671079
}
1068-
return c.values[i]
1080+
value := c.values[i]
1081+
if c.app.config.Immutable {
1082+
return c.app.getString([]byte(value))
1083+
}
1084+
return value
10691085
}
10701086
}
10711087
return defaultString("", defaultValue)
@@ -1818,6 +1834,11 @@ func (c *Ctx) Subdomains(offset ...int) []string {
18181834
l = len(subdomains)
18191835
}
18201836
subdomains = subdomains[:l]
1837+
if c.app.config.Immutable {
1838+
for i, subdomain := range subdomains {
1839+
subdomains[i] = c.app.getString([]byte(subdomain))
1840+
}
1841+
}
18211842
return subdomains
18221843
}
18231844

@@ -1859,9 +1880,9 @@ func (c *Ctx) String() string {
18591880
buf.WriteString(" - ")
18601881

18611882
// Add method and URI
1862-
buf.Write(c.fasthttp.Request.Header.Method())
1883+
buf.WriteString(c.app.getString(c.fasthttp.Request.Header.Method()))
18631884
buf.WriteByte(' ')
1864-
buf.Write(c.fasthttp.URI().FullURI())
1885+
buf.WriteString(c.app.getString(c.fasthttp.URI().FullURI()))
18651886

18661887
// Allocate string
18671888
str := buf.String()

ctx_test.go

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2959,6 +2959,127 @@ func Test_Ctx_Subdomains(t *testing.T) {
29592959
utils.AssertEqual(t, []string{"localhost:3000"}, c.Subdomains())
29602960
}
29612961

2962+
// go test -run Test_Ctx_Immutable_AfterHandler
2963+
func Test_Ctx_Immutable_AfterHandler(t *testing.T) {
2964+
t.Parallel()
2965+
2966+
app := New(Config{
2967+
Immutable: true,
2968+
ProxyHeader: HeaderXForwardedFor,
2969+
EnableIPValidation: true,
2970+
})
2971+
2972+
type ctxSnapshot struct {
2973+
method string
2974+
path string
2975+
originalURL string
2976+
baseURL string
2977+
protocol string
2978+
hostname string
2979+
paramName string
2980+
queryFoo string
2981+
cookieSession string
2982+
headerUserAgent string
2983+
ip string
2984+
ips []string
2985+
subdomains []string
2986+
body []byte
2987+
routePath string
2988+
}
2989+
2990+
snapshots := make([]ctxSnapshot, 0, 2)
2991+
2992+
app.All("/v1/:name", func(c *Ctx) error {
2993+
snapshots = append(snapshots, ctxSnapshot{
2994+
method: c.Method(),
2995+
path: c.Path(),
2996+
originalURL: c.OriginalURL(),
2997+
baseURL: c.BaseURL(),
2998+
protocol: c.Protocol(),
2999+
hostname: c.Hostname(),
3000+
paramName: c.Params("name"),
3001+
queryFoo: c.Query("foo"),
3002+
cookieSession: c.Cookies("session"),
3003+
headerUserAgent: c.Get(HeaderUserAgent),
3004+
ip: c.IP(),
3005+
ips: c.IPs(),
3006+
subdomains: c.Subdomains(),
3007+
body: c.Body(),
3008+
routePath: c.Route().Path,
3009+
})
3010+
3011+
return c.SendString("ok")
3012+
})
3013+
3014+
req := httptest.NewRequest(MethodPost, "https://initial.invalid/v1/alpha?foo=bar", strings.NewReader("body-one"))
3015+
req.Header.Set(HeaderXForwardedHost, "p1.api.example.com")
3016+
req.Header.Set(HeaderXForwardedProto, "https")
3017+
req.Header.Set(HeaderXForwardedFor, "10.0.0.1, 10.0.0.2")
3018+
req.Header.Set(HeaderUserAgent, "agent-one")
3019+
req.Header.Set(HeaderCookie, "session=alpha")
3020+
3021+
originalFirst := req.URL.String()
3022+
3023+
resp, err := app.Test(req)
3024+
utils.AssertEqual(t, nil, err)
3025+
utils.AssertEqual(t, StatusOK, resp.StatusCode)
3026+
utils.AssertEqual(t, nil, resp.Body.Close())
3027+
3028+
utils.AssertEqual(t, 1, len(snapshots))
3029+
first := snapshots[0]
3030+
3031+
follow := httptest.NewRequest(MethodPatch, "http://secondary.invalid/v1/beta?foo=qux", strings.NewReader("body-two"))
3032+
follow.Header.Set(HeaderXForwardedHost, "edge.stage.example.org")
3033+
follow.Header.Set(HeaderXForwardedProto, "http")
3034+
follow.Header.Set(HeaderXForwardedFor, "192.168.1.50")
3035+
follow.Header.Set(HeaderUserAgent, "agent-two")
3036+
follow.Header.Set(HeaderCookie, "session=beta")
3037+
3038+
originalSecond := follow.URL.String()
3039+
3040+
resp, err = app.Test(follow)
3041+
utils.AssertEqual(t, nil, err)
3042+
utils.AssertEqual(t, StatusOK, resp.StatusCode)
3043+
utils.AssertEqual(t, nil, resp.Body.Close())
3044+
3045+
utils.AssertEqual(t, 2, len(snapshots))
3046+
second := snapshots[1]
3047+
3048+
// Ensure the first request's state stays valid after the context is released.
3049+
utils.AssertEqual(t, MethodPost, first.method)
3050+
utils.AssertEqual(t, "/v1/alpha", first.path)
3051+
utils.AssertEqual(t, originalFirst, first.originalURL)
3052+
utils.AssertEqual(t, "https://p1.api.example.com", first.baseURL)
3053+
utils.AssertEqual(t, "https", first.protocol)
3054+
utils.AssertEqual(t, "p1.api.example.com", first.hostname)
3055+
utils.AssertEqual(t, "alpha", first.paramName)
3056+
utils.AssertEqual(t, "bar", first.queryFoo)
3057+
utils.AssertEqual(t, "alpha", first.cookieSession)
3058+
utils.AssertEqual(t, "agent-one", first.headerUserAgent)
3059+
utils.AssertEqual(t, "10.0.0.1", first.ip)
3060+
utils.AssertEqual(t, []string{"10.0.0.1", "10.0.0.2"}, first.ips)
3061+
utils.AssertEqual(t, []string{"p1", "api"}, first.subdomains)
3062+
utils.AssertEqual(t, "body-one", string(first.body))
3063+
utils.AssertEqual(t, "/v1/:name", first.routePath)
3064+
3065+
// Verify the second request collected distinct immutable copies.
3066+
utils.AssertEqual(t, MethodPatch, second.method)
3067+
utils.AssertEqual(t, "/v1/beta", second.path)
3068+
utils.AssertEqual(t, originalSecond, second.originalURL)
3069+
utils.AssertEqual(t, "http://edge.stage.example.org", second.baseURL)
3070+
utils.AssertEqual(t, "http", second.protocol)
3071+
utils.AssertEqual(t, "edge.stage.example.org", second.hostname)
3072+
utils.AssertEqual(t, "beta", second.paramName)
3073+
utils.AssertEqual(t, "qux", second.queryFoo)
3074+
utils.AssertEqual(t, "beta", second.cookieSession)
3075+
utils.AssertEqual(t, "agent-two", second.headerUserAgent)
3076+
utils.AssertEqual(t, "192.168.1.50", second.ip)
3077+
utils.AssertEqual(t, []string{"192.168.1.50"}, second.ips)
3078+
utils.AssertEqual(t, []string{"edge", "stage"}, second.subdomains)
3079+
utils.AssertEqual(t, "body-two", string(second.body))
3080+
utils.AssertEqual(t, "/v1/:name", second.routePath)
3081+
}
3082+
29623083
// go test -v -run=^$ -bench=Benchmark_Ctx_Subdomains -benchmem -count=4
29633084
func Benchmark_Ctx_Subdomains(b *testing.B) {
29643085
app := New()

0 commit comments

Comments
 (0)