Skip to content

Commit ddbb831

Browse files
committed
🔥 feat: Add StreamResponseBody support for the Client
1 parent 64a7113 commit ddbb831

File tree

6 files changed

+402
-38
lines changed

6 files changed

+402
-38
lines changed

client/client.go

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,10 @@ type Client struct {
5555
userResponseHooks []ResponseHook
5656
builtinResponseHooks []ResponseHook
5757

58-
timeout time.Duration
59-
mu sync.RWMutex
60-
debug bool
58+
timeout time.Duration
59+
mu sync.RWMutex
60+
debug bool
61+
streamResponseBody bool
6162
}
6263

6364
// R creates a new Request associated with the client.
@@ -435,6 +436,20 @@ func (c *Client) DisableDebug() *Client {
435436
return c
436437
}
437438

439+
// StreamResponseBody returns the current StreamResponseBody setting.
440+
func (c *Client) StreamResponseBody() bool {
441+
return c.streamResponseBody
442+
}
443+
444+
// SetStreamResponseBody enables or disables response body streaming.
445+
// When enabled, the response body can be read as a stream using BodyStream()
446+
// instead of being fully loaded into memory. This is useful for large responses
447+
// or server-sent events.
448+
func (c *Client) SetStreamResponseBody(enable bool) *Client {
449+
c.streamResponseBody = enable
450+
return c
451+
}
452+
438453
// SetCookieJar sets the cookie jar for the client.
439454
func (c *Client) SetCookieJar(cookieJar *CookieJar) *Client {
440455
c.cookieJar = cookieJar

client/client_test.go

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1733,3 +1733,151 @@ func Benchmark_Client_Request_Parallel(b *testing.B) {
17331733
require.NoError(b, err)
17341734
})
17351735
}
1736+
1737+
func Test_Client_StreamResponseBody(t *testing.T) {
1738+
t.Parallel()
1739+
client := New()
1740+
require.False(t, client.StreamResponseBody())
1741+
client.SetStreamResponseBody(true)
1742+
require.True(t, client.StreamResponseBody())
1743+
client.SetStreamResponseBody(false)
1744+
require.False(t, client.StreamResponseBody())
1745+
}
1746+
1747+
func Test_Client_StreamResponseBody_ServerSentEvents(t *testing.T) {
1748+
t.Parallel()
1749+
1750+
app, addr := startTestServerWithPort(t, func(app *fiber.App) {
1751+
app.Get("/sse", func(c fiber.Ctx) error {
1752+
c.Set("Content-Type", "text/event-stream")
1753+
c.Set("Cache-Control", "no-cache")
1754+
c.Set("Connection", "keep-alive")
1755+
1756+
messages := []string{
1757+
"data: message 1\n\n",
1758+
"data: message 2\n\n",
1759+
"data: message 3\n\n",
1760+
}
1761+
1762+
for _, msg := range messages {
1763+
if _, err := c.WriteString(msg); err != nil {
1764+
return err
1765+
}
1766+
}
1767+
1768+
return nil
1769+
})
1770+
})
1771+
defer func() { require.NoError(t, app.Shutdown()) }()
1772+
1773+
client := New().SetStreamResponseBody(true)
1774+
resp, err := client.Get("http://" + addr + "/sse")
1775+
require.NoError(t, err)
1776+
defer resp.Close()
1777+
1778+
bodyStream := resp.BodyStream()
1779+
require.NotNil(t, bodyStream)
1780+
1781+
buffer := make([]byte, 1024)
1782+
n, err := bodyStream.Read(buffer)
1783+
require.NoError(t, err)
1784+
require.Positive(t, n)
1785+
1786+
content := string(buffer[:n])
1787+
require.Contains(t, content, "data: message 1")
1788+
}
1789+
1790+
func Test_Client_StreamResponseBody_LargeResponse(t *testing.T) {
1791+
t.Parallel()
1792+
1793+
largeData := make([]byte, 1024*1024)
1794+
for i := range largeData {
1795+
largeData[i] = byte(i % 256)
1796+
}
1797+
1798+
app, addr := startTestServerWithPort(t, func(app *fiber.App) {
1799+
app.Get("/large", func(c fiber.Ctx) error {
1800+
return c.Send(largeData)
1801+
})
1802+
})
1803+
defer func() { require.NoError(t, app.Shutdown()) }()
1804+
client := New().SetStreamResponseBody(true)
1805+
resp, err := client.Get("http://" + addr + "/large")
1806+
require.NoError(t, err)
1807+
defer resp.Close()
1808+
bodyStream := resp.BodyStream()
1809+
require.NotNil(t, bodyStream)
1810+
streamedData, err := io.ReadAll(bodyStream)
1811+
require.NoError(t, err)
1812+
require.Equal(t, largeData, streamedData)
1813+
client2 := New()
1814+
resp2, err := client2.Get("http://" + addr + "/large")
1815+
require.NoError(t, err)
1816+
defer resp2.Close()
1817+
body := resp2.Body()
1818+
require.Equal(t, largeData, body)
1819+
}
1820+
1821+
func Test_Client_StreamResponseBody_Disabled_Default(t *testing.T) {
1822+
t.Parallel()
1823+
1824+
app, addr := startTestServerWithPort(t, func(app *fiber.App) {
1825+
app.Get("/test", func(c fiber.Ctx) error {
1826+
return c.SendString("Hello, World!")
1827+
})
1828+
})
1829+
defer func() { require.NoError(t, app.Shutdown()) }()
1830+
1831+
client := New()
1832+
resp, err := client.Get("http://" + addr + "/test")
1833+
require.NoError(t, err)
1834+
defer resp.Close()
1835+
1836+
body := resp.Body()
1837+
require.Equal(t, "Hello, World!", string(body))
1838+
1839+
bodyStream := resp.BodyStream()
1840+
require.NotNil(t, bodyStream)
1841+
}
1842+
1843+
func Test_Client_StreamResponseBody_ChainableMethods(t *testing.T) {
1844+
t.Parallel()
1845+
1846+
client := New().
1847+
SetStreamResponseBody(true).
1848+
SetTimeout(time.Second * 5).
1849+
SetStreamResponseBody(false)
1850+
1851+
require.False(t, client.StreamResponseBody())
1852+
}
1853+
1854+
func Test_Request_StreamResponseBody(t *testing.T) {
1855+
t.Parallel()
1856+
1857+
app, addr := startTestServerWithPort(t, func(app *fiber.App) {
1858+
app.Get("/test", func(c fiber.Ctx) error {
1859+
return c.SendString("Hello, World!")
1860+
})
1861+
})
1862+
defer func() { require.NoError(t, app.Shutdown()) }()
1863+
1864+
client := New().SetStreamResponseBody(false) // client has streaming disabled
1865+
req := client.R().SetStreamResponseBody(true)
1866+
require.True(t, req.StreamResponseBody())
1867+
1868+
resp, err := req.Get("http://" + addr + "/test")
1869+
require.NoError(t, err)
1870+
defer resp.Close()
1871+
bodyStream := resp.BodyStream()
1872+
require.NotNil(t, bodyStream)
1873+
req2 := client.R().SetStreamResponseBody(false)
1874+
require.False(t, req2.StreamResponseBody())
1875+
clientWithStreaming := New().SetStreamResponseBody(true)
1876+
req3 := clientWithStreaming.R()
1877+
require.True(t, req3.StreamResponseBody()) // Should inherit from client
1878+
req4 := client.R().
1879+
SetStreamResponseBody(true).
1880+
SetTimeout(time.Second * 5).
1881+
SetStreamResponseBody(false)
1882+
require.False(t, req4.StreamResponseBody())
1883+
}

client/core.go

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,43 @@ func (c *core) execFunc() (*Response, error) {
8181
c.req.RawRequest.CopyTo(reqv)
8282
cfg := c.getRetryConfig()
8383

84+
// Determine which client to use - create a new one if StreamResponseBody differs
85+
var fastHTTPClient *fasthttp.Client
86+
requestStreamResponseBody := c.req.StreamResponseBody()
87+
c.client.mu.RLock()
88+
clientStream := c.client.streamResponseBody
89+
original := c.client.fasthttp
90+
91+
if requestStreamResponseBody != clientStream {
92+
// Request setting differs from client setting, create a temporary client
93+
94+
fastHTTPClient = &fasthttp.Client{
95+
Dial: original.Dial,
96+
DialDualStack: original.DialDualStack,
97+
TLSConfig: original.TLSConfig,
98+
MaxConnsPerHost: original.MaxConnsPerHost,
99+
MaxIdleConnDuration: original.MaxIdleConnDuration,
100+
MaxConnDuration: original.MaxConnDuration,
101+
ReadTimeout: original.ReadTimeout,
102+
WriteTimeout: original.WriteTimeout,
103+
ReadBufferSize: original.ReadBufferSize,
104+
WriteBufferSize: original.WriteBufferSize,
105+
MaxResponseBodySize: original.MaxResponseBodySize,
106+
NoDefaultUserAgentHeader: original.NoDefaultUserAgentHeader,
107+
DisableHeaderNamesNormalizing: original.DisableHeaderNamesNormalizing,
108+
DisablePathNormalizing: original.DisablePathNormalizing,
109+
MaxIdemponentCallAttempts: original.MaxIdemponentCallAttempts,
110+
Name: original.Name,
111+
ConfigureClient: original.ConfigureClient,
112+
113+
// Request-specific override
114+
StreamResponseBody: requestStreamResponseBody,
115+
}
116+
} else {
117+
fastHTTPClient = original
118+
}
119+
c.client.mu.RUnlock()
120+
84121
var err error
85122
go func() {
86123
respv := fasthttp.AcquireResponse()
@@ -93,15 +130,15 @@ func (c *core) execFunc() (*Response, error) {
93130
// Use an exponential backoff retry strategy.
94131
err = retry.NewExponentialBackoff(*cfg).Retry(func() error {
95132
if c.req.maxRedirects > 0 && (string(reqv.Header.Method()) == fiber.MethodGet || string(reqv.Header.Method()) == fiber.MethodHead) {
96-
return c.client.fasthttp.DoRedirects(reqv, respv, c.req.maxRedirects)
133+
return fastHTTPClient.DoRedirects(reqv, respv, c.req.maxRedirects)
97134
}
98-
return c.client.fasthttp.Do(reqv, respv)
135+
return fastHTTPClient.Do(reqv, respv)
99136
})
100137
} else {
101138
if c.req.maxRedirects > 0 && (string(reqv.Header.Method()) == fiber.MethodGet || string(reqv.Header.Method()) == fiber.MethodHead) {
102-
err = c.client.fasthttp.DoRedirects(reqv, respv, c.req.maxRedirects)
139+
err = fastHTTPClient.DoRedirects(reqv, respv, c.req.maxRedirects)
103140
} else {
104-
err = c.client.fasthttp.Do(reqv, respv)
141+
err = fastHTTPClient.Do(reqv, respv)
105142
}
106143
}
107144

client/request.go

Lines changed: 39 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -44,30 +44,25 @@ var ErrClientNil = errors.New("client cannot be nil")
4444

4545
// Request contains all data related to an HTTP request.
4646
type Request struct {
47-
ctx context.Context //nolint:containedctx // Context is needed to be stored in the request.
48-
49-
body any
50-
header *Header
51-
params *QueryParam
52-
cookies *Cookie
53-
path *PathParam
54-
55-
client *Client
56-
57-
formData *FormData
58-
59-
RawRequest *fasthttp.Request
60-
url string
61-
method string
62-
userAgent string
63-
boundary string
64-
referer string
65-
files []*File
66-
67-
timeout time.Duration
68-
maxRedirects int
69-
70-
bodyType bodyType
47+
files []*File
48+
ctx context.Context //nolint:containedctx // Context is needed to be stored in the request.
49+
body any
50+
url string
51+
method string
52+
userAgent string
53+
boundary string
54+
referer string
55+
header *Header
56+
params *QueryParam
57+
cookies *Cookie
58+
path *PathParam
59+
client *Client
60+
formData *FormData
61+
RawRequest *fasthttp.Request
62+
streamResponseBody *bool // nil means use client setting
63+
timeout time.Duration
64+
maxRedirects int
65+
bodyType bodyType
7166
}
7267

7368
// Method returns the HTTP method set in the Request.
@@ -590,6 +585,25 @@ func (r *Request) SetMaxRedirects(count int) *Request {
590585
return r
591586
}
592587

588+
// StreamResponseBody returns the StreamResponseBody setting for this request.
589+
// Returns the client's setting if not explicitly set at the request level.
590+
func (r *Request) StreamResponseBody() bool {
591+
if r.streamResponseBody != nil {
592+
return *r.streamResponseBody
593+
}
594+
if r.client != nil {
595+
return r.client.streamResponseBody
596+
}
597+
return false
598+
}
599+
600+
// SetStreamResponseBody sets the StreamResponseBody option for this specific request,
601+
// overriding the client-level setting.
602+
func (r *Request) SetStreamResponseBody(enable bool) *Request {
603+
r.streamResponseBody = &enable
604+
return r
605+
}
606+
593607
// checkClient ensures that a Client is set. If none is set, it defaults to the global defaultClient.
594608
func (r *Request) checkClient() {
595609
if r.client == nil {
@@ -656,6 +670,7 @@ func (r *Request) Reset() {
656670
r.maxRedirects = 0
657671
r.bodyType = noBody
658672
r.boundary = boundary
673+
r.streamResponseBody = nil
659674

660675
for len(r.files) != 0 {
661676
t := r.files[0]

client/response.go

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,22 @@ func (r *Response) Body() []byte {
8989
return r.RawResponse.Body()
9090
}
9191

92+
// BodyStream returns the response body as a stream reader.
93+
// Note: When using BodyStream(), the response body is not copied to memory,
94+
// so calling Body() afterwards may return an empty slice.
95+
func (r *Response) BodyStream() io.Reader {
96+
if stream := r.RawResponse.BodyStream(); stream != nil {
97+
return stream
98+
}
99+
// If streaming is not enabled, return a bytes.Reader from the regular body
100+
return bytes.NewReader(r.RawResponse.Body())
101+
}
102+
103+
// IsStreaming returns true if the response body is being streamed.
104+
func (r *Response) IsStreaming() bool {
105+
return r.RawResponse.BodyStream() != nil
106+
}
107+
92108
// String returns the response body as a trimmed string.
93109
func (r *Response) String() string {
94110
return utils.Trim(string(r.Body()), ' ')
@@ -143,14 +159,17 @@ func (r *Response) Save(v any) error {
143159
return nil
144160

145161
case io.Writer:
146-
if _, err := io.Copy(p, bytes.NewReader(r.Body())); err != nil {
147-
return fmt.Errorf("failed to write response body to io.Writer: %w", err)
162+
var err error
163+
if r.IsStreaming() {
164+
_, err = io.Copy(p, r.BodyStream())
165+
} else {
166+
_, err = io.Copy(p, bytes.NewReader(r.Body()))
148167
}
149-
defer func() {
150-
if pc, ok := p.(io.WriteCloser); ok {
151-
_ = pc.Close() //nolint:errcheck // not needed
152-
}
153-
}()
168+
169+
if err != nil {
170+
return fmt.Errorf("failed to write response body to writer: %w", err)
171+
}
172+
154173
return nil
155174

156175
default:

0 commit comments

Comments
 (0)