diff --git a/endpoint_stream.go b/endpoint_stream.go index b9a2717..138aebc 100644 --- a/endpoint_stream.go +++ b/endpoint_stream.go @@ -14,7 +14,7 @@ package gousb -func (e *endpoint) newStream(size, count int, submit bool) (*stream, error) { +func (e *endpoint) newStream(size, count int) (*stream, error) { var ts []transferIntf for i := 0; i < count; i++ { t, err := newUSBTransfer(e.ctx, e.h, &e.Desc, size, e.Timeout) @@ -26,7 +26,7 @@ func (e *endpoint) newStream(size, count int, submit bool) (*stream, error) { } ts = append(ts, t) } - return newStream(ts, submit), nil + return newStream(ts), nil } // NewStream prepares a new read stream that will keep reading data from the @@ -36,9 +36,23 @@ func (e *endpoint) newStream(size, count int, submit bool) (*stream, error) { // By keeping multiple transfers active at the same time, a Stream reduces // the latency between subsequent transfers and increases reading throughput. func (e *InEndpoint) NewStream(size, count int) (*ReadStream, error) { - s, err := e.newStream(size, count, true) + s, err := e.newStream(size, count) if err != nil { return nil, err } - return &ReadStream{s}, nil + s.submitAll() + return &ReadStream{s: s}, nil +} + +// NewStream prepares a new write stream that will write data in the background. +// Size defines a buffer size for a single write transaction and count +// defines how many transactions may be active at any time. +// By buffering the writes, a Stream reduces the latency between subsequent +// transfers and increases writing throughput. +func (e *OutEndpoint) NewStream(size, count int) (*WriteStream, error) { + s, err := e.newStream(size, count) + if err != nil { + return nil, err + } + return &WriteStream{s: s}, nil } diff --git a/endpoint_stream_test.go b/endpoint_stream_test.go index 51e14b5..c7418be 100644 --- a/endpoint_stream_test.go +++ b/endpoint_stream_test.go @@ -27,10 +27,11 @@ func TestEndpointReadStream(t *testing.T) { }() goodTransfers := 7 + done := make(chan struct{}) go func() { var num int for { - xfr := lib.waitForSubmitted() + xfr := lib.waitForSubmitted(done) if xfr == nil { return } @@ -83,4 +84,80 @@ func TestEndpointReadStream(t *testing.T) { if got != want { t.Errorf("stream.Read(): read %d bytes, want %d", got, want) } + close(done) +} + +func TestEndpointWriteStream(t *testing.T) { + t.Parallel() + lib := newFakeLibusb() + ctx := newContextWithImpl(lib) + defer func() { + if err := ctx.Close(); err != nil { + t.Errorf("Context.Close: %v", err) + } + }() + + done := make(chan struct{}) + total := 0 + num := 0 + go func() { + for { + xfr := lib.waitForSubmitted(done) + if xfr == nil { + return + } + xfr.length = len(xfr.buf) + xfr.status = TransferCompleted + xfr.done <- struct{}{} + num++ + total += xfr.length + } + }() + + dev, err := ctx.OpenDeviceWithVIDPID(0x9999, 0x0001) + if err != nil { + t.Fatalf("OpenDeviceWithVIDPID(9999, 0001): %v", err) + } + defer dev.Close() + cfg, err := dev.Config(1) + if err != nil { + t.Fatalf("%s.Config(1): %v", dev, err) + } + defer cfg.Close() + intf, err := cfg.Interface(0, 0) + if err != nil { + t.Fatalf("%s.Interface(0, 0): %v", cfg, err) + } + defer intf.Close() + ep, err := intf.OutEndpoint(1) + if err != nil { + t.Fatalf("%s.Endpoint(1): %v", intf, err) + } + pktSize := 1024 + stream, err := ep.NewStream(pktSize, 5) + if err != nil { + t.Fatalf("%s.NewStream(%d, 5): %v", ep, pktSize, err) + } + defer stream.Close() + for i := 0; i < 5; i++ { + if n, err := stream.Write(make([]byte, pktSize*2)); err != nil { + t.Fatalf("stream.Write: got error %v", err) + } else if n != pktSize*2 { + t.Fatalf("stream.Write: %d, want %d", n, pktSize*2) + } + } + want := pktSize * 10 + if err := stream.Close(); err != nil { + t.Fatalf("stream.Close: got error %v", err) + } + if got := stream.Written(); got != want { + t.Errorf("stream.Written: got %d, want %d", got, want) + } + done <- struct{}{} + if num != 10 { + t.Errorf("received transfers: got %d, want %d", num, 10) + } + if total != want { + t.Errorf("received data: got %d, want %d", total, want) + } } diff --git a/endpoint_test.go b/endpoint_test.go index 7c1def0..4841121 100644 --- a/endpoint_test.go +++ b/endpoint_test.go @@ -100,7 +100,7 @@ func TestEndpoint(t *testing.T) { ep := &endpoint{h: nil, ctx: ctx, InterfaceSetting: epData.intf, Desc: epData.ei} if tc.wantSubmit { go func() { - fakeT := lib.waitForSubmitted() + fakeT := lib.waitForSubmitted(nil) fakeT.length = tc.ret fakeT.status = tc.status close(fakeT.done) @@ -208,7 +208,7 @@ func TestEndpointInOut(t *testing.T) { } dataTransferred := 100 go func() { - fakeT := lib.waitForSubmitted() + fakeT := lib.waitForSubmitted(nil) fakeT.length = dataTransferred fakeT.status = TransferCompleted close(fakeT.done) @@ -232,7 +232,7 @@ func TestEndpointInOut(t *testing.T) { t.Fatalf("%s.OutEndpoint(1): got error %v, want nil", intf, err) } go func() { - fakeT := lib.waitForSubmitted() + fakeT := lib.waitForSubmitted(nil) fakeT.length = dataTransferred fakeT.status = TransferCompleted close(fakeT.done) diff --git a/fakelibusb_test.go b/fakelibusb_test.go index 338d3f2..5774d90 100644 --- a/fakelibusb_test.go +++ b/fakelibusb_test.go @@ -192,8 +192,16 @@ func (f *fakeLibusb) setIsoPacketLengths(*libusbTransfer, uint32) {} // waitForSubmitted can be used by tests to define custom behavior of the transfers submitted on the USB bus. // TODO(sebek): add fields in fakeTransfer to differentiate between different devices/endpoints used concurrently. -func (f *fakeLibusb) waitForSubmitted() *fakeTransfer { - return <-f.submitted +func (f *fakeLibusb) waitForSubmitted(done <-chan struct{}) *fakeTransfer { + select { + case t, ok := <-f.submitted: + if !ok { + return nil + } + return t + case <-done: + return nil + } } // empty can be used to confirm that all transfers were cleaned up. diff --git a/transfer_stream.go b/transfer_stream.go index 6e294fb..b28f1d2 100644 --- a/transfer_stream.go +++ b/transfer_stream.go @@ -27,18 +27,54 @@ type transferIntf interface { type stream struct { // a fifo of USB transfers. transfers chan transferIntf - // current holds the last transfer to return. - current transferIntf - // total/used are the number of all/used bytes in the current transfer. - total, used int - // delayedErr is the delayed error, returned to the user after all - // remaining data was read. - delayedErr error + // err is the first encountered error, returned to the user. + err error + // finished is true if transfers has been already closed. + finished bool } -func (s *stream) setDelayedErr(err error) { - if s.delayedErr == nil { - s.delayedErr = err +func (s *stream) gotError(err error) { + if s.err == nil { + s.err = err + } +} + +func (s *stream) noMore() { + if !s.finished { + close(s.transfers) + s.finished = true + } +} + +func (s *stream) submitAll() { + count := len(s.transfers) + var all []transferIntf + for i := 0; i < count; i++ { + all = append(all, <-s.transfers) + } + for _, t := range all { + if err := t.submit(); err != nil { + t.free() + s.gotError(err) + s.noMore() + return + } + s.transfers <- t + } + return +} + +func (s *stream) flushRemaining() { + s.noMore() + for t := range s.transfers { + t.cancel() + t.wait() + t.free() + } +} + +func (s *stream) done() { + if s.err == nil { close(s.transfers) } } @@ -52,6 +88,10 @@ func (s *stream) setDelayedErr(err error) { // data is left, io.EOF is returned. type ReadStream struct { s *stream + // current holds the last transfer to return. + current transferIntf + // total/used are the number of all/used bytes in the current transfer. + total, used int } // Read reads data from the transfer stream. @@ -60,56 +100,49 @@ type ReadStream struct { // After a non-nil error is returned, all subsequent attempts to read will // return io.ErrClosedPipe. // Read cannot be called concurrently with other Read or Close. -func (r ReadStream) Read(p []byte) (int, error) { - s := r.s - if s.transfers == nil { +func (r *ReadStream) Read(p []byte) (int, error) { + if r.s.transfers == nil { return 0, io.ErrClosedPipe } - if s.current == nil { - t, ok := <-s.transfers + if r.current == nil { + t, ok := <-r.s.transfers if !ok { // no more transfers in flight - s.transfers = nil - return 0, s.delayedErr + r.s.transfers = nil + return 0, r.s.err } n, err := t.wait() if err != nil { // wait error aborts immediately, all remaining data is invalid. t.free() - if s.delayedErr == nil { - close(s.transfers) - } - for t := range s.transfers { - t.cancel() - t.wait() - t.free() - } - s.transfers = nil + r.s.flushRemaining() + r.s.transfers = nil return n, err } - s.current = t - s.total = n - s.used = 0 + r.current = t + r.total = n + r.used = 0 } - use := s.total - s.used + use := r.total - r.used if use > len(p) { use = len(p) } - copy(p, s.current.data()[s.used:s.used+use]) - s.used += use - if s.used == s.total { - if s.delayedErr == nil { - if err := s.current.submit(); err == nil { + copy(p, r.current.data()[r.used:r.used+use]) + r.used += use + if r.used == r.total { + if r.s.err == nil { + if err := r.current.submit(); err == nil { // guaranteed to not block, len(transfers) == number of allocated transfers - s.transfers <- s.current + r.s.transfers <- r.current } else { - s.setDelayedErr(err) + r.s.gotError(err) + r.s.noMore() } } - if s.delayedErr != nil { - s.current.free() + if r.s.err != nil { + r.current.free() } - s.current = nil + r.current = nil } return use, nil } @@ -119,64 +152,112 @@ func (r ReadStream) Read(p []byte) (int, error) { // in progress before returning an io.EOF error, unless another error // was encountered earlier. // Close cannot be called concurrently with Read. -func (r ReadStream) Close() error { +func (r *ReadStream) Close() error { if r.s.transfers == nil { return nil } - r.s.setDelayedErr(io.EOF) + r.s.gotError(io.EOF) + r.s.noMore() return nil } // WriteStream is a buffer that will send data asynchronously, reducing // the latency between subsequent Write()s. -/* type WriteStream struct { - s *stream + s *stream + total int } -*/ // Write sends the data to the endpoint. Write returning a nil error doesn't // mean that data was written to the device, only that it was written to the -// buffer. Only a call to Flush() that returns nil error guarantees that +// buffer. Only a call to Close() that returns nil error guarantees that // all transfers have succeeded. -// TODO(sebek): not implemented and tested yet -/* -func (w WriteStream) Write(p []byte) (int, error) { - s := w.s +// If the slice passed to Write does not align exactly with the transfer +// buffer size (as declared in a call to NewStream), the last USB transfer +// of this Write will be sent with less data than the full buffer. +// After a non-nil error is returned, all subsequent attempts to write will +// return io.ErrClosedPipe. +// If Write encounters an error when preparing the transfer, the stream +// will still try to complete any pending transfers. The total number +// of bytes successfully written can be retrieved through a Written() +// call after Close() has returned. +// Write cannot be called concurrently with another Write, Written or Close. +func (w *WriteStream) Write(p []byte) (int, error) { + if w.s.transfers == nil || w.s.err != nil { + return 0, io.ErrClosedPipe + } written := 0 all := len(p) for written < all { - if s.current == nil { - s.current = <-s.transfers - s.total = len(s.current.data()) - s.used = 0 + t := <-w.s.transfers + n, err := t.wait() // unsubmitted transfers will return 0 bytes and no error + w.total += n + if err != nil { + t.free() + w.s.gotError(err) + // This branch is used only after all the transfers were set in flight. + // That means all transfers left in the queue are in flight. + // They must be ignored, since this wait() failed. + w.s.flushRemaining() + return written, err } use := all - written - if use > s.total { - use = s.total + if max := len(t.data()); use > max { + use = max } - copy(s.current.data()[s.used:], p[written:written+use]) + copy(t.data(), p[written:written+use]) + if err := t.submit(); err != nil { + t.free() + w.s.gotError(err) + // Even though this submit failed, all the transfers in flight are still valid. + // Don't flush remaining transfers. + // We won't submit any more transfers. + w.s.noMore() + return written, err + } + written += use + w.s.transfers <- t // guaranteed non blocking } - return 0, nil + return written, nil } -func (w WriteStream) Flush() error { - return nil +// Close signals end of data to write. Close blocks until all transfers +// that were sent are finished. The error returned by Close is the first +// error encountered during writing the entire stream (if any). +// Close returning nil indicates all transfers completed successfuly. +// After Close, the total number of bytes successfuly written can be +// retrieved using Written(). +// Close may not be called concurrently with Write, Close or Written. +func (w *WriteStream) Close() error { + if w.s.transfers == nil { + return io.ErrClosedPipe + } + w.s.noMore() + for t := range w.s.transfers { + n, err := t.wait() + w.total += n + t.free() + if err != nil { + w.s.gotError(err) + w.s.flushRemaining() + } + t.free() + } + w.s.transfers = nil + return w.s.err } -*/ -func newStream(tt []transferIntf, submit bool) *stream { +// Written returns the number of bytes successfuly written by the stream. +// Written may be called only after Close() has been called and returned. +func (w *WriteStream) Written() int { + return w.total +} + +func newStream(tt []transferIntf) *stream { s := &stream{ transfers: make(chan transferIntf, len(tt)), } for _, t := range tt { - if submit { - if err := t.submit(); err != nil { - t.free() - s.setDelayedErr(err) - break - } - } s.transfers <- t } return s diff --git a/transfer_stream_test.go b/transfer_stream_test.go index b159ca8..38614f7 100644 --- a/transfer_stream_test.go +++ b/transfer_stream_test.go @@ -62,7 +62,7 @@ func (f *fakeStreamTransfer) wait() (int, error) { return 0, errors.New("wait() called on a free()d transfer") } if !f.inFlight { - return 0, errors.New("wait() called without submit()") + return 0, nil } if len(f.res) == 0 { return 0, errors.New("wait() called but fake result missing") @@ -224,7 +224,8 @@ func TestTransferReadStream(t *testing.T) { } tt[i] = ftt[i] } - s := ReadStream{newStream(tt, true)} + s := ReadStream{s: newStream(tt)} + s.s.submitAll() buf := make([]byte, 400) got := make([]readRes, len(tc.want)) for i := range tc.want { @@ -250,3 +251,106 @@ func TestTransferReadStream(t *testing.T) { }) } } + +func TestTransferWriteStream(t *testing.T) { + t.Parallel() + for _, tc := range []struct { + desc string + transfers [][]fakeStreamResult + writes []int + want []int + total int + err error + }{ + { + desc: "successful two transfers", + transfers: [][]fakeStreamResult{ + {{n: 1500}}, + {{n: 1500}}, + {{n: 1500}}, + }, + writes: []int{3000}, + want: []int{3000}, + total: 3000, + }, + { + desc: "submit failed on second transfer", + transfers: [][]fakeStreamResult{ + {{n: 1500}}, + {{submitErr: errSentinel}}, + {{n: 1500}}, + }, + writes: []int{3000}, + want: []int{1500}, + total: 1500, + err: errSentinel, + }, + { + desc: "wait failed on second transfer", + transfers: [][]fakeStreamResult{ + {{n: 1500}}, + {{waitErr: errSentinel}}, + {{n: 1500}}, + }, + writes: []int{3000, 1500}, + want: []int{3000, 1500}, + total: 1500, + err: errSentinel, + }, + { + desc: "reused transfer", + transfers: [][]fakeStreamResult{ + {{n: 1500}, {n: 1500}}, + {{n: 1500}, {n: 1500}}, + {{n: 1500}, {n: 500}}, + }, + writes: []int{3000, 3000, 2000}, + want: []int{3000, 3000, 2000}, + total: 8000, + }, + { + desc: "wait failed on reused transfer", + transfers: [][]fakeStreamResult{ + {{n: 1500}, {n: 1500}}, + {{waitErr: errSentinel}, {n: 1500}}, + {{n: 1500}, {n: 1500}}, + }, + writes: []int{1500, 1500, 1500, 1500, 1500}, + want: []int{1500, 1500, 1500, 1500, 0}, + total: 1500, + err: errSentinel, + }, + } { + tc := tc + t.Run(tc.desc, func(t *testing.T) { + t.Parallel() + ftt := make([]*fakeStreamTransfer, len(tc.transfers)) + tt := make([]transferIntf, len(tc.transfers)) + for i := range tc.transfers { + ftt[i] = &fakeStreamTransfer{ + res: tc.transfers[i], + } + tt[i] = ftt[i] + } + s := WriteStream{s: newStream(tt)} + for i, w := range tc.writes { + got, err := s.Write(make([]byte, w)) + if want := tc.want[i]; got != want { + t.Errorf("WriteStream.Write #%d: got %d, want %d", i, got, want) + } + if err != nil && err != tc.err { + t.Errorf("WriteStream.Write: got error %v, want %v", err, tc.err) + } + } + if err := s.Close(); err != tc.err { + t.Fatalf("WriteStream.Close: got %v, want %v", err, tc.err) + } + if err := s.Close(); err != io.ErrClosedPipe { + t.Fatalf("second WriteStream.Close: got %v, want %v", err, io.ErrClosedPipe) + } + if got := s.Written(); got != tc.total { + t.Fatalf("WriteStream.Written: got %d, want %d", got, tc.total) + } + }) + } +} diff --git a/transfer_test.go b/transfer_test.go index 4c8a1de..372a27c 100644 --- a/transfer_test.go +++ b/transfer_test.go @@ -99,19 +99,19 @@ func TestTransferProtocol(t *testing.T) { } go func() { - ft := f.waitForSubmitted() + ft := f.waitForSubmitted(nil) ft.length = 5 ft.status = TransferCompleted copy(ft.buf, []byte{1, 2, 3, 4, 5}) ft.done <- struct{}{} - ft = f.waitForSubmitted() + ft = f.waitForSubmitted(nil) ft.length = 99 ft.status = TransferCompleted copy(ft.buf, []byte{12, 12, 12, 12, 12}) ft.done <- struct{}{} - ft = f.waitForSubmitted() + ft = f.waitForSubmitted(nil) ft.length = 123 ft.status = TransferCancelled ft.done <- struct{}{}