Add context-aware read/write (#57)
Adds ReadContext/WriteContext methods to endpoints and to streams. Update rawread example tool to use ReadContext for implementation of "timeout" flag.
This commit is contained in:

committed by
GitHub

parent
593cfb67e9
commit
da849d96b5
41
endpoint.go
41
endpoint.go
@@ -16,6 +16,7 @@
|
||||
package gousb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -76,8 +77,6 @@ type endpoint struct {
|
||||
InterfaceSetting
|
||||
Desc EndpointDesc
|
||||
|
||||
Timeout time.Duration
|
||||
|
||||
ctx *Context
|
||||
}
|
||||
|
||||
@@ -86,12 +85,12 @@ func (e *endpoint) String() string {
|
||||
return e.Desc.String()
|
||||
}
|
||||
|
||||
func (e *endpoint) transfer(buf []byte) (int, error) {
|
||||
func (e *endpoint) transfer(ctx context.Context, buf []byte) (int, error) {
|
||||
if len(buf) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
t, err := newUSBTransfer(e.ctx, e.h, &e.Desc, len(buf), e.Timeout)
|
||||
t, err := newUSBTransfer(e.ctx, e.h, &e.Desc, len(buf))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -104,7 +103,7 @@ func (e *endpoint) transfer(buf []byte) (int, error) {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
n, err := t.wait()
|
||||
n, err := t.wait(ctx)
|
||||
if e.Desc.Direction == EndpointDirectionIn {
|
||||
copy(buf, t.data())
|
||||
}
|
||||
@@ -122,9 +121,21 @@ type InEndpoint struct {
|
||||
*endpoint
|
||||
}
|
||||
|
||||
// Read reads data from an IN endpoint.
|
||||
// Read reads data from an IN endpoint. Read returns number of bytes obtained
|
||||
// from the endpoint. Read may return non-zero length even if
|
||||
// the returned error is not nil (partial read).
|
||||
func (e *InEndpoint) Read(buf []byte) (int, error) {
|
||||
return e.transfer(buf)
|
||||
return e.transfer(context.Background(), buf)
|
||||
}
|
||||
|
||||
// ReadContext reads data from an IN endpoint. ReadContext returns number of
|
||||
// bytes obtained from the endpoint. ReadContext may return non-zero length
|
||||
// even if the returned error is not nil (partial read).
|
||||
// The passed context can be used to control the cancellation of the read. If
|
||||
// the context is cancelled, ReadContext will cancel the underlying transfers,
|
||||
// resulting in TransferCancelled error.
|
||||
func (e *InEndpoint) ReadContext(ctx context.Context, buf []byte) (int, error) {
|
||||
return e.transfer(ctx, buf)
|
||||
}
|
||||
|
||||
// OutEndpoint represents an OUT endpoint open for transfer.
|
||||
@@ -132,7 +143,19 @@ type OutEndpoint struct {
|
||||
*endpoint
|
||||
}
|
||||
|
||||
// Write writes data to an OUT endpoint.
|
||||
// Write writes data to an OUT endpoint. Write returns number of bytes comitted
|
||||
// to the endpoint. Write may return non-zero length even if the returned error
|
||||
// is not nil (partial write).
|
||||
func (e *OutEndpoint) Write(buf []byte) (int, error) {
|
||||
return e.transfer(buf)
|
||||
return e.transfer(context.Background(), buf)
|
||||
}
|
||||
|
||||
// WriteContext writes data to an OUT endpoint. WriteContext returns number of
|
||||
// bytes comitted to the endpoint. WriteContext may return non-zero length even
|
||||
// if the returned error is not nil (partial write).
|
||||
// The passed context can be used to control the cancellation of the write. If
|
||||
// the context is cancelled, WriteContext will cancel the underlying transfers,
|
||||
// resulting in TransferCancelled error.
|
||||
func (e *OutEndpoint) WriteContext(ctx context.Context, buf []byte) (int, error) {
|
||||
return e.transfer(ctx, buf)
|
||||
}
|
||||
|
@@ -17,7 +17,7 @@ package gousb
|
||||
func (e *endpoint) newStream(size, count int) (*stream, error) {
|
||||
var ts []transferIntf
|
||||
for i := 0; i < count; i++ {
|
||||
t, err := newUSBTransfer(e.ctx, e.h, &e.Desc, size, e.Timeout)
|
||||
t, err := newUSBTransfer(e.ctx, e.h, &e.Desc, size)
|
||||
if err != nil {
|
||||
for _, t := range ts {
|
||||
t.free()
|
||||
@@ -29,8 +29,8 @@ func (e *endpoint) newStream(size, count int) (*stream, error) {
|
||||
return newStream(ts), nil
|
||||
}
|
||||
|
||||
// NewStream prepares a new read stream that will keep reading data from the
|
||||
// endpoint until closed.
|
||||
// NewStream prepares a new read stream that will keep reading data from
|
||||
// the endpoint until closed or until an error or timeout is encountered.
|
||||
// Size defines a buffer size for a single read transaction and count
|
||||
// defines how many transactions should be active at any time.
|
||||
// By keeping multiple transfers active at the same time, a Stream reduces
|
||||
@@ -44,11 +44,11 @@ func (e *InEndpoint) NewStream(size, count int) (*ReadStream, error) {
|
||||
return &ReadStream{s: s}, nil
|
||||
}
|
||||
|
||||
// NewStream prepares a new write stream that will write data in the background.
|
||||
// Size defines a buffer size for a single write transaction and count
|
||||
// defines how many transactions may be active at any time.
|
||||
// By buffering the writes, a Stream reduces the latency between subsequent
|
||||
// transfers and increases writing throughput.
|
||||
// NewStream prepares a new write stream that will write data in the
|
||||
// background. Size defines a buffer size for a single write transaction and
|
||||
// count defines how many transactions may be active at any time. By buffering
|
||||
// the writes, a Stream reduces the latency between subsequent transfers and
|
||||
// increases writing throughput.
|
||||
func (e *OutEndpoint) NewStream(size, count int) (*WriteStream, error) {
|
||||
s, err := e.newStream(size, count)
|
||||
if err != nil {
|
||||
|
@@ -36,13 +36,11 @@ func TestEndpointReadStream(t *testing.T) {
|
||||
return
|
||||
}
|
||||
if num < goodTransfers {
|
||||
xfr.status = TransferCompleted
|
||||
xfr.length = len(xfr.buf)
|
||||
xfr.setData(make([]byte, len(xfr.buf)))
|
||||
xfr.setStatus(TransferCompleted)
|
||||
} else {
|
||||
xfr.status = TransferError
|
||||
xfr.length = 0
|
||||
xfr.setStatus(TransferError)
|
||||
}
|
||||
xfr.done <- struct{}{}
|
||||
num++
|
||||
}
|
||||
}()
|
||||
@@ -106,9 +104,8 @@ func TestEndpointWriteStream(t *testing.T) {
|
||||
if xfr == nil {
|
||||
return
|
||||
}
|
||||
xfr.length = len(xfr.buf)
|
||||
xfr.status = TransferCompleted
|
||||
xfr.done <- struct{}{}
|
||||
xfr.setData(make([]byte, len(xfr.buf)))
|
||||
xfr.setStatus(TransferCompleted)
|
||||
num++
|
||||
total += xfr.length
|
||||
}
|
||||
|
@@ -15,6 +15,7 @@
|
||||
package gousb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
@@ -101,12 +102,11 @@ func TestEndpoint(t *testing.T) {
|
||||
if tc.wantSubmit {
|
||||
go func() {
|
||||
fakeT := lib.waitForSubmitted(nil)
|
||||
fakeT.length = tc.ret
|
||||
fakeT.status = tc.status
|
||||
close(fakeT.done)
|
||||
fakeT.setData(make([]byte, tc.ret))
|
||||
fakeT.setStatus(tc.status)
|
||||
}()
|
||||
}
|
||||
got, err := ep.transfer(tc.buf)
|
||||
got, err := ep.transfer(context.TODO(), tc.buf)
|
||||
if (err != nil) != tc.wantErr {
|
||||
t.Errorf("%s, %s: ep.transfer(...): got err: %v, err != nil is %v, want %v", epData.ei, tc.desc, err, err != nil, tc.wantErr)
|
||||
continue
|
||||
@@ -209,9 +209,8 @@ func TestEndpointInOut(t *testing.T) {
|
||||
dataTransferred := 100
|
||||
go func() {
|
||||
fakeT := lib.waitForSubmitted(nil)
|
||||
fakeT.length = dataTransferred
|
||||
fakeT.status = TransferCompleted
|
||||
close(fakeT.done)
|
||||
fakeT.setData(make([]byte, dataTransferred))
|
||||
fakeT.setStatus(TransferCompleted)
|
||||
}()
|
||||
buf := make([]byte, 512)
|
||||
got, err := iep.Read(buf)
|
||||
@@ -233,9 +232,8 @@ func TestEndpointInOut(t *testing.T) {
|
||||
}
|
||||
go func() {
|
||||
fakeT := lib.waitForSubmitted(nil)
|
||||
fakeT.length = dataTransferred
|
||||
fakeT.status = TransferCompleted
|
||||
close(fakeT.done)
|
||||
fakeT.setData(make([]byte, dataTransferred))
|
||||
fakeT.setStatus(TransferCompleted)
|
||||
}()
|
||||
got, err = oep.Write(buf)
|
||||
if err != nil {
|
||||
@@ -290,3 +288,71 @@ func TestSameEndpointNumberInOut(t *testing.T) {
|
||||
t.Errorf("%s.OutEndpoint(1): got error %v, want nil", intf, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
lib := newFakeLibusb()
|
||||
ctx := newContextWithImpl(lib)
|
||||
defer func() {
|
||||
if err := ctx.Close(); err != nil {
|
||||
t.Errorf("Context.Close(): %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
d, err := ctx.OpenDeviceWithVIDPID(0x9999, 0x0001)
|
||||
if err != nil {
|
||||
t.Fatalf("OpenDeviceWithVIDPID(0x9999, 0x0001): got error %v, want nil", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Errorf("%s.Close(): %v", d, err)
|
||||
}
|
||||
}()
|
||||
cfg, err := d.Config(1)
|
||||
if err != nil {
|
||||
t.Fatalf("%s.Config(1): %v", d, err)
|
||||
}
|
||||
defer func() {
|
||||
if err := cfg.Close(); err != nil {
|
||||
t.Errorf("%s.Close(): %v", cfg, err)
|
||||
}
|
||||
}()
|
||||
intf, err := cfg.Interface(0, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("%s.Interface(0, 0): %v", cfg, err)
|
||||
}
|
||||
defer intf.Close()
|
||||
iep, err := intf.InEndpoint(2)
|
||||
if err != nil {
|
||||
t.Fatalf("%s.InEndpoint(2): got error %v, want nil", intf, err)
|
||||
}
|
||||
buf := make([]byte, 512)
|
||||
|
||||
rCtx, done := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
ft := lib.waitForSubmitted(nil)
|
||||
ft.setData([]byte{1, 2, 3, 4, 5})
|
||||
done()
|
||||
}()
|
||||
if got, err := iep.ReadContext(rCtx, buf); err != TransferCancelled {
|
||||
t.Errorf("%s.Read: got error %v, want %v", iep, err, TransferCancelled)
|
||||
} else if want := 5; got != want {
|
||||
t.Errorf("%s.Read: got %d bytes, want %d (partial read success)", iep, got, want)
|
||||
}
|
||||
|
||||
oep, err := intf.OutEndpoint(1)
|
||||
if err != nil {
|
||||
t.Fatalf("%s.OutEndpoint(1): got error %v, want nil", intf, err)
|
||||
}
|
||||
wCtx, done := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
ft := lib.waitForSubmitted(nil)
|
||||
ft.setLength(5)
|
||||
done()
|
||||
}()
|
||||
if got, err := oep.WriteContext(wCtx, buf); err != TransferCancelled {
|
||||
t.Errorf("%s.Write: got error %v, want %v", oep, err, TransferCancelled)
|
||||
} else if want := 5; got != want {
|
||||
t.Errorf("%s.Write: got %d bytes, want %d (partial write success)", oep, got, want)
|
||||
}
|
||||
}
|
||||
|
@@ -22,10 +22,17 @@ import (
|
||||
)
|
||||
|
||||
type fakeTransfer struct {
|
||||
// done is the channel that needs to be closed when the transfer has finished.
|
||||
// done is the channel that needs to receive a signal when the transfer has
|
||||
// finished.
|
||||
// This is different from finished below - done is provided by the caller
|
||||
// and is used to signal the caller.
|
||||
done chan struct{}
|
||||
// mu protects transfer data and status.
|
||||
mu sync.Mutex
|
||||
// buf is the slice for reading/writing data between the submit() and wait() returning.
|
||||
buf []byte
|
||||
// finished is true after the transfer is no longer in flight
|
||||
finished bool
|
||||
// status will be returned by wait() on this transfer
|
||||
status TransferStatus
|
||||
// length is the number of bytes used from the buffer (write) or available
|
||||
@@ -33,6 +40,36 @@ type fakeTransfer struct {
|
||||
length int
|
||||
}
|
||||
|
||||
func (t *fakeTransfer) setData(d []byte) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
if t.finished {
|
||||
return
|
||||
}
|
||||
copy(t.buf, d)
|
||||
t.length = len(d)
|
||||
}
|
||||
|
||||
func (t *fakeTransfer) setLength(n int) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
if t.finished {
|
||||
return
|
||||
}
|
||||
t.length = n
|
||||
}
|
||||
|
||||
func (t *fakeTransfer) setStatus(st TransferStatus) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
if t.finished {
|
||||
return
|
||||
}
|
||||
t.status = st
|
||||
t.finished = true
|
||||
t.done <- struct{}{}
|
||||
}
|
||||
|
||||
// fakeLibusb implements a fake libusb stack that pretends to have a number of
|
||||
// devices connected to it (see fakeDevices variable for a list of devices).
|
||||
// fakeLibusb is expected to implement all the functions related to device
|
||||
@@ -162,18 +199,25 @@ func (f *fakeLibusb) setAlt(d *libusbDevHandle, intf, alt uint8) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeLibusb) alloc(_ *libusbDevHandle, _ *EndpointDesc, _ time.Duration, _ int, bufLen int, done chan struct{}) (*libusbTransfer, error) {
|
||||
func (f *fakeLibusb) alloc(_ *libusbDevHandle, _ *EndpointDesc, _ int, bufLen int, done chan struct{}) (*libusbTransfer, error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
t := newFakeTransferPointer()
|
||||
f.ts[t] = &fakeTransfer{buf: make([]byte, bufLen), done: done}
|
||||
return t, nil
|
||||
}
|
||||
func (f *fakeLibusb) cancel(t *libusbTransfer) error { return errors.New("not implemented") }
|
||||
func (f *fakeLibusb) cancel(t *libusbTransfer) error {
|
||||
f.mu.Lock()
|
||||
ft := f.ts[t]
|
||||
f.mu.Unlock()
|
||||
ft.setStatus(TransferCancelled)
|
||||
return nil
|
||||
}
|
||||
func (f *fakeLibusb) submit(t *libusbTransfer) error {
|
||||
f.mu.Lock()
|
||||
ft := f.ts[t]
|
||||
f.mu.Unlock()
|
||||
ft.finished = false
|
||||
f.submitted <- ft
|
||||
return nil
|
||||
}
|
||||
|
@@ -159,7 +159,7 @@ type libusbIntf interface {
|
||||
setAlt(*libusbDevHandle, uint8, uint8) error
|
||||
|
||||
// transfer
|
||||
alloc(*libusbDevHandle, *EndpointDesc, time.Duration, int, int, chan struct{}) (*libusbTransfer, error)
|
||||
alloc(*libusbDevHandle, *EndpointDesc, int, int, chan struct{}) (*libusbTransfer, error)
|
||||
cancel(*libusbTransfer) error
|
||||
submit(*libusbTransfer) error
|
||||
buffer(*libusbTransfer) []byte
|
||||
@@ -422,7 +422,7 @@ func (libusbImpl) setAlt(d *libusbDevHandle, iface, setup uint8) error {
|
||||
return fromErrNo(C.libusb_set_interface_alt_setting((*C.libusb_device_handle)(d), C.int(iface), C.int(setup)))
|
||||
}
|
||||
|
||||
func (libusbImpl) alloc(d *libusbDevHandle, ep *EndpointDesc, timeout time.Duration, isoPackets int, bufLen int, done chan struct{}) (*libusbTransfer, error) {
|
||||
func (libusbImpl) alloc(d *libusbDevHandle, ep *EndpointDesc, isoPackets int, bufLen int, done chan struct{}) (*libusbTransfer, error) {
|
||||
xfer := C.gousb_alloc_transfer_and_buffer(C.int(bufLen), C.int(isoPackets))
|
||||
if xfer == nil {
|
||||
return nil, fmt.Errorf("gousb_alloc_transfer_and_buffer(%d, %d) failed", bufLen, isoPackets)
|
||||
@@ -432,7 +432,6 @@ func (libusbImpl) alloc(d *libusbDevHandle, ep *EndpointDesc, timeout time.Durat
|
||||
}
|
||||
xfer.dev_handle = (*C.libusb_device_handle)(d)
|
||||
xfer.endpoint = C.uchar(ep.Address)
|
||||
xfer.timeout = C.uint(timeout / time.Millisecond)
|
||||
xfer._type = C.uchar(ep.TransferType)
|
||||
xfer.num_iso_packets = C.int(isoPackets)
|
||||
ret := (*libusbTransfer)(xfer)
|
||||
|
@@ -17,9 +17,9 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"strconv"
|
||||
@@ -39,6 +39,7 @@ var (
|
||||
size = flag.Int("read_size", 1024, "Number of bytes of data to read in a single transaction.")
|
||||
bufSize = flag.Int("buffer_size", 0, "Number of buffer transfers, for data prefetching.")
|
||||
num = flag.Int("read_num", 0, "Number of read transactions to perform. 0 means infinite.")
|
||||
timeout = flag.Duration("timeout", 0, "Timeout for the command. 0 means infinite.")
|
||||
)
|
||||
|
||||
func parseVIDPID(vidPid string) (gousb.ID, gousb.ID, error) {
|
||||
@@ -73,6 +74,10 @@ func parseBusAddr(busAddr string) (int, int, error) {
|
||||
return int(bus), int(addr), nil
|
||||
}
|
||||
|
||||
type contextReader interface {
|
||||
ReadContext(context.Context, []byte) (int, error)
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
|
||||
@@ -160,7 +165,7 @@ func main() {
|
||||
log.Fatalf("dev.InEndpoint(): %s", err)
|
||||
}
|
||||
log.Printf("Found endpoint: %s", ep)
|
||||
var rdr io.Reader = ep
|
||||
var rdr contextReader = ep
|
||||
if *bufSize > 1 {
|
||||
log.Print("Creating buffer...")
|
||||
s, err := ep.NewStream(*size, *bufSize)
|
||||
@@ -170,11 +175,17 @@ func main() {
|
||||
defer s.Close()
|
||||
rdr = s
|
||||
}
|
||||
log.Print("Reading...")
|
||||
|
||||
opCtx := context.Background()
|
||||
if *timeout > 0 {
|
||||
var done func()
|
||||
opCtx, done = context.WithTimeout(opCtx, *timeout)
|
||||
defer done()
|
||||
}
|
||||
buf := make([]byte, *size)
|
||||
log.Print("Reading...")
|
||||
for i := 0; *num == 0 || i < *num; i++ {
|
||||
num, err := rdr.Read(buf)
|
||||
num, err := rdr.ReadContext(opCtx, buf)
|
||||
if err != nil {
|
||||
log.Fatalf("Reading from device failed: %v", err)
|
||||
}
|
||||
|
19
transfer.go
19
transfer.go
@@ -15,10 +15,10 @@
|
||||
package gousb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type usbTransfer struct {
|
||||
@@ -60,13 +60,20 @@ func (t *usbTransfer) submit() error {
|
||||
// via t.buf. The number returned by wait indicates how many bytes
|
||||
// of the buffer were read or written by libusb, and it can be
|
||||
// smaller than the length of t.buf.
|
||||
func (t *usbTransfer) wait() (n int, err error) {
|
||||
func (t *usbTransfer) wait(ctx context.Context) (n int, err error) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
if !t.submitted {
|
||||
return 0, nil
|
||||
}
|
||||
<-t.done
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.ctx.libusb.cancel(t.xfer)
|
||||
// after the transfer is cancelled, it will run a callback
|
||||
// that triggers the activation of t.done.
|
||||
<-t.done
|
||||
case <-t.done:
|
||||
}
|
||||
t.submitted = false
|
||||
n, status := t.ctx.libusb.data(t.xfer)
|
||||
if status != TransferCompleted {
|
||||
@@ -117,7 +124,7 @@ func (t *usbTransfer) data() []byte {
|
||||
|
||||
// newUSBTransfer allocates a new transfer structure and a new buffer for
|
||||
// communication with a given device/endpoint.
|
||||
func newUSBTransfer(ctx *Context, dev *libusbDevHandle, ei *EndpointDesc, bufLen int, timeout time.Duration) (*usbTransfer, error) {
|
||||
func newUSBTransfer(ctx *Context, dev *libusbDevHandle, ei *EndpointDesc, bufLen int) (*usbTransfer, error) {
|
||||
var isoPackets, isoPktSize int
|
||||
if ei.TransferType == TransferTypeIsochronous {
|
||||
isoPktSize = ei.MaxPacketSize
|
||||
@@ -129,7 +136,7 @@ func newUSBTransfer(ctx *Context, dev *libusbDevHandle, ei *EndpointDesc, bufLen
|
||||
}
|
||||
|
||||
done := make(chan struct{}, 1)
|
||||
xfer, err := ctx.libusb.alloc(dev, ei, timeout, isoPackets, bufLen, done)
|
||||
xfer, err := ctx.libusb.alloc(dev, ei, isoPackets, bufLen, done)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -146,7 +153,7 @@ func newUSBTransfer(ctx *Context, dev *libusbDevHandle, ei *EndpointDesc, bufLen
|
||||
}
|
||||
runtime.SetFinalizer(t, func(t *usbTransfer) {
|
||||
t.cancel()
|
||||
t.wait()
|
||||
t.wait(context.Background())
|
||||
t.free()
|
||||
})
|
||||
return t, nil
|
||||
|
@@ -14,12 +14,15 @@
|
||||
|
||||
package gousb
|
||||
|
||||
import "io"
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
)
|
||||
|
||||
type transferIntf interface {
|
||||
submit() error
|
||||
cancel() error
|
||||
wait() (int, error)
|
||||
wait(context.Context) (int, error)
|
||||
free() error
|
||||
data() []byte
|
||||
}
|
||||
@@ -68,7 +71,7 @@ func (s *stream) flushRemaining() {
|
||||
s.noMore()
|
||||
for t := range s.transfers {
|
||||
t.cancel()
|
||||
t.wait()
|
||||
t.wait(context.Background())
|
||||
t.free()
|
||||
}
|
||||
}
|
||||
@@ -99,8 +102,23 @@ type ReadStream struct {
|
||||
// might be smaller than the length of p.
|
||||
// After a non-nil error is returned, all subsequent attempts to read will
|
||||
// return io.ErrClosedPipe.
|
||||
// Read cannot be called concurrently with other Read or Close.
|
||||
// Read cannot be called concurrently with other Read, ReadContext
|
||||
// or Close.
|
||||
func (r *ReadStream) Read(p []byte) (int, error) {
|
||||
return r.ReadContext(context.Background(), p)
|
||||
}
|
||||
|
||||
// ReadContext reads data from the transfer stream.
|
||||
// The data will come from at most a single transfer, so the returned number
|
||||
// might be smaller than the length of p.
|
||||
// After a non-nil error is returned, all subsequent attempts to read will
|
||||
// return io.ErrClosedPipe.
|
||||
// ReadContext cannot be called concurrently with other Read, ReadContext
|
||||
// or Close.
|
||||
// The context passed controls the cancellation of this particular read
|
||||
// operation within the stream. The semantics is identical to
|
||||
// Endpoint.ReadContext.
|
||||
func (r *ReadStream) ReadContext(ctx context.Context, p []byte) (int, error) {
|
||||
if r.s.transfers == nil {
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
@@ -111,7 +129,7 @@ func (r *ReadStream) Read(p []byte) (int, error) {
|
||||
r.s.transfers = nil
|
||||
return 0, r.s.err
|
||||
}
|
||||
n, err := t.wait()
|
||||
n, err := t.wait(ctx)
|
||||
if err != nil {
|
||||
// wait error aborts immediately, all remaining data is invalid.
|
||||
t.free()
|
||||
@@ -183,6 +201,25 @@ type WriteStream struct {
|
||||
// call after Close() has returned.
|
||||
// Write cannot be called concurrently with another Write, Written or Close.
|
||||
func (w *WriteStream) Write(p []byte) (int, error) {
|
||||
return w.WriteContext(context.Background(), p)
|
||||
}
|
||||
|
||||
// WriteContext sends the data to the endpoint. Write returning a nil error doesn't
|
||||
// mean that data was written to the device, only that it was written to the
|
||||
// buffer. Only a call to Close() that returns nil error guarantees that
|
||||
// all transfers have succeeded.
|
||||
// If the slice passed to WriteContext does not align exactly with the transfer
|
||||
// buffer size (as declared in a call to NewStream), the last USB transfer
|
||||
// of this Write will be sent with less data than the full buffer.
|
||||
// After a non-nil error is returned, all subsequent attempts to write will
|
||||
// return io.ErrClosedPipe.
|
||||
// If WriteContext encounters an error when preparing the transfer, the stream
|
||||
// will still try to complete any pending transfers. The total number
|
||||
// of bytes successfully written can be retrieved through a Written()
|
||||
// call after Close() has returned.
|
||||
// WriteContext cannot be called concurrently with another Write, WriteContext,
|
||||
// Written, Close or CloseContext.
|
||||
func (w *WriteStream) WriteContext(ctx context.Context, p []byte) (int, error) {
|
||||
if w.s.transfers == nil || w.s.err != nil {
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
@@ -190,7 +227,7 @@ func (w *WriteStream) Write(p []byte) (int, error) {
|
||||
all := len(p)
|
||||
for written < all {
|
||||
t := <-w.s.transfers
|
||||
n, err := t.wait() // unsubmitted transfers will return 0 bytes and no error
|
||||
n, err := t.wait(ctx) // unsubmitted transfers will return 0 bytes and no error
|
||||
w.total += n
|
||||
if err != nil {
|
||||
t.free()
|
||||
@@ -229,12 +266,24 @@ func (w *WriteStream) Write(p []byte) (int, error) {
|
||||
// retrieved using Written().
|
||||
// Close may not be called concurrently with Write, Close or Written.
|
||||
func (w *WriteStream) Close() error {
|
||||
return w.CloseContext(context.Background())
|
||||
}
|
||||
|
||||
// Close signals end of data to write. Close blocks until all transfers
|
||||
// that were sent are finished. The error returned by Close is the first
|
||||
// error encountered during writing the entire stream (if any).
|
||||
// Close returning nil indicates all transfers completed successfully.
|
||||
// After Close, the total number of bytes successfully written can be
|
||||
// retrieved using Written().
|
||||
// Close may not be called concurrently with Write, Close or Written.
|
||||
// CloseContext
|
||||
func (w *WriteStream) CloseContext(ctx context.Context) error {
|
||||
if w.s.transfers == nil {
|
||||
return io.ErrClosedPipe
|
||||
}
|
||||
w.s.noMore()
|
||||
for t := range w.s.transfers {
|
||||
n, err := t.wait()
|
||||
n, err := t.wait(ctx)
|
||||
w.total += n
|
||||
t.free()
|
||||
if err != nil {
|
||||
@@ -248,7 +297,8 @@ func (w *WriteStream) Close() error {
|
||||
}
|
||||
|
||||
// Written returns the number of bytes successfully written by the stream.
|
||||
// Written may be called only after Close() has been called and returned.
|
||||
// Written may be called only after Close() or CloseContext()
|
||||
// has been called and returned.
|
||||
func (w *WriteStream) Written() int {
|
||||
return w.total
|
||||
}
|
||||
|
@@ -16,6 +16,7 @@ package gousb
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -57,7 +58,7 @@ func (f *fakeStreamTransfer) submit() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeStreamTransfer) wait() (int, error) {
|
||||
func (f *fakeStreamTransfer) wait(ctx context.Context) (int, error) {
|
||||
if f.released {
|
||||
return 0, errors.New("wait() called on a free()d transfer")
|
||||
}
|
||||
|
@@ -15,8 +15,8 @@
|
||||
package gousb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNewTransfer(t *testing.T) {
|
||||
@@ -34,7 +34,6 @@ func TestNewTransfer(t *testing.T) {
|
||||
tt TransferType
|
||||
maxPkt int
|
||||
buf int
|
||||
timeout time.Duration
|
||||
wantIso int
|
||||
wantLength int
|
||||
wantTimeout int
|
||||
@@ -45,7 +44,6 @@ func TestNewTransfer(t *testing.T) {
|
||||
tt: TransferTypeBulk,
|
||||
maxPkt: 512,
|
||||
buf: 1024,
|
||||
timeout: time.Second,
|
||||
wantLength: 1024,
|
||||
},
|
||||
{
|
||||
@@ -62,7 +60,7 @@ func TestNewTransfer(t *testing.T) {
|
||||
Direction: tc.dir,
|
||||
TransferType: tc.tt,
|
||||
MaxPacketSize: tc.maxPkt,
|
||||
}, tc.buf, tc.timeout)
|
||||
}, tc.buf)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("newUSBTransfer(): %v", err)
|
||||
@@ -92,41 +90,37 @@ func TestTransferProtocol(t *testing.T) {
|
||||
Direction: EndpointDirectionIn,
|
||||
TransferType: TransferTypeBulk,
|
||||
MaxPacketSize: 512,
|
||||
}, 10240, time.Second)
|
||||
}, 10240)
|
||||
if err != nil {
|
||||
t.Fatalf("newUSBTransfer: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
partial := make(chan struct{})
|
||||
go func() {
|
||||
ft := f.waitForSubmitted(nil)
|
||||
ft.length = 5
|
||||
ft.status = TransferCompleted
|
||||
copy(ft.buf, []byte{1, 2, 3, 4, 5})
|
||||
ft.done <- struct{}{}
|
||||
ft.setData([]byte{1, 2, 3, 4, 5})
|
||||
ft.setStatus(TransferCompleted)
|
||||
|
||||
ft = f.waitForSubmitted(nil)
|
||||
ft.length = 99
|
||||
ft.status = TransferCompleted
|
||||
copy(ft.buf, []byte{12, 12, 12, 12, 12})
|
||||
ft.done <- struct{}{}
|
||||
ft.setData(make([]byte, 99))
|
||||
ft.setStatus(TransferCompleted)
|
||||
|
||||
ft = f.waitForSubmitted(nil)
|
||||
ft.length = 123
|
||||
ft.status = TransferCancelled
|
||||
ft.done <- struct{}{}
|
||||
ft.setData(make([]byte, 123))
|
||||
close(partial)
|
||||
}()
|
||||
|
||||
xfers[0].submit()
|
||||
xfers[1].submit()
|
||||
got, err := xfers[0].wait()
|
||||
got, err := xfers[0].wait(context.Background())
|
||||
if err != nil {
|
||||
t.Errorf("xfer#0.wait returned error %v, want nil", err)
|
||||
}
|
||||
if want := 5; got != want {
|
||||
t.Errorf("xfer#0.wait returned %d bytes, want %d", got, want)
|
||||
}
|
||||
got, err = xfers[1].wait()
|
||||
got, err = xfers[1].wait(context.Background())
|
||||
if err != nil {
|
||||
t.Errorf("xfer#0.wait returned error %v, want nil", err)
|
||||
}
|
||||
@@ -135,8 +129,9 @@ func TestTransferProtocol(t *testing.T) {
|
||||
}
|
||||
|
||||
xfers[1].submit()
|
||||
<-partial
|
||||
xfers[1].cancel()
|
||||
got, err = xfers[1].wait()
|
||||
got, err = xfers[1].wait(context.Background())
|
||||
if err == nil {
|
||||
t.Error("xfer#1(resubmitted).wait returned error nil, want non-nil")
|
||||
}
|
||||
@@ -146,7 +141,7 @@ func TestTransferProtocol(t *testing.T) {
|
||||
|
||||
for _, x := range xfers {
|
||||
x.cancel()
|
||||
x.wait()
|
||||
x.wait(context.Background())
|
||||
x.free()
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user