aboutsummaryrefslogtreecommitdiff
path: root/fasthttputil
diff options
context:
space:
mode:
authorGravatar Kazushi Kitaya <kazushikitaya@gmail.com> 2019-09-05 00:57:51 +0900
committerGravatar Erik Dubbelboer <erik@dubbelboer.com> 2019-09-04 17:57:51 +0200
commit8713335f548d2c18effe36c9686f5a88b65eefce (patch)
tree8c2615c54814c6b4ad4026507ea0c752869253dc /fasthttputil
parentPropagating custom SkipBody value to allow explicitly skip body reading for r... (diff)
downloadfasthttp-8713335f548d2c18effe36c9686f5a88b65eefce.tar.gz
fasthttp-8713335f548d2c18effe36c9686f5a88b65eefce.tar.bz2
fasthttp-8713335f548d2c18effe36c9686f5a88b65eefce.zip
Fix data race in fasthttputil.pipeConn (#645)
* add tests for fasthttputil.InmemoryListener * fix data race in pipeConn * update use of readDeadlineChLock
Diffstat (limited to 'fasthttputil')
-rw-r--r--fasthttputil/inmemory_listener_test.go92
-rw-r--r--fasthttputil/pipeconns.go12
2 files changed, 102 insertions, 2 deletions
diff --git a/fasthttputil/inmemory_listener_test.go b/fasthttputil/inmemory_listener_test.go
index 86aab68..19cec0c 100644
--- a/fasthttputil/inmemory_listener_test.go
+++ b/fasthttputil/inmemory_listener_test.go
@@ -2,7 +2,13 @@ package fasthttputil
import (
"bytes"
+ "context"
"fmt"
+ "io"
+ "io/ioutil"
+ "net"
+ "net/http"
+ "sync"
"testing"
"time"
)
@@ -90,3 +96,89 @@ func TestInmemoryListener(t *testing.T) {
t.Fatalf("timeout")
}
}
+
+// echoServerHandler implements http.Handler.
+type echoServerHandler struct {
+ t *testing.T
+}
+
+func (s *echoServerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(200)
+ time.Sleep(time.Millisecond * 100)
+ if _, err := io.Copy(w, r.Body); err != nil {
+ s.t.Fatalf("unexpected error: %s", err)
+ }
+}
+
+func testInmemoryListenerHTTP(t *testing.T, f func(t *testing.T, client *http.Client)) {
+ ln := NewInmemoryListener()
+ defer ln.Close()
+
+ client := &http.Client{
+ Transport: &http.Transport{
+ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
+ return ln.Dial()
+ },
+ },
+ Timeout: time.Second,
+ }
+
+ server := &http.Server{
+ Handler: &echoServerHandler{t},
+ }
+
+ go func() {
+ if err := server.Serve(ln); err != nil && err != http.ErrServerClosed {
+ t.Fatalf("unexpected error: %s", err)
+ }
+ }()
+
+ f(t, client)
+
+ ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100)
+ defer cancel()
+ server.Shutdown(ctx)
+}
+
+func testInmemoryListenerHTTPSingle(t *testing.T, client *http.Client, content string) {
+ res, err := client.Post("http://...", "text/plain", bytes.NewBufferString(content))
+ if err != nil {
+ t.Fatalf("unexpected error: %s", err)
+ }
+ b, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("unexpected error: %s", err)
+ }
+ s := string(b)
+ if string(b) != content {
+ t.Fatalf("unexpected response %s, expecting %s", s, content)
+ }
+}
+
+func TestInmemoryListenerHTTPSingle(t *testing.T) {
+ testInmemoryListenerHTTP(t, func(t *testing.T, client *http.Client) {
+ testInmemoryListenerHTTPSingle(t, client, "request")
+ })
+}
+
+func TestInmemoryListenerHTTPSerial(t *testing.T) {
+ testInmemoryListenerHTTP(t, func(t *testing.T, client *http.Client) {
+ for i := 0; i < 10; i++ {
+ testInmemoryListenerHTTPSingle(t, client, fmt.Sprintf("request_%d", i))
+ }
+ })
+}
+
+func TestInmemoryListenerHTTPConcurrent(t *testing.T) {
+ testInmemoryListenerHTTP(t, func(t *testing.T, client *http.Client) {
+ var wg sync.WaitGroup
+ for i := 0; i < 10; i++ {
+ wg.Add(1)
+ go func(i int) {
+ defer wg.Done()
+ testInmemoryListenerHTTPSingle(t, client, fmt.Sprintf("request_%d", i))
+ }(i)
+ }
+ wg.Wait()
+ })
+}
diff --git a/fasthttputil/pipeconns.go b/fasthttputil/pipeconns.go
index aa92b6f..c6ca39d 100644
--- a/fasthttputil/pipeconns.go
+++ b/fasthttputil/pipeconns.go
@@ -87,6 +87,8 @@ type pipeConn struct {
readDeadlineCh <-chan time.Time
writeDeadlineCh <-chan time.Time
+
+ readDeadlineChLock sync.Mutex
}
func (c *pipeConn) Write(p []byte) (int, error) {
@@ -158,9 +160,12 @@ func (c *pipeConn) readNextByteBuffer(mayBlock bool) error {
if !mayBlock {
return errWouldBlock
}
+ c.readDeadlineChLock.Lock()
+ readDeadlineCh := c.readDeadlineCh
+ c.readDeadlineChLock.Unlock()
select {
case c.b = <-c.rCh:
- case <-c.readDeadlineCh:
+ case <-readDeadlineCh:
c.readDeadlineCh = closedDeadlineCh
// rCh may contain data when deadline is reached.
// Read the data before returning ErrTimeout.
@@ -214,7 +219,10 @@ func (c *pipeConn) SetReadDeadline(deadline time.Time) error {
if c.readDeadlineTimer == nil {
c.readDeadlineTimer = time.NewTimer(time.Hour)
}
- c.readDeadlineCh = updateTimer(c.readDeadlineTimer, deadline)
+ readDeadlineCh := updateTimer(c.readDeadlineTimer, deadline)
+ c.readDeadlineChLock.Lock()
+ c.readDeadlineCh = readDeadlineCh
+ c.readDeadlineChLock.Unlock()
return nil
}