From da849d96b5b1c68d3c61d81bbedd033246d584d0 Mon Sep 17 00:00:00 2001 From: Sebastian Zagrodzki Date: Mon, 5 Nov 2018 15:33:31 +0100 Subject: [PATCH] Add context-aware read/write (#57) Adds ReadContext/WriteContext methods to endpoints and to streams. Update rawread example tool to use ReadContext for implementation of "timeout" flag. --- endpoint.go | 41 +++++++++++++++----- endpoint_stream.go | 16 ++++---- endpoint_stream_test.go | 13 +++---- endpoint_test.go | 86 ++++++++++++++++++++++++++++++++++++----- fakelibusb_test.go | 50 ++++++++++++++++++++++-- libusb.go | 5 +-- rawread/main.go | 19 +++++++-- transfer.go | 19 ++++++--- transfer_stream.go | 66 +++++++++++++++++++++++++++---- transfer_stream_test.go | 3 +- transfer_test.go | 35 +++++++---------- 11 files changed, 273 insertions(+), 80 deletions(-) diff --git a/endpoint.go b/endpoint.go index 07903ed..fdb43fb 100644 --- a/endpoint.go +++ b/endpoint.go @@ -16,6 +16,7 @@ package gousb import ( + "context" "fmt" "strings" "time" @@ -76,8 +77,6 @@ type endpoint struct { InterfaceSetting Desc EndpointDesc - Timeout time.Duration - ctx *Context } @@ -86,12 +85,12 @@ func (e *endpoint) String() string { return e.Desc.String() } -func (e *endpoint) transfer(buf []byte) (int, error) { +func (e *endpoint) transfer(ctx context.Context, buf []byte) (int, error) { if len(buf) == 0 { return 0, nil } - t, err := newUSBTransfer(e.ctx, e.h, &e.Desc, len(buf), e.Timeout) + t, err := newUSBTransfer(e.ctx, e.h, &e.Desc, len(buf)) if err != nil { return 0, err } @@ -104,7 +103,7 @@ func (e *endpoint) transfer(buf []byte) (int, error) { return 0, err } - n, err := t.wait() + n, err := t.wait(ctx) if e.Desc.Direction == EndpointDirectionIn { copy(buf, t.data()) } @@ -122,9 +121,21 @@ type InEndpoint struct { *endpoint } -// Read reads data from an IN endpoint. +// Read reads data from an IN endpoint. Read returns number of bytes obtained +// from the endpoint. Read may return non-zero length even if +// the returned error is not nil (partial read). func (e *InEndpoint) Read(buf []byte) (int, error) { - return e.transfer(buf) + return e.transfer(context.Background(), buf) +} + +// ReadContext reads data from an IN endpoint. ReadContext returns number of +// bytes obtained from the endpoint. ReadContext may return non-zero length +// even if the returned error is not nil (partial read). +// The passed context can be used to control the cancellation of the read. If +// the context is cancelled, ReadContext will cancel the underlying transfers, +// resulting in TransferCancelled error. +func (e *InEndpoint) ReadContext(ctx context.Context, buf []byte) (int, error) { + return e.transfer(ctx, buf) } // OutEndpoint represents an OUT endpoint open for transfer. @@ -132,7 +143,19 @@ type OutEndpoint struct { *endpoint } -// Write writes data to an OUT endpoint. +// Write writes data to an OUT endpoint. Write returns number of bytes comitted +// to the endpoint. Write may return non-zero length even if the returned error +// is not nil (partial write). func (e *OutEndpoint) Write(buf []byte) (int, error) { - return e.transfer(buf) + return e.transfer(context.Background(), buf) +} + +// WriteContext writes data to an OUT endpoint. WriteContext returns number of +// bytes comitted to the endpoint. WriteContext may return non-zero length even +// if the returned error is not nil (partial write). +// The passed context can be used to control the cancellation of the write. If +// the context is cancelled, WriteContext will cancel the underlying transfers, +// resulting in TransferCancelled error. +func (e *OutEndpoint) WriteContext(ctx context.Context, buf []byte) (int, error) { + return e.transfer(ctx, buf) } diff --git a/endpoint_stream.go b/endpoint_stream.go index 138aebc..b4f2780 100644 --- a/endpoint_stream.go +++ b/endpoint_stream.go @@ -17,7 +17,7 @@ package gousb func (e *endpoint) newStream(size, count int) (*stream, error) { var ts []transferIntf for i := 0; i < count; i++ { - t, err := newUSBTransfer(e.ctx, e.h, &e.Desc, size, e.Timeout) + t, err := newUSBTransfer(e.ctx, e.h, &e.Desc, size) if err != nil { for _, t := range ts { t.free() @@ -29,8 +29,8 @@ func (e *endpoint) newStream(size, count int) (*stream, error) { return newStream(ts), nil } -// NewStream prepares a new read stream that will keep reading data from the -// endpoint until closed. +// NewStream prepares a new read stream that will keep reading data from +// the endpoint until closed or until an error or timeout is encountered. // Size defines a buffer size for a single read transaction and count // defines how many transactions should be active at any time. // By keeping multiple transfers active at the same time, a Stream reduces @@ -44,11 +44,11 @@ func (e *InEndpoint) NewStream(size, count int) (*ReadStream, error) { return &ReadStream{s: s}, nil } -// NewStream prepares a new write stream that will write data in the background. -// Size defines a buffer size for a single write transaction and count -// defines how many transactions may be active at any time. -// By buffering the writes, a Stream reduces the latency between subsequent -// transfers and increases writing throughput. +// NewStream prepares a new write stream that will write data in the +// background. Size defines a buffer size for a single write transaction and +// count defines how many transactions may be active at any time. By buffering +// the writes, a Stream reduces the latency between subsequent transfers and +// increases writing throughput. func (e *OutEndpoint) NewStream(size, count int) (*WriteStream, error) { s, err := e.newStream(size, count) if err != nil { diff --git a/endpoint_stream_test.go b/endpoint_stream_test.go index c7418be..588f21a 100644 --- a/endpoint_stream_test.go +++ b/endpoint_stream_test.go @@ -36,13 +36,11 @@ func TestEndpointReadStream(t *testing.T) { return } if num < goodTransfers { - xfr.status = TransferCompleted - xfr.length = len(xfr.buf) + xfr.setData(make([]byte, len(xfr.buf))) + xfr.setStatus(TransferCompleted) } else { - xfr.status = TransferError - xfr.length = 0 + xfr.setStatus(TransferError) } - xfr.done <- struct{}{} num++ } }() @@ -106,9 +104,8 @@ func TestEndpointWriteStream(t *testing.T) { if xfr == nil { return } - xfr.length = len(xfr.buf) - xfr.status = TransferCompleted - xfr.done <- struct{}{} + xfr.setData(make([]byte, len(xfr.buf))) + xfr.setStatus(TransferCompleted) num++ total += xfr.length } diff --git a/endpoint_test.go b/endpoint_test.go index 4841121..89d6b9d 100644 --- a/endpoint_test.go +++ b/endpoint_test.go @@ -15,6 +15,7 @@ package gousb import ( + "context" "testing" "time" ) @@ -101,12 +102,11 @@ func TestEndpoint(t *testing.T) { if tc.wantSubmit { go func() { fakeT := lib.waitForSubmitted(nil) - fakeT.length = tc.ret - fakeT.status = tc.status - close(fakeT.done) + fakeT.setData(make([]byte, tc.ret)) + fakeT.setStatus(tc.status) }() } - got, err := ep.transfer(tc.buf) + got, err := ep.transfer(context.TODO(), tc.buf) if (err != nil) != tc.wantErr { t.Errorf("%s, %s: ep.transfer(...): got err: %v, err != nil is %v, want %v", epData.ei, tc.desc, err, err != nil, tc.wantErr) continue @@ -209,9 +209,8 @@ func TestEndpointInOut(t *testing.T) { dataTransferred := 100 go func() { fakeT := lib.waitForSubmitted(nil) - fakeT.length = dataTransferred - fakeT.status = TransferCompleted - close(fakeT.done) + fakeT.setData(make([]byte, dataTransferred)) + fakeT.setStatus(TransferCompleted) }() buf := make([]byte, 512) got, err := iep.Read(buf) @@ -233,9 +232,8 @@ func TestEndpointInOut(t *testing.T) { } go func() { fakeT := lib.waitForSubmitted(nil) - fakeT.length = dataTransferred - fakeT.status = TransferCompleted - close(fakeT.done) + fakeT.setData(make([]byte, dataTransferred)) + fakeT.setStatus(TransferCompleted) }() got, err = oep.Write(buf) if err != nil { @@ -290,3 +288,71 @@ func TestSameEndpointNumberInOut(t *testing.T) { t.Errorf("%s.OutEndpoint(1): got error %v, want nil", intf, err) } } + +func TestReadContext(t *testing.T) { + t.Parallel() + lib := newFakeLibusb() + ctx := newContextWithImpl(lib) + defer func() { + if err := ctx.Close(); err != nil { + t.Errorf("Context.Close(): %v", err) + } + }() + + d, err := ctx.OpenDeviceWithVIDPID(0x9999, 0x0001) + if err != nil { + t.Fatalf("OpenDeviceWithVIDPID(0x9999, 0x0001): got error %v, want nil", err) + } + defer func() { + if err := d.Close(); err != nil { + t.Errorf("%s.Close(): %v", d, err) + } + }() + cfg, err := d.Config(1) + if err != nil { + t.Fatalf("%s.Config(1): %v", d, err) + } + defer func() { + if err := cfg.Close(); err != nil { + t.Errorf("%s.Close(): %v", cfg, err) + } + }() + intf, err := cfg.Interface(0, 0) + if err != nil { + t.Fatalf("%s.Interface(0, 0): %v", cfg, err) + } + defer intf.Close() + iep, err := intf.InEndpoint(2) + if err != nil { + t.Fatalf("%s.InEndpoint(2): got error %v, want nil", intf, err) + } + buf := make([]byte, 512) + + rCtx, done := context.WithCancel(context.Background()) + go func() { + ft := lib.waitForSubmitted(nil) + ft.setData([]byte{1, 2, 3, 4, 5}) + done() + }() + if got, err := iep.ReadContext(rCtx, buf); err != TransferCancelled { + t.Errorf("%s.Read: got error %v, want %v", iep, err, TransferCancelled) + } else if want := 5; got != want { + t.Errorf("%s.Read: got %d bytes, want %d (partial read success)", iep, got, want) + } + + oep, err := intf.OutEndpoint(1) + if err != nil { + t.Fatalf("%s.OutEndpoint(1): got error %v, want nil", intf, err) + } + wCtx, done := context.WithCancel(context.Background()) + go func() { + ft := lib.waitForSubmitted(nil) + ft.setLength(5) + done() + }() + if got, err := oep.WriteContext(wCtx, buf); err != TransferCancelled { + t.Errorf("%s.Write: got error %v, want %v", oep, err, TransferCancelled) + } else if want := 5; got != want { + t.Errorf("%s.Write: got %d bytes, want %d (partial write success)", oep, got, want) + } +} diff --git a/fakelibusb_test.go b/fakelibusb_test.go index 5774d90..f1c9321 100644 --- a/fakelibusb_test.go +++ b/fakelibusb_test.go @@ -22,10 +22,17 @@ import ( ) type fakeTransfer struct { - // done is the channel that needs to be closed when the transfer has finished. + // done is the channel that needs to receive a signal when the transfer has + // finished. + // This is different from finished below - done is provided by the caller + // and is used to signal the caller. done chan struct{} + // mu protects transfer data and status. + mu sync.Mutex // buf is the slice for reading/writing data between the submit() and wait() returning. buf []byte + // finished is true after the transfer is no longer in flight + finished bool // status will be returned by wait() on this transfer status TransferStatus // length is the number of bytes used from the buffer (write) or available @@ -33,6 +40,36 @@ type fakeTransfer struct { length int } +func (t *fakeTransfer) setData(d []byte) { + t.mu.Lock() + defer t.mu.Unlock() + if t.finished { + return + } + copy(t.buf, d) + t.length = len(d) +} + +func (t *fakeTransfer) setLength(n int) { + t.mu.Lock() + defer t.mu.Unlock() + if t.finished { + return + } + t.length = n +} + +func (t *fakeTransfer) setStatus(st TransferStatus) { + t.mu.Lock() + defer t.mu.Unlock() + if t.finished { + return + } + t.status = st + t.finished = true + t.done <- struct{}{} +} + // fakeLibusb implements a fake libusb stack that pretends to have a number of // devices connected to it (see fakeDevices variable for a list of devices). // fakeLibusb is expected to implement all the functions related to device @@ -162,18 +199,25 @@ func (f *fakeLibusb) setAlt(d *libusbDevHandle, intf, alt uint8) error { return nil } -func (f *fakeLibusb) alloc(_ *libusbDevHandle, _ *EndpointDesc, _ time.Duration, _ int, bufLen int, done chan struct{}) (*libusbTransfer, error) { +func (f *fakeLibusb) alloc(_ *libusbDevHandle, _ *EndpointDesc, _ int, bufLen int, done chan struct{}) (*libusbTransfer, error) { f.mu.Lock() defer f.mu.Unlock() t := newFakeTransferPointer() f.ts[t] = &fakeTransfer{buf: make([]byte, bufLen), done: done} return t, nil } -func (f *fakeLibusb) cancel(t *libusbTransfer) error { return errors.New("not implemented") } +func (f *fakeLibusb) cancel(t *libusbTransfer) error { + f.mu.Lock() + ft := f.ts[t] + f.mu.Unlock() + ft.setStatus(TransferCancelled) + return nil +} func (f *fakeLibusb) submit(t *libusbTransfer) error { f.mu.Lock() ft := f.ts[t] f.mu.Unlock() + ft.finished = false f.submitted <- ft return nil } diff --git a/libusb.go b/libusb.go index b24c64d..2af2dac 100644 --- a/libusb.go +++ b/libusb.go @@ -159,7 +159,7 @@ type libusbIntf interface { setAlt(*libusbDevHandle, uint8, uint8) error // transfer - alloc(*libusbDevHandle, *EndpointDesc, time.Duration, int, int, chan struct{}) (*libusbTransfer, error) + alloc(*libusbDevHandle, *EndpointDesc, int, int, chan struct{}) (*libusbTransfer, error) cancel(*libusbTransfer) error submit(*libusbTransfer) error buffer(*libusbTransfer) []byte @@ -422,7 +422,7 @@ func (libusbImpl) setAlt(d *libusbDevHandle, iface, setup uint8) error { return fromErrNo(C.libusb_set_interface_alt_setting((*C.libusb_device_handle)(d), C.int(iface), C.int(setup))) } -func (libusbImpl) alloc(d *libusbDevHandle, ep *EndpointDesc, timeout time.Duration, isoPackets int, bufLen int, done chan struct{}) (*libusbTransfer, error) { +func (libusbImpl) alloc(d *libusbDevHandle, ep *EndpointDesc, isoPackets int, bufLen int, done chan struct{}) (*libusbTransfer, error) { xfer := C.gousb_alloc_transfer_and_buffer(C.int(bufLen), C.int(isoPackets)) if xfer == nil { return nil, fmt.Errorf("gousb_alloc_transfer_and_buffer(%d, %d) failed", bufLen, isoPackets) @@ -432,7 +432,6 @@ func (libusbImpl) alloc(d *libusbDevHandle, ep *EndpointDesc, timeout time.Durat } xfer.dev_handle = (*C.libusb_device_handle)(d) xfer.endpoint = C.uchar(ep.Address) - xfer.timeout = C.uint(timeout / time.Millisecond) xfer._type = C.uchar(ep.TransferType) xfer.num_iso_packets = C.int(isoPackets) ret := (*libusbTransfer)(xfer) diff --git a/rawread/main.go b/rawread/main.go index e1cf4a6..efe58f3 100644 --- a/rawread/main.go +++ b/rawread/main.go @@ -17,9 +17,9 @@ package main import ( + "context" "flag" "fmt" - "io" "log" "os" "strconv" @@ -39,6 +39,7 @@ var ( size = flag.Int("read_size", 1024, "Number of bytes of data to read in a single transaction.") bufSize = flag.Int("buffer_size", 0, "Number of buffer transfers, for data prefetching.") num = flag.Int("read_num", 0, "Number of read transactions to perform. 0 means infinite.") + timeout = flag.Duration("timeout", 0, "Timeout for the command. 0 means infinite.") ) func parseVIDPID(vidPid string) (gousb.ID, gousb.ID, error) { @@ -73,6 +74,10 @@ func parseBusAddr(busAddr string) (int, int, error) { return int(bus), int(addr), nil } +type contextReader interface { + ReadContext(context.Context, []byte) (int, error) +} + func main() { flag.Parse() @@ -160,7 +165,7 @@ func main() { log.Fatalf("dev.InEndpoint(): %s", err) } log.Printf("Found endpoint: %s", ep) - var rdr io.Reader = ep + var rdr contextReader = ep if *bufSize > 1 { log.Print("Creating buffer...") s, err := ep.NewStream(*size, *bufSize) @@ -170,11 +175,17 @@ func main() { defer s.Close() rdr = s } - log.Print("Reading...") + opCtx := context.Background() + if *timeout > 0 { + var done func() + opCtx, done = context.WithTimeout(opCtx, *timeout) + defer done() + } buf := make([]byte, *size) + log.Print("Reading...") for i := 0; *num == 0 || i < *num; i++ { - num, err := rdr.Read(buf) + num, err := rdr.ReadContext(opCtx, buf) if err != nil { log.Fatalf("Reading from device failed: %v", err) } diff --git a/transfer.go b/transfer.go index 4f9fd23..45abf73 100644 --- a/transfer.go +++ b/transfer.go @@ -15,10 +15,10 @@ package gousb import ( + "context" "errors" "runtime" "sync" - "time" ) type usbTransfer struct { @@ -60,13 +60,20 @@ func (t *usbTransfer) submit() error { // via t.buf. The number returned by wait indicates how many bytes // of the buffer were read or written by libusb, and it can be // smaller than the length of t.buf. -func (t *usbTransfer) wait() (n int, err error) { +func (t *usbTransfer) wait(ctx context.Context) (n int, err error) { t.mu.Lock() defer t.mu.Unlock() if !t.submitted { return 0, nil } - <-t.done + select { + case <-ctx.Done(): + t.ctx.libusb.cancel(t.xfer) + // after the transfer is cancelled, it will run a callback + // that triggers the activation of t.done. + <-t.done + case <-t.done: + } t.submitted = false n, status := t.ctx.libusb.data(t.xfer) if status != TransferCompleted { @@ -117,7 +124,7 @@ func (t *usbTransfer) data() []byte { // newUSBTransfer allocates a new transfer structure and a new buffer for // communication with a given device/endpoint. -func newUSBTransfer(ctx *Context, dev *libusbDevHandle, ei *EndpointDesc, bufLen int, timeout time.Duration) (*usbTransfer, error) { +func newUSBTransfer(ctx *Context, dev *libusbDevHandle, ei *EndpointDesc, bufLen int) (*usbTransfer, error) { var isoPackets, isoPktSize int if ei.TransferType == TransferTypeIsochronous { isoPktSize = ei.MaxPacketSize @@ -129,7 +136,7 @@ func newUSBTransfer(ctx *Context, dev *libusbDevHandle, ei *EndpointDesc, bufLen } done := make(chan struct{}, 1) - xfer, err := ctx.libusb.alloc(dev, ei, timeout, isoPackets, bufLen, done) + xfer, err := ctx.libusb.alloc(dev, ei, isoPackets, bufLen, done) if err != nil { return nil, err } @@ -146,7 +153,7 @@ func newUSBTransfer(ctx *Context, dev *libusbDevHandle, ei *EndpointDesc, bufLen } runtime.SetFinalizer(t, func(t *usbTransfer) { t.cancel() - t.wait() + t.wait(context.Background()) t.free() }) return t, nil diff --git a/transfer_stream.go b/transfer_stream.go index bfd8404..a32662c 100644 --- a/transfer_stream.go +++ b/transfer_stream.go @@ -14,12 +14,15 @@ package gousb -import "io" +import ( + "context" + "io" +) type transferIntf interface { submit() error cancel() error - wait() (int, error) + wait(context.Context) (int, error) free() error data() []byte } @@ -68,7 +71,7 @@ func (s *stream) flushRemaining() { s.noMore() for t := range s.transfers { t.cancel() - t.wait() + t.wait(context.Background()) t.free() } } @@ -99,8 +102,23 @@ type ReadStream struct { // might be smaller than the length of p. // After a non-nil error is returned, all subsequent attempts to read will // return io.ErrClosedPipe. -// Read cannot be called concurrently with other Read or Close. +// Read cannot be called concurrently with other Read, ReadContext +// or Close. func (r *ReadStream) Read(p []byte) (int, error) { + return r.ReadContext(context.Background(), p) +} + +// ReadContext reads data from the transfer stream. +// The data will come from at most a single transfer, so the returned number +// might be smaller than the length of p. +// After a non-nil error is returned, all subsequent attempts to read will +// return io.ErrClosedPipe. +// ReadContext cannot be called concurrently with other Read, ReadContext +// or Close. +// The context passed controls the cancellation of this particular read +// operation within the stream. The semantics is identical to +// Endpoint.ReadContext. +func (r *ReadStream) ReadContext(ctx context.Context, p []byte) (int, error) { if r.s.transfers == nil { return 0, io.ErrClosedPipe } @@ -111,7 +129,7 @@ func (r *ReadStream) Read(p []byte) (int, error) { r.s.transfers = nil return 0, r.s.err } - n, err := t.wait() + n, err := t.wait(ctx) if err != nil { // wait error aborts immediately, all remaining data is invalid. t.free() @@ -183,6 +201,25 @@ type WriteStream struct { // call after Close() has returned. // Write cannot be called concurrently with another Write, Written or Close. func (w *WriteStream) Write(p []byte) (int, error) { + return w.WriteContext(context.Background(), p) +} + +// WriteContext sends the data to the endpoint. Write returning a nil error doesn't +// mean that data was written to the device, only that it was written to the +// buffer. Only a call to Close() that returns nil error guarantees that +// all transfers have succeeded. +// If the slice passed to WriteContext does not align exactly with the transfer +// buffer size (as declared in a call to NewStream), the last USB transfer +// of this Write will be sent with less data than the full buffer. +// After a non-nil error is returned, all subsequent attempts to write will +// return io.ErrClosedPipe. +// If WriteContext encounters an error when preparing the transfer, the stream +// will still try to complete any pending transfers. The total number +// of bytes successfully written can be retrieved through a Written() +// call after Close() has returned. +// WriteContext cannot be called concurrently with another Write, WriteContext, +// Written, Close or CloseContext. +func (w *WriteStream) WriteContext(ctx context.Context, p []byte) (int, error) { if w.s.transfers == nil || w.s.err != nil { return 0, io.ErrClosedPipe } @@ -190,7 +227,7 @@ func (w *WriteStream) Write(p []byte) (int, error) { all := len(p) for written < all { t := <-w.s.transfers - n, err := t.wait() // unsubmitted transfers will return 0 bytes and no error + n, err := t.wait(ctx) // unsubmitted transfers will return 0 bytes and no error w.total += n if err != nil { t.free() @@ -229,12 +266,24 @@ func (w *WriteStream) Write(p []byte) (int, error) { // retrieved using Written(). // Close may not be called concurrently with Write, Close or Written. func (w *WriteStream) Close() error { + return w.CloseContext(context.Background()) +} + +// Close signals end of data to write. Close blocks until all transfers +// that were sent are finished. The error returned by Close is the first +// error encountered during writing the entire stream (if any). +// Close returning nil indicates all transfers completed successfully. +// After Close, the total number of bytes successfully written can be +// retrieved using Written(). +// Close may not be called concurrently with Write, Close or Written. +// CloseContext +func (w *WriteStream) CloseContext(ctx context.Context) error { if w.s.transfers == nil { return io.ErrClosedPipe } w.s.noMore() for t := range w.s.transfers { - n, err := t.wait() + n, err := t.wait(ctx) w.total += n t.free() if err != nil { @@ -248,7 +297,8 @@ func (w *WriteStream) Close() error { } // Written returns the number of bytes successfully written by the stream. -// Written may be called only after Close() has been called and returned. +// Written may be called only after Close() or CloseContext() +// has been called and returned. func (w *WriteStream) Written() int { return w.total } diff --git a/transfer_stream_test.go b/transfer_stream_test.go index 38614f7..f5b85e7 100644 --- a/transfer_stream_test.go +++ b/transfer_stream_test.go @@ -16,6 +16,7 @@ package gousb import ( "bytes" + "context" "errors" "fmt" "io" @@ -57,7 +58,7 @@ func (f *fakeStreamTransfer) submit() error { return nil } -func (f *fakeStreamTransfer) wait() (int, error) { +func (f *fakeStreamTransfer) wait(ctx context.Context) (int, error) { if f.released { return 0, errors.New("wait() called on a free()d transfer") } diff --git a/transfer_test.go b/transfer_test.go index 372a27c..2a9b943 100644 --- a/transfer_test.go +++ b/transfer_test.go @@ -15,8 +15,8 @@ package gousb import ( + "context" "testing" - "time" ) func TestNewTransfer(t *testing.T) { @@ -34,7 +34,6 @@ func TestNewTransfer(t *testing.T) { tt TransferType maxPkt int buf int - timeout time.Duration wantIso int wantLength int wantTimeout int @@ -45,7 +44,6 @@ func TestNewTransfer(t *testing.T) { tt: TransferTypeBulk, maxPkt: 512, buf: 1024, - timeout: time.Second, wantLength: 1024, }, { @@ -62,7 +60,7 @@ func TestNewTransfer(t *testing.T) { Direction: tc.dir, TransferType: tc.tt, MaxPacketSize: tc.maxPkt, - }, tc.buf, tc.timeout) + }, tc.buf) if err != nil { t.Fatalf("newUSBTransfer(): %v", err) @@ -92,41 +90,37 @@ func TestTransferProtocol(t *testing.T) { Direction: EndpointDirectionIn, TransferType: TransferTypeBulk, MaxPacketSize: 512, - }, 10240, time.Second) + }, 10240) if err != nil { t.Fatalf("newUSBTransfer: %v", err) } } + partial := make(chan struct{}) go func() { ft := f.waitForSubmitted(nil) - ft.length = 5 - ft.status = TransferCompleted - copy(ft.buf, []byte{1, 2, 3, 4, 5}) - ft.done <- struct{}{} + ft.setData([]byte{1, 2, 3, 4, 5}) + ft.setStatus(TransferCompleted) ft = f.waitForSubmitted(nil) - ft.length = 99 - ft.status = TransferCompleted - copy(ft.buf, []byte{12, 12, 12, 12, 12}) - ft.done <- struct{}{} + ft.setData(make([]byte, 99)) + ft.setStatus(TransferCompleted) ft = f.waitForSubmitted(nil) - ft.length = 123 - ft.status = TransferCancelled - ft.done <- struct{}{} + ft.setData(make([]byte, 123)) + close(partial) }() xfers[0].submit() xfers[1].submit() - got, err := xfers[0].wait() + got, err := xfers[0].wait(context.Background()) if err != nil { t.Errorf("xfer#0.wait returned error %v, want nil", err) } if want := 5; got != want { t.Errorf("xfer#0.wait returned %d bytes, want %d", got, want) } - got, err = xfers[1].wait() + got, err = xfers[1].wait(context.Background()) if err != nil { t.Errorf("xfer#0.wait returned error %v, want nil", err) } @@ -135,8 +129,9 @@ func TestTransferProtocol(t *testing.T) { } xfers[1].submit() + <-partial xfers[1].cancel() - got, err = xfers[1].wait() + got, err = xfers[1].wait(context.Background()) if err == nil { t.Error("xfer#1(resubmitted).wait returned error nil, want non-nil") } @@ -146,7 +141,7 @@ func TestTransferProtocol(t *testing.T) { for _, x := range xfers { x.cancel() - x.wait() + x.wait(context.Background()) x.free() } }