diff --git a/usb/transfer_stream.go b/usb/transfer_stream.go index 4d5d882..8c79c5e 100644 --- a/usb/transfer_stream.go +++ b/usb/transfer_stream.go @@ -31,20 +31,25 @@ type stream struct { current transferIntf // total/used are the number of all/used bytes in the current transfer. total, used int - // err is the first error encountered, returned to the user as soon - // as all remaining data was read. - err error + // delayedErr is the delayed error, returned to the user after all + // remaining data was read. + delayedErr error } -func (s *stream) cleanup() { - close(s.transfers) - for t := range s.transfers { - t.cancel() - t.wait() - t.free() +func (s *stream) setDelayedErr(err error) { + if s.delayedErr == nil { + s.delayedErr = err + close(s.transfers) } } +// ReadStream is a buffer that tries to prefetch data from the IN endpoint, +// reducing the latency between subsequent Read()s. +// ReadStream keeps prefetching data until Close() is called or until +// an error is encountered. After Close(), the buffer might still have +// data left from transfers that were initiated before Close. Read()ing +// from the ReadStream will keep returning available data. When no more +// data is left, io.EOF is returned. type ReadStream struct { s *stream } @@ -56,20 +61,27 @@ type ReadStream struct { // return io.ErrClosedPipe. func (r ReadStream) Read(p []byte) (int, error) { s := r.s + if s.transfers == nil { + return 0, io.ErrClosedPipe + } if s.current == nil { t, ok := <-s.transfers if !ok { // no more transfers in flight - retErr := io.ErrClosedPipe - if s.err != nil { - retErr = s.err - s.err = nil - } - return 0, retErr + s.transfers = nil + return 0, s.delayedErr } n, err := t.wait() if err != nil { - s.err = err + // wait error aborts immediately, all remaining data is invalid. + t.free() + for t := range s.transfers { + t.cancel() + t.wait() + t.free() + } + s.transfers = nil + return n, err } s.current = t s.total = n @@ -82,26 +94,20 @@ func (r ReadStream) Read(p []byte) (int, error) { copy(p, s.current.data()[s.used:s.used+use]) s.used += use if s.used == s.total { - if s.err == nil { + if s.delayedErr == nil { if err := s.current.submit(); err == nil { // guaranteed to not block, len(transfers) == number of allocated transfers s.transfers <- s.current } else { - s.err = err + s.setDelayedErr(err) } } - if s.err != nil { + if s.delayedErr != nil { s.current.free() } s.current = nil } - var retErr error - if s.current == nil && s.err != nil { - s.cleanup() - retErr = s.err - s.err = nil - } - return use, retErr + return use, nil } // Close signals that the transfer should stop. After Close is called, @@ -109,10 +115,40 @@ func (r ReadStream) Read(p []byte) (int, error) { // in progress before returning an io.EOF error, unless another error // was encountered earlier. func (r ReadStream) Close() { - s := r.s - if s.err != nil { - s.err = io.EOF + r.s.setDelayedErr(io.EOF) +} + +// WriteStream is a buffer that will send data asynchronously, reducing +// the latency between subsequent Write()s. +type WriteStream struct { + s *stream +} + +// Write sends the data to the endpoint. Write returning a nil error doesn't +// mean that data was written to the device, only that it was written to the +// buffer. Only a call to Flush() that returns nil error guarantees that +// all transfers have succeeded. +func (w WriteStream) Write(p []byte) (int, error) { + s := w.s + written := 0 + all := len(p) + for written < all { + if s.current == nil { + s.current = <-s.transfers + s.total = len(s.current.data()) + s.used = 0 + } + use := all - written + if use > s.total { + use = s.total + } + copy(s.current.data()[s.used:], p[written:written+use]) } + return 0, nil +} + +func (w WriteStream) Flush() error { + return nil } func newStream(tt []transferIntf, submit bool) *stream { @@ -120,15 +156,14 @@ func newStream(tt []transferIntf, submit bool) *stream { transfers: make(chan transferIntf, len(tt)), } for _, t := range tt { - s.transfers <- t - } - if submit { - for _, t := range tt { + if submit { if err := t.submit(); err != nil { - s.err = err + t.free() + s.setDelayedErr(err) break } } + s.transfers <- t } return s } diff --git a/usb/transfer_stream_test.go b/usb/transfer_stream_test.go index 570f055..352973e 100644 --- a/usb/transfer_stream_test.go +++ b/usb/transfer_stream_test.go @@ -15,8 +15,12 @@ package usb import ( + "bytes" "errors" + "fmt" "io" + "reflect" + "strconv" "testing" ) @@ -42,7 +46,7 @@ func (f *fakeStreamTransfer) submit() error { return errors.New("submit() called twice") } if len(f.res) == 0 { - return io.ErrUnexpectedEOF + return errors.New("submit() called but fake result missing") } f.inFlight = true res := f.res[0] @@ -61,11 +65,13 @@ func (f *fakeStreamTransfer) wait() (int, error) { return 0, errors.New("wait() called without submit()") } if len(f.res) == 0 { - return 0, io.ErrUnexpectedEOF + return 0, errors.New("wait() called but fake result missing") } f.inFlight = false res := f.res[0] - if res.waitErr != nil { + if res.waitErr == nil { + f.res = f.res[1:] + } else { f.res = nil } return res.n, res.waitErr @@ -84,49 +90,148 @@ func (f *fakeStreamTransfer) data() []byte { return fakeTransferBuf } var sentinelError = errors.New("sentinel error") +type readRes struct { + n int + err error +} + +func (r readRes) String() string { + var buf bytes.Buffer + fmt.Fprintf(&buf, "<%d bytes", r.n) + if r.err != nil { + fmt.Fprintf(&buf, ", error: %s", r.err.Error()) + } + buf.WriteString(">") + return buf.String() +} + func TestReadStream(t *testing.T) { - transfers := []*fakeStreamTransfer{ - {res: []fakeStreamResult{ - {n: 500}, - }}, - {res: []fakeStreamResult{ - {n: 500}, - }}, - {res: []fakeStreamResult{ - {n: 123, waitErr: sentinelError}, - }}, - {res: []fakeStreamResult{ - {n: 500}, - }}, - } - intfs := make([]transferIntf, len(transfers)) - for i := range transfers { - intfs[i] = transfers[i] - } - s := ReadStream{newStream(intfs, true)} - buf := make([]byte, 400) - for _, rs := range []struct { - want int - err error + for tcNum, tc := range []struct { + desc string + closeBefore int + // transfers is a list of allocated transfers, each transfers + // carries a list of results for subsequent submits/waits. + transfers [][]fakeStreamResult + want []readRes }{ - {400, nil}, - {100, nil}, - {400, nil}, - {100, nil}, - {123, sentinelError}, - {0, io.ErrClosedPipe}, + { + desc: "two transfers submitted, close, read returns both and EOF", + closeBefore: 1, + transfers: [][]fakeStreamResult{ + {{n: 400}}, + {{n: 400}}, + }, + want: []readRes{ + {n: 400}, + {n: 400}, + {err: io.EOF}, + {err: io.ErrClosedPipe}, + }, + }, + { + desc: "two transfers, two and a half cycles through transfer queue", + closeBefore: 4, + transfers: [][]fakeStreamResult{ + {{n: 400}, {n: 400}, {n: 400}, {waitErr: errors.New("fake wait error")}}, + {{n: 400}, {n: 400}, {waitErr: errors.New("fake wait error")}}, + }, + want: []readRes{ + {n: 400}, + {n: 400}, + {n: 400}, + {n: 400}, + {n: 400}, + {err: io.EOF}, + {err: io.ErrClosedPipe}, + }, + }, + { + desc: "4 transfers submitted, two return, third fails on wait", + transfers: [][]fakeStreamResult{ + {{n: 500}}, + {{n: 500}}, + {{n: 123, waitErr: sentinelError}}, + {{n: 500}}, + }, + want: []readRes{ + {n: 400}, + {n: 100}, + {n: 400}, + {n: 100}, + {n: 123, err: sentinelError}, + {err: io.ErrClosedPipe}, + }, + }, + { + desc: "2 transfers, second submit fails initialization but error overshadowed by wait error", + transfers: [][]fakeStreamResult{ + {{n: 123, waitErr: sentinelError}}, + {{submitErr: errors.New("fake submit error")}}, + }, + want: []readRes{ + {n: 123, err: sentinelError}, + {err: io.ErrClosedPipe}, + }, + }, + { + desc: "2 transfers, second submit fails during initialization", + transfers: [][]fakeStreamResult{ + {{n: 400}}, + {{submitErr: sentinelError}}, + }, + want: []readRes{ + {n: 400}, + {err: sentinelError}, + {err: io.ErrClosedPipe}, + }, + }, + { + desc: "2 transfers, 3rd submit fails during second round", + transfers: [][]fakeStreamResult{ + {{n: 400}, {submitErr: sentinelError}}, + {{n: 400}}, + }, + want: []readRes{ + {n: 400}, + {n: 400}, + {err: sentinelError}, + {err: io.ErrClosedPipe}, + }, + }, } { - n, err := s.Read(buf) - if n != rs.want { - t.Errorf("Read(): got %d bytes, want %d", n, rs.want) - } - if err != rs.err { - t.Errorf("Read(): got error %v, want %v", err, rs.err) - } - } - for i := range transfers { - if !transfers[i].released { - t.Errorf("Transfer #%d was not freed after stream completed", i) - } + t.Run(strconv.Itoa(tcNum), func(t *testing.T) { + t.Logf("Case %d: %s", tcNum, tc.desc) + ftt := make([]*fakeStreamTransfer, len(tc.transfers)) + tt := make([]transferIntf, len(tc.transfers)) + for i := range tc.transfers { + ftt[i] = &fakeStreamTransfer{ + res: tc.transfers[i], + } + tt[i] = ftt[i] + } + s := ReadStream{newStream(tt, true)} + buf := make([]byte, 400) + got := make([]readRes, len(tc.want)) + for i := range tc.want { + if i == tc.closeBefore-1 { + t.Logf("Close()", tcNum) + s.Close() + } + n, err := s.Read(buf) + t.Logf("Read(): got %d, %v", tcNum, n, err) + got[i] = readRes{ + n: n, + err: err, + } + } + if !reflect.DeepEqual(got, tc.want) { + t.Errorf("Got Read() results:\n%v\nwant Read() results:\n%v\n", got, tc.want) + } + for i := range ftt { + if !ftt[i].released { + t.Errorf("Transfer #%d was not freed after stream completed", i) + } + } + }) } }