Add context-aware read/write (#57)

Adds ReadContext/WriteContext methods to endpoints and to streams. Update rawread example tool to use ReadContext for implementation of "timeout" flag.
This commit is contained in:
Sebastian Zagrodzki
2018-11-05 15:33:31 +01:00
committed by GitHub
parent 593cfb67e9
commit da849d96b5
11 changed files with 273 additions and 80 deletions

View File

@@ -16,6 +16,7 @@
package gousb package gousb
import ( import (
"context"
"fmt" "fmt"
"strings" "strings"
"time" "time"
@@ -76,8 +77,6 @@ type endpoint struct {
InterfaceSetting InterfaceSetting
Desc EndpointDesc Desc EndpointDesc
Timeout time.Duration
ctx *Context ctx *Context
} }
@@ -86,12 +85,12 @@ func (e *endpoint) String() string {
return e.Desc.String() return e.Desc.String()
} }
func (e *endpoint) transfer(buf []byte) (int, error) { func (e *endpoint) transfer(ctx context.Context, buf []byte) (int, error) {
if len(buf) == 0 { if len(buf) == 0 {
return 0, nil return 0, nil
} }
t, err := newUSBTransfer(e.ctx, e.h, &e.Desc, len(buf), e.Timeout) t, err := newUSBTransfer(e.ctx, e.h, &e.Desc, len(buf))
if err != nil { if err != nil {
return 0, err return 0, err
} }
@@ -104,7 +103,7 @@ func (e *endpoint) transfer(buf []byte) (int, error) {
return 0, err return 0, err
} }
n, err := t.wait() n, err := t.wait(ctx)
if e.Desc.Direction == EndpointDirectionIn { if e.Desc.Direction == EndpointDirectionIn {
copy(buf, t.data()) copy(buf, t.data())
} }
@@ -122,9 +121,21 @@ type InEndpoint struct {
*endpoint *endpoint
} }
// Read reads data from an IN endpoint. // Read reads data from an IN endpoint. Read returns number of bytes obtained
// from the endpoint. Read may return non-zero length even if
// the returned error is not nil (partial read).
func (e *InEndpoint) Read(buf []byte) (int, error) { func (e *InEndpoint) Read(buf []byte) (int, error) {
return e.transfer(buf) return e.transfer(context.Background(), buf)
}
// ReadContext reads data from an IN endpoint. ReadContext returns number of
// bytes obtained from the endpoint. ReadContext may return non-zero length
// even if the returned error is not nil (partial read).
// The passed context can be used to control the cancellation of the read. If
// the context is cancelled, ReadContext will cancel the underlying transfers,
// resulting in TransferCancelled error.
func (e *InEndpoint) ReadContext(ctx context.Context, buf []byte) (int, error) {
return e.transfer(ctx, buf)
} }
// OutEndpoint represents an OUT endpoint open for transfer. // OutEndpoint represents an OUT endpoint open for transfer.
@@ -132,7 +143,19 @@ type OutEndpoint struct {
*endpoint *endpoint
} }
// Write writes data to an OUT endpoint. // Write writes data to an OUT endpoint. Write returns number of bytes comitted
// to the endpoint. Write may return non-zero length even if the returned error
// is not nil (partial write).
func (e *OutEndpoint) Write(buf []byte) (int, error) { func (e *OutEndpoint) Write(buf []byte) (int, error) {
return e.transfer(buf) return e.transfer(context.Background(), buf)
}
// WriteContext writes data to an OUT endpoint. WriteContext returns number of
// bytes comitted to the endpoint. WriteContext may return non-zero length even
// if the returned error is not nil (partial write).
// The passed context can be used to control the cancellation of the write. If
// the context is cancelled, WriteContext will cancel the underlying transfers,
// resulting in TransferCancelled error.
func (e *OutEndpoint) WriteContext(ctx context.Context, buf []byte) (int, error) {
return e.transfer(ctx, buf)
} }

View File

@@ -17,7 +17,7 @@ package gousb
func (e *endpoint) newStream(size, count int) (*stream, error) { func (e *endpoint) newStream(size, count int) (*stream, error) {
var ts []transferIntf var ts []transferIntf
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
t, err := newUSBTransfer(e.ctx, e.h, &e.Desc, size, e.Timeout) t, err := newUSBTransfer(e.ctx, e.h, &e.Desc, size)
if err != nil { if err != nil {
for _, t := range ts { for _, t := range ts {
t.free() t.free()
@@ -29,8 +29,8 @@ func (e *endpoint) newStream(size, count int) (*stream, error) {
return newStream(ts), nil return newStream(ts), nil
} }
// NewStream prepares a new read stream that will keep reading data from the // NewStream prepares a new read stream that will keep reading data from
// endpoint until closed. // the endpoint until closed or until an error or timeout is encountered.
// Size defines a buffer size for a single read transaction and count // Size defines a buffer size for a single read transaction and count
// defines how many transactions should be active at any time. // defines how many transactions should be active at any time.
// By keeping multiple transfers active at the same time, a Stream reduces // By keeping multiple transfers active at the same time, a Stream reduces
@@ -44,11 +44,11 @@ func (e *InEndpoint) NewStream(size, count int) (*ReadStream, error) {
return &ReadStream{s: s}, nil return &ReadStream{s: s}, nil
} }
// NewStream prepares a new write stream that will write data in the background. // NewStream prepares a new write stream that will write data in the
// Size defines a buffer size for a single write transaction and count // background. Size defines a buffer size for a single write transaction and
// defines how many transactions may be active at any time. // count defines how many transactions may be active at any time. By buffering
// By buffering the writes, a Stream reduces the latency between subsequent // the writes, a Stream reduces the latency between subsequent transfers and
// transfers and increases writing throughput. // increases writing throughput.
func (e *OutEndpoint) NewStream(size, count int) (*WriteStream, error) { func (e *OutEndpoint) NewStream(size, count int) (*WriteStream, error) {
s, err := e.newStream(size, count) s, err := e.newStream(size, count)
if err != nil { if err != nil {

View File

@@ -36,13 +36,11 @@ func TestEndpointReadStream(t *testing.T) {
return return
} }
if num < goodTransfers { if num < goodTransfers {
xfr.status = TransferCompleted xfr.setData(make([]byte, len(xfr.buf)))
xfr.length = len(xfr.buf) xfr.setStatus(TransferCompleted)
} else { } else {
xfr.status = TransferError xfr.setStatus(TransferError)
xfr.length = 0
} }
xfr.done <- struct{}{}
num++ num++
} }
}() }()
@@ -106,9 +104,8 @@ func TestEndpointWriteStream(t *testing.T) {
if xfr == nil { if xfr == nil {
return return
} }
xfr.length = len(xfr.buf) xfr.setData(make([]byte, len(xfr.buf)))
xfr.status = TransferCompleted xfr.setStatus(TransferCompleted)
xfr.done <- struct{}{}
num++ num++
total += xfr.length total += xfr.length
} }

View File

@@ -15,6 +15,7 @@
package gousb package gousb
import ( import (
"context"
"testing" "testing"
"time" "time"
) )
@@ -101,12 +102,11 @@ func TestEndpoint(t *testing.T) {
if tc.wantSubmit { if tc.wantSubmit {
go func() { go func() {
fakeT := lib.waitForSubmitted(nil) fakeT := lib.waitForSubmitted(nil)
fakeT.length = tc.ret fakeT.setData(make([]byte, tc.ret))
fakeT.status = tc.status fakeT.setStatus(tc.status)
close(fakeT.done)
}() }()
} }
got, err := ep.transfer(tc.buf) got, err := ep.transfer(context.TODO(), tc.buf)
if (err != nil) != tc.wantErr { if (err != nil) != tc.wantErr {
t.Errorf("%s, %s: ep.transfer(...): got err: %v, err != nil is %v, want %v", epData.ei, tc.desc, err, err != nil, tc.wantErr) t.Errorf("%s, %s: ep.transfer(...): got err: %v, err != nil is %v, want %v", epData.ei, tc.desc, err, err != nil, tc.wantErr)
continue continue
@@ -209,9 +209,8 @@ func TestEndpointInOut(t *testing.T) {
dataTransferred := 100 dataTransferred := 100
go func() { go func() {
fakeT := lib.waitForSubmitted(nil) fakeT := lib.waitForSubmitted(nil)
fakeT.length = dataTransferred fakeT.setData(make([]byte, dataTransferred))
fakeT.status = TransferCompleted fakeT.setStatus(TransferCompleted)
close(fakeT.done)
}() }()
buf := make([]byte, 512) buf := make([]byte, 512)
got, err := iep.Read(buf) got, err := iep.Read(buf)
@@ -233,9 +232,8 @@ func TestEndpointInOut(t *testing.T) {
} }
go func() { go func() {
fakeT := lib.waitForSubmitted(nil) fakeT := lib.waitForSubmitted(nil)
fakeT.length = dataTransferred fakeT.setData(make([]byte, dataTransferred))
fakeT.status = TransferCompleted fakeT.setStatus(TransferCompleted)
close(fakeT.done)
}() }()
got, err = oep.Write(buf) got, err = oep.Write(buf)
if err != nil { if err != nil {
@@ -290,3 +288,71 @@ func TestSameEndpointNumberInOut(t *testing.T) {
t.Errorf("%s.OutEndpoint(1): got error %v, want nil", intf, err) t.Errorf("%s.OutEndpoint(1): got error %v, want nil", intf, err)
} }
} }
func TestReadContext(t *testing.T) {
t.Parallel()
lib := newFakeLibusb()
ctx := newContextWithImpl(lib)
defer func() {
if err := ctx.Close(); err != nil {
t.Errorf("Context.Close(): %v", err)
}
}()
d, err := ctx.OpenDeviceWithVIDPID(0x9999, 0x0001)
if err != nil {
t.Fatalf("OpenDeviceWithVIDPID(0x9999, 0x0001): got error %v, want nil", err)
}
defer func() {
if err := d.Close(); err != nil {
t.Errorf("%s.Close(): %v", d, err)
}
}()
cfg, err := d.Config(1)
if err != nil {
t.Fatalf("%s.Config(1): %v", d, err)
}
defer func() {
if err := cfg.Close(); err != nil {
t.Errorf("%s.Close(): %v", cfg, err)
}
}()
intf, err := cfg.Interface(0, 0)
if err != nil {
t.Fatalf("%s.Interface(0, 0): %v", cfg, err)
}
defer intf.Close()
iep, err := intf.InEndpoint(2)
if err != nil {
t.Fatalf("%s.InEndpoint(2): got error %v, want nil", intf, err)
}
buf := make([]byte, 512)
rCtx, done := context.WithCancel(context.Background())
go func() {
ft := lib.waitForSubmitted(nil)
ft.setData([]byte{1, 2, 3, 4, 5})
done()
}()
if got, err := iep.ReadContext(rCtx, buf); err != TransferCancelled {
t.Errorf("%s.Read: got error %v, want %v", iep, err, TransferCancelled)
} else if want := 5; got != want {
t.Errorf("%s.Read: got %d bytes, want %d (partial read success)", iep, got, want)
}
oep, err := intf.OutEndpoint(1)
if err != nil {
t.Fatalf("%s.OutEndpoint(1): got error %v, want nil", intf, err)
}
wCtx, done := context.WithCancel(context.Background())
go func() {
ft := lib.waitForSubmitted(nil)
ft.setLength(5)
done()
}()
if got, err := oep.WriteContext(wCtx, buf); err != TransferCancelled {
t.Errorf("%s.Write: got error %v, want %v", oep, err, TransferCancelled)
} else if want := 5; got != want {
t.Errorf("%s.Write: got %d bytes, want %d (partial write success)", oep, got, want)
}
}

View File

@@ -22,10 +22,17 @@ import (
) )
type fakeTransfer struct { type fakeTransfer struct {
// done is the channel that needs to be closed when the transfer has finished. // done is the channel that needs to receive a signal when the transfer has
// finished.
// This is different from finished below - done is provided by the caller
// and is used to signal the caller.
done chan struct{} done chan struct{}
// mu protects transfer data and status.
mu sync.Mutex
// buf is the slice for reading/writing data between the submit() and wait() returning. // buf is the slice for reading/writing data between the submit() and wait() returning.
buf []byte buf []byte
// finished is true after the transfer is no longer in flight
finished bool
// status will be returned by wait() on this transfer // status will be returned by wait() on this transfer
status TransferStatus status TransferStatus
// length is the number of bytes used from the buffer (write) or available // length is the number of bytes used from the buffer (write) or available
@@ -33,6 +40,36 @@ type fakeTransfer struct {
length int length int
} }
func (t *fakeTransfer) setData(d []byte) {
t.mu.Lock()
defer t.mu.Unlock()
if t.finished {
return
}
copy(t.buf, d)
t.length = len(d)
}
func (t *fakeTransfer) setLength(n int) {
t.mu.Lock()
defer t.mu.Unlock()
if t.finished {
return
}
t.length = n
}
func (t *fakeTransfer) setStatus(st TransferStatus) {
t.mu.Lock()
defer t.mu.Unlock()
if t.finished {
return
}
t.status = st
t.finished = true
t.done <- struct{}{}
}
// fakeLibusb implements a fake libusb stack that pretends to have a number of // fakeLibusb implements a fake libusb stack that pretends to have a number of
// devices connected to it (see fakeDevices variable for a list of devices). // devices connected to it (see fakeDevices variable for a list of devices).
// fakeLibusb is expected to implement all the functions related to device // fakeLibusb is expected to implement all the functions related to device
@@ -162,18 +199,25 @@ func (f *fakeLibusb) setAlt(d *libusbDevHandle, intf, alt uint8) error {
return nil return nil
} }
func (f *fakeLibusb) alloc(_ *libusbDevHandle, _ *EndpointDesc, _ time.Duration, _ int, bufLen int, done chan struct{}) (*libusbTransfer, error) { func (f *fakeLibusb) alloc(_ *libusbDevHandle, _ *EndpointDesc, _ int, bufLen int, done chan struct{}) (*libusbTransfer, error) {
f.mu.Lock() f.mu.Lock()
defer f.mu.Unlock() defer f.mu.Unlock()
t := newFakeTransferPointer() t := newFakeTransferPointer()
f.ts[t] = &fakeTransfer{buf: make([]byte, bufLen), done: done} f.ts[t] = &fakeTransfer{buf: make([]byte, bufLen), done: done}
return t, nil return t, nil
} }
func (f *fakeLibusb) cancel(t *libusbTransfer) error { return errors.New("not implemented") } func (f *fakeLibusb) cancel(t *libusbTransfer) error {
f.mu.Lock()
ft := f.ts[t]
f.mu.Unlock()
ft.setStatus(TransferCancelled)
return nil
}
func (f *fakeLibusb) submit(t *libusbTransfer) error { func (f *fakeLibusb) submit(t *libusbTransfer) error {
f.mu.Lock() f.mu.Lock()
ft := f.ts[t] ft := f.ts[t]
f.mu.Unlock() f.mu.Unlock()
ft.finished = false
f.submitted <- ft f.submitted <- ft
return nil return nil
} }

View File

@@ -159,7 +159,7 @@ type libusbIntf interface {
setAlt(*libusbDevHandle, uint8, uint8) error setAlt(*libusbDevHandle, uint8, uint8) error
// transfer // transfer
alloc(*libusbDevHandle, *EndpointDesc, time.Duration, int, int, chan struct{}) (*libusbTransfer, error) alloc(*libusbDevHandle, *EndpointDesc, int, int, chan struct{}) (*libusbTransfer, error)
cancel(*libusbTransfer) error cancel(*libusbTransfer) error
submit(*libusbTransfer) error submit(*libusbTransfer) error
buffer(*libusbTransfer) []byte buffer(*libusbTransfer) []byte
@@ -422,7 +422,7 @@ func (libusbImpl) setAlt(d *libusbDevHandle, iface, setup uint8) error {
return fromErrNo(C.libusb_set_interface_alt_setting((*C.libusb_device_handle)(d), C.int(iface), C.int(setup))) return fromErrNo(C.libusb_set_interface_alt_setting((*C.libusb_device_handle)(d), C.int(iface), C.int(setup)))
} }
func (libusbImpl) alloc(d *libusbDevHandle, ep *EndpointDesc, timeout time.Duration, isoPackets int, bufLen int, done chan struct{}) (*libusbTransfer, error) { func (libusbImpl) alloc(d *libusbDevHandle, ep *EndpointDesc, isoPackets int, bufLen int, done chan struct{}) (*libusbTransfer, error) {
xfer := C.gousb_alloc_transfer_and_buffer(C.int(bufLen), C.int(isoPackets)) xfer := C.gousb_alloc_transfer_and_buffer(C.int(bufLen), C.int(isoPackets))
if xfer == nil { if xfer == nil {
return nil, fmt.Errorf("gousb_alloc_transfer_and_buffer(%d, %d) failed", bufLen, isoPackets) return nil, fmt.Errorf("gousb_alloc_transfer_and_buffer(%d, %d) failed", bufLen, isoPackets)
@@ -432,7 +432,6 @@ func (libusbImpl) alloc(d *libusbDevHandle, ep *EndpointDesc, timeout time.Durat
} }
xfer.dev_handle = (*C.libusb_device_handle)(d) xfer.dev_handle = (*C.libusb_device_handle)(d)
xfer.endpoint = C.uchar(ep.Address) xfer.endpoint = C.uchar(ep.Address)
xfer.timeout = C.uint(timeout / time.Millisecond)
xfer._type = C.uchar(ep.TransferType) xfer._type = C.uchar(ep.TransferType)
xfer.num_iso_packets = C.int(isoPackets) xfer.num_iso_packets = C.int(isoPackets)
ret := (*libusbTransfer)(xfer) ret := (*libusbTransfer)(xfer)

View File

@@ -17,9 +17,9 @@
package main package main
import ( import (
"context"
"flag" "flag"
"fmt" "fmt"
"io"
"log" "log"
"os" "os"
"strconv" "strconv"
@@ -39,6 +39,7 @@ var (
size = flag.Int("read_size", 1024, "Number of bytes of data to read in a single transaction.") size = flag.Int("read_size", 1024, "Number of bytes of data to read in a single transaction.")
bufSize = flag.Int("buffer_size", 0, "Number of buffer transfers, for data prefetching.") bufSize = flag.Int("buffer_size", 0, "Number of buffer transfers, for data prefetching.")
num = flag.Int("read_num", 0, "Number of read transactions to perform. 0 means infinite.") num = flag.Int("read_num", 0, "Number of read transactions to perform. 0 means infinite.")
timeout = flag.Duration("timeout", 0, "Timeout for the command. 0 means infinite.")
) )
func parseVIDPID(vidPid string) (gousb.ID, gousb.ID, error) { func parseVIDPID(vidPid string) (gousb.ID, gousb.ID, error) {
@@ -73,6 +74,10 @@ func parseBusAddr(busAddr string) (int, int, error) {
return int(bus), int(addr), nil return int(bus), int(addr), nil
} }
type contextReader interface {
ReadContext(context.Context, []byte) (int, error)
}
func main() { func main() {
flag.Parse() flag.Parse()
@@ -160,7 +165,7 @@ func main() {
log.Fatalf("dev.InEndpoint(): %s", err) log.Fatalf("dev.InEndpoint(): %s", err)
} }
log.Printf("Found endpoint: %s", ep) log.Printf("Found endpoint: %s", ep)
var rdr io.Reader = ep var rdr contextReader = ep
if *bufSize > 1 { if *bufSize > 1 {
log.Print("Creating buffer...") log.Print("Creating buffer...")
s, err := ep.NewStream(*size, *bufSize) s, err := ep.NewStream(*size, *bufSize)
@@ -170,11 +175,17 @@ func main() {
defer s.Close() defer s.Close()
rdr = s rdr = s
} }
log.Print("Reading...")
opCtx := context.Background()
if *timeout > 0 {
var done func()
opCtx, done = context.WithTimeout(opCtx, *timeout)
defer done()
}
buf := make([]byte, *size) buf := make([]byte, *size)
log.Print("Reading...")
for i := 0; *num == 0 || i < *num; i++ { for i := 0; *num == 0 || i < *num; i++ {
num, err := rdr.Read(buf) num, err := rdr.ReadContext(opCtx, buf)
if err != nil { if err != nil {
log.Fatalf("Reading from device failed: %v", err) log.Fatalf("Reading from device failed: %v", err)
} }

View File

@@ -15,10 +15,10 @@
package gousb package gousb
import ( import (
"context"
"errors" "errors"
"runtime" "runtime"
"sync" "sync"
"time"
) )
type usbTransfer struct { type usbTransfer struct {
@@ -60,13 +60,20 @@ func (t *usbTransfer) submit() error {
// via t.buf. The number returned by wait indicates how many bytes // via t.buf. The number returned by wait indicates how many bytes
// of the buffer were read or written by libusb, and it can be // of the buffer were read or written by libusb, and it can be
// smaller than the length of t.buf. // smaller than the length of t.buf.
func (t *usbTransfer) wait() (n int, err error) { func (t *usbTransfer) wait(ctx context.Context) (n int, err error) {
t.mu.Lock() t.mu.Lock()
defer t.mu.Unlock() defer t.mu.Unlock()
if !t.submitted { if !t.submitted {
return 0, nil return 0, nil
} }
<-t.done select {
case <-ctx.Done():
t.ctx.libusb.cancel(t.xfer)
// after the transfer is cancelled, it will run a callback
// that triggers the activation of t.done.
<-t.done
case <-t.done:
}
t.submitted = false t.submitted = false
n, status := t.ctx.libusb.data(t.xfer) n, status := t.ctx.libusb.data(t.xfer)
if status != TransferCompleted { if status != TransferCompleted {
@@ -117,7 +124,7 @@ func (t *usbTransfer) data() []byte {
// newUSBTransfer allocates a new transfer structure and a new buffer for // newUSBTransfer allocates a new transfer structure and a new buffer for
// communication with a given device/endpoint. // communication with a given device/endpoint.
func newUSBTransfer(ctx *Context, dev *libusbDevHandle, ei *EndpointDesc, bufLen int, timeout time.Duration) (*usbTransfer, error) { func newUSBTransfer(ctx *Context, dev *libusbDevHandle, ei *EndpointDesc, bufLen int) (*usbTransfer, error) {
var isoPackets, isoPktSize int var isoPackets, isoPktSize int
if ei.TransferType == TransferTypeIsochronous { if ei.TransferType == TransferTypeIsochronous {
isoPktSize = ei.MaxPacketSize isoPktSize = ei.MaxPacketSize
@@ -129,7 +136,7 @@ func newUSBTransfer(ctx *Context, dev *libusbDevHandle, ei *EndpointDesc, bufLen
} }
done := make(chan struct{}, 1) done := make(chan struct{}, 1)
xfer, err := ctx.libusb.alloc(dev, ei, timeout, isoPackets, bufLen, done) xfer, err := ctx.libusb.alloc(dev, ei, isoPackets, bufLen, done)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -146,7 +153,7 @@ func newUSBTransfer(ctx *Context, dev *libusbDevHandle, ei *EndpointDesc, bufLen
} }
runtime.SetFinalizer(t, func(t *usbTransfer) { runtime.SetFinalizer(t, func(t *usbTransfer) {
t.cancel() t.cancel()
t.wait() t.wait(context.Background())
t.free() t.free()
}) })
return t, nil return t, nil

View File

@@ -14,12 +14,15 @@
package gousb package gousb
import "io" import (
"context"
"io"
)
type transferIntf interface { type transferIntf interface {
submit() error submit() error
cancel() error cancel() error
wait() (int, error) wait(context.Context) (int, error)
free() error free() error
data() []byte data() []byte
} }
@@ -68,7 +71,7 @@ func (s *stream) flushRemaining() {
s.noMore() s.noMore()
for t := range s.transfers { for t := range s.transfers {
t.cancel() t.cancel()
t.wait() t.wait(context.Background())
t.free() t.free()
} }
} }
@@ -99,8 +102,23 @@ type ReadStream struct {
// might be smaller than the length of p. // might be smaller than the length of p.
// After a non-nil error is returned, all subsequent attempts to read will // After a non-nil error is returned, all subsequent attempts to read will
// return io.ErrClosedPipe. // return io.ErrClosedPipe.
// Read cannot be called concurrently with other Read or Close. // Read cannot be called concurrently with other Read, ReadContext
// or Close.
func (r *ReadStream) Read(p []byte) (int, error) { func (r *ReadStream) Read(p []byte) (int, error) {
return r.ReadContext(context.Background(), p)
}
// ReadContext reads data from the transfer stream.
// The data will come from at most a single transfer, so the returned number
// might be smaller than the length of p.
// After a non-nil error is returned, all subsequent attempts to read will
// return io.ErrClosedPipe.
// ReadContext cannot be called concurrently with other Read, ReadContext
// or Close.
// The context passed controls the cancellation of this particular read
// operation within the stream. The semantics is identical to
// Endpoint.ReadContext.
func (r *ReadStream) ReadContext(ctx context.Context, p []byte) (int, error) {
if r.s.transfers == nil { if r.s.transfers == nil {
return 0, io.ErrClosedPipe return 0, io.ErrClosedPipe
} }
@@ -111,7 +129,7 @@ func (r *ReadStream) Read(p []byte) (int, error) {
r.s.transfers = nil r.s.transfers = nil
return 0, r.s.err return 0, r.s.err
} }
n, err := t.wait() n, err := t.wait(ctx)
if err != nil { if err != nil {
// wait error aborts immediately, all remaining data is invalid. // wait error aborts immediately, all remaining data is invalid.
t.free() t.free()
@@ -183,6 +201,25 @@ type WriteStream struct {
// call after Close() has returned. // call after Close() has returned.
// Write cannot be called concurrently with another Write, Written or Close. // Write cannot be called concurrently with another Write, Written or Close.
func (w *WriteStream) Write(p []byte) (int, error) { func (w *WriteStream) Write(p []byte) (int, error) {
return w.WriteContext(context.Background(), p)
}
// WriteContext sends the data to the endpoint. Write returning a nil error doesn't
// mean that data was written to the device, only that it was written to the
// buffer. Only a call to Close() that returns nil error guarantees that
// all transfers have succeeded.
// If the slice passed to WriteContext does not align exactly with the transfer
// buffer size (as declared in a call to NewStream), the last USB transfer
// of this Write will be sent with less data than the full buffer.
// After a non-nil error is returned, all subsequent attempts to write will
// return io.ErrClosedPipe.
// If WriteContext encounters an error when preparing the transfer, the stream
// will still try to complete any pending transfers. The total number
// of bytes successfully written can be retrieved through a Written()
// call after Close() has returned.
// WriteContext cannot be called concurrently with another Write, WriteContext,
// Written, Close or CloseContext.
func (w *WriteStream) WriteContext(ctx context.Context, p []byte) (int, error) {
if w.s.transfers == nil || w.s.err != nil { if w.s.transfers == nil || w.s.err != nil {
return 0, io.ErrClosedPipe return 0, io.ErrClosedPipe
} }
@@ -190,7 +227,7 @@ func (w *WriteStream) Write(p []byte) (int, error) {
all := len(p) all := len(p)
for written < all { for written < all {
t := <-w.s.transfers t := <-w.s.transfers
n, err := t.wait() // unsubmitted transfers will return 0 bytes and no error n, err := t.wait(ctx) // unsubmitted transfers will return 0 bytes and no error
w.total += n w.total += n
if err != nil { if err != nil {
t.free() t.free()
@@ -229,12 +266,24 @@ func (w *WriteStream) Write(p []byte) (int, error) {
// retrieved using Written(). // retrieved using Written().
// Close may not be called concurrently with Write, Close or Written. // Close may not be called concurrently with Write, Close or Written.
func (w *WriteStream) Close() error { func (w *WriteStream) Close() error {
return w.CloseContext(context.Background())
}
// Close signals end of data to write. Close blocks until all transfers
// that were sent are finished. The error returned by Close is the first
// error encountered during writing the entire stream (if any).
// Close returning nil indicates all transfers completed successfully.
// After Close, the total number of bytes successfully written can be
// retrieved using Written().
// Close may not be called concurrently with Write, Close or Written.
// CloseContext
func (w *WriteStream) CloseContext(ctx context.Context) error {
if w.s.transfers == nil { if w.s.transfers == nil {
return io.ErrClosedPipe return io.ErrClosedPipe
} }
w.s.noMore() w.s.noMore()
for t := range w.s.transfers { for t := range w.s.transfers {
n, err := t.wait() n, err := t.wait(ctx)
w.total += n w.total += n
t.free() t.free()
if err != nil { if err != nil {
@@ -248,7 +297,8 @@ func (w *WriteStream) Close() error {
} }
// Written returns the number of bytes successfully written by the stream. // Written returns the number of bytes successfully written by the stream.
// Written may be called only after Close() has been called and returned. // Written may be called only after Close() or CloseContext()
// has been called and returned.
func (w *WriteStream) Written() int { func (w *WriteStream) Written() int {
return w.total return w.total
} }

View File

@@ -16,6 +16,7 @@ package gousb
import ( import (
"bytes" "bytes"
"context"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@@ -57,7 +58,7 @@ func (f *fakeStreamTransfer) submit() error {
return nil return nil
} }
func (f *fakeStreamTransfer) wait() (int, error) { func (f *fakeStreamTransfer) wait(ctx context.Context) (int, error) {
if f.released { if f.released {
return 0, errors.New("wait() called on a free()d transfer") return 0, errors.New("wait() called on a free()d transfer")
} }

View File

@@ -15,8 +15,8 @@
package gousb package gousb
import ( import (
"context"
"testing" "testing"
"time"
) )
func TestNewTransfer(t *testing.T) { func TestNewTransfer(t *testing.T) {
@@ -34,7 +34,6 @@ func TestNewTransfer(t *testing.T) {
tt TransferType tt TransferType
maxPkt int maxPkt int
buf int buf int
timeout time.Duration
wantIso int wantIso int
wantLength int wantLength int
wantTimeout int wantTimeout int
@@ -45,7 +44,6 @@ func TestNewTransfer(t *testing.T) {
tt: TransferTypeBulk, tt: TransferTypeBulk,
maxPkt: 512, maxPkt: 512,
buf: 1024, buf: 1024,
timeout: time.Second,
wantLength: 1024, wantLength: 1024,
}, },
{ {
@@ -62,7 +60,7 @@ func TestNewTransfer(t *testing.T) {
Direction: tc.dir, Direction: tc.dir,
TransferType: tc.tt, TransferType: tc.tt,
MaxPacketSize: tc.maxPkt, MaxPacketSize: tc.maxPkt,
}, tc.buf, tc.timeout) }, tc.buf)
if err != nil { if err != nil {
t.Fatalf("newUSBTransfer(): %v", err) t.Fatalf("newUSBTransfer(): %v", err)
@@ -92,41 +90,37 @@ func TestTransferProtocol(t *testing.T) {
Direction: EndpointDirectionIn, Direction: EndpointDirectionIn,
TransferType: TransferTypeBulk, TransferType: TransferTypeBulk,
MaxPacketSize: 512, MaxPacketSize: 512,
}, 10240, time.Second) }, 10240)
if err != nil { if err != nil {
t.Fatalf("newUSBTransfer: %v", err) t.Fatalf("newUSBTransfer: %v", err)
} }
} }
partial := make(chan struct{})
go func() { go func() {
ft := f.waitForSubmitted(nil) ft := f.waitForSubmitted(nil)
ft.length = 5 ft.setData([]byte{1, 2, 3, 4, 5})
ft.status = TransferCompleted ft.setStatus(TransferCompleted)
copy(ft.buf, []byte{1, 2, 3, 4, 5})
ft.done <- struct{}{}
ft = f.waitForSubmitted(nil) ft = f.waitForSubmitted(nil)
ft.length = 99 ft.setData(make([]byte, 99))
ft.status = TransferCompleted ft.setStatus(TransferCompleted)
copy(ft.buf, []byte{12, 12, 12, 12, 12})
ft.done <- struct{}{}
ft = f.waitForSubmitted(nil) ft = f.waitForSubmitted(nil)
ft.length = 123 ft.setData(make([]byte, 123))
ft.status = TransferCancelled close(partial)
ft.done <- struct{}{}
}() }()
xfers[0].submit() xfers[0].submit()
xfers[1].submit() xfers[1].submit()
got, err := xfers[0].wait() got, err := xfers[0].wait(context.Background())
if err != nil { if err != nil {
t.Errorf("xfer#0.wait returned error %v, want nil", err) t.Errorf("xfer#0.wait returned error %v, want nil", err)
} }
if want := 5; got != want { if want := 5; got != want {
t.Errorf("xfer#0.wait returned %d bytes, want %d", got, want) t.Errorf("xfer#0.wait returned %d bytes, want %d", got, want)
} }
got, err = xfers[1].wait() got, err = xfers[1].wait(context.Background())
if err != nil { if err != nil {
t.Errorf("xfer#0.wait returned error %v, want nil", err) t.Errorf("xfer#0.wait returned error %v, want nil", err)
} }
@@ -135,8 +129,9 @@ func TestTransferProtocol(t *testing.T) {
} }
xfers[1].submit() xfers[1].submit()
<-partial
xfers[1].cancel() xfers[1].cancel()
got, err = xfers[1].wait() got, err = xfers[1].wait(context.Background())
if err == nil { if err == nil {
t.Error("xfer#1(resubmitted).wait returned error nil, want non-nil") t.Error("xfer#1(resubmitted).wait returned error nil, want non-nil")
} }
@@ -146,7 +141,7 @@ func TestTransferProtocol(t *testing.T) {
for _, x := range xfers { for _, x := range xfers {
x.cancel() x.cancel()
x.wait() x.wait(context.Background())
x.free() x.free()
} }
} }