diff --git a/stream.go b/stream.go index 5970877..1ebefc4 100644 --- a/stream.go +++ b/stream.go @@ -7,7 +7,6 @@ import ( "io/ioutil" "log" "net/http" - "sync" "time" ) @@ -29,9 +28,7 @@ type Stream struct { // Logger is a logger that, when set, will be used for logging debug messages Logger *log.Logger // isClosed is a marker that the stream is/should be closed - isClosed bool - // isClosedMutex is a mutex protecting concurrent read/write access of isClosed - isClosedMutex sync.RWMutex + done chan int } type SubscriptionError struct { @@ -69,9 +66,9 @@ func SubscribeWith(lastEventId string, client *http.Client, request *http.Reques retry: time.Millisecond * 3000, Events: make(chan Event), Errors: make(chan error), + done: make(chan int), } stream.c.CheckRedirect = checkRedirect - r, err := stream.connect() if err != nil { return nil, err @@ -82,25 +79,12 @@ func SubscribeWith(lastEventId string, client *http.Client, request *http.Reques // Close will close the stream. It is safe for concurrent access and can be called multiple times. func (stream *Stream) Close() { - if stream.isStreamClosed() { + select { + case <-stream.done: return + default: + close(stream.done) } - - stream.markStreamClosed() - close(stream.Errors) - close(stream.Events) -} - -func (stream *Stream) isStreamClosed() bool { - stream.isClosedMutex.RLock() - defer stream.isClosedMutex.RUnlock() - return stream.isClosed -} - -func (stream *Stream) markStreamClosed() { - stream.isClosedMutex.Lock() - defer stream.isClosedMutex.Unlock() - stream.isClosed = true } // Go's http package doesn't copy headers across when it encounters @@ -133,64 +117,75 @@ func (stream *Stream) connect() (r io.ReadCloser, err error) { Code: resp.StatusCode, Message: string(message), } + resp.Body.Close() } r = resp.Body return } func (stream *Stream) stream(r io.ReadCloser) { - defer r.Close() - - // receives events until an error is encountered - stream.receiveEvents(r) - - // tries to reconnect and start the stream again - stream.retryRestartStream() -} - -func (stream *Stream) receiveEvents(r io.ReadCloser) { - dec := NewDecoder(r) + defer stream.shutdown() + defer func() { + if r != nil { + r.Close() + } + }() + var err error + backoff := stream.retry for { - ev, err := dec.Decode() - if stream.isStreamClosed() { + select { + case <-stream.done: return + default: + if r != nil { + // Stream is closed by Close method, if err equals to nil + if err := stream.receiveEvents(r); err == nil { + return + } + } + if stream.Logger != nil { + stream.Logger.Printf("Reconnecting in %0.4f secs\n", backoff.Seconds()) + } + r, err = stream.connect() + if err != nil { + stream.Errors <- err + time.Sleep(backoff) + backoff *= 2 + continue + } + // reset backoff after successfully setting up the connection + backoff = stream.retry } - if err != nil { - stream.Errors <- err - return - } - - pub := ev.(*publication) - if pub.Retry() > 0 { - stream.retry = time.Duration(pub.Retry()) * time.Millisecond - } - if len(pub.Id()) > 0 { - stream.lastEventId = pub.Id() - } - stream.Events <- ev } } -func (stream *Stream) retryRestartStream() { - backoff := stream.retry +func (stream *Stream) shutdown() { + close(stream.Errors) + close(stream.Events) +} + +func (stream *Stream) receiveEvents(r io.ReadCloser) error { + dec := NewDecoder(r) + for { - if stream.Logger != nil { - stream.Logger.Printf("Reconnecting in %0.4f secs\n", backoff.Seconds()) - } - time.Sleep(backoff) - if stream.isStreamClosed() { - return - } - // NOTE: because of the defer we're opening the new connection - // before closing the old one. Shouldn't be a problem in practice, - // but something to be aware of. - r, err := stream.connect() - if err == nil { - go stream.stream(r) - return + select { + case <-stream.done: + return nil + default: + ev, err := dec.Decode() + if err != nil { + stream.Errors <- err + return err + } + pub := ev.(*publication) + if pub.Retry() > 0 { + stream.retry = time.Duration(pub.Retry()) * time.Millisecond + } + if len(pub.Id()) > 0 { + stream.lastEventId = pub.Id() + } + stream.Events <- ev } - stream.Errors <- err - backoff *= 2 } }