From 57b10f0dd3af969b7b95ad97f57c0758147ca62d Mon Sep 17 00:00:00 2001 From: Sebastian Zagrodzki Date: Mon, 10 Apr 2017 01:00:53 +0200 Subject: [PATCH] Remove OpenEndpoint, add InEndpoint for Read transfers. --- rawread/main.go | 4 +- usb/device.go | 16 ++++-- usb/device_test.go | 6 +-- usb/endpoint.go | 56 ++++++++++--------- usb/endpoint_test.go | 126 ++++++++++++++++--------------------------- 5 files changed, 90 insertions(+), 118 deletions(-) diff --git a/rawread/main.go b/rawread/main.go index 40afdaf..ad2daec 100644 --- a/rawread/main.go +++ b/rawread/main.go @@ -33,7 +33,7 @@ var ( config = flag.Uint("config", 1, "Configuration number to use with the device.") iface = flag.Uint("interface", 0, "Interface to use on the device.") setup = flag.Uint("setup", 0, "Alternate setting to use on the interface.") - endpoint = flag.Uint("endpoint", 1, "Endpoint number to which to connect.") + endpoint = flag.Uint("endpoint", 1, "Endpoint number to which to connect (without the leading 0x8).") debug = flag.Int("debug", 3, "Debug level for libusb.") size = flag.Uint("read_size", 1024, "Number of bytes of data to read in a single transaction.") num = flag.Uint("read_num", 0, "Number of read transactions to perform. 0 means infinite.") @@ -139,7 +139,7 @@ func main() { dev := devs[0] log.Printf("Connecting to endpoint %d...", *endpoint) - ep, err := dev.OpenEndpoint(uint8(*config), uint8(*iface), uint8(*setup), uint8(*endpoint)) + ep, err := dev.InEndpoint(uint8(*config), uint8(*iface), uint8(*setup), uint8(*endpoint)) if err != nil { log.Fatalf("open: %s", err) } diff --git a/usb/device.go b/usb/device.go index f59ec66..aaa07bd 100644 --- a/usb/device.go +++ b/usb/device.go @@ -100,8 +100,7 @@ func (d *Device) Close() error { return nil } -// OpenEndpoint prepares a device endpoint for data transfer. -func (d *Device) OpenEndpoint(epAddr, cfgNum, ifNum, setNum uint8) (*Endpoint, error) { +func (d *Device) openEndpoint(cfgNum, ifNum, setNum, epAddr uint8) (*endpoint, error) { var cfg *ConfigInfo for _, c := range d.Configs { if c.Config == cfgNum { @@ -150,7 +149,7 @@ func (d *Device) OpenEndpoint(epAddr, cfgNum, ifNum, setNum uint8) (*Endpoint, e return nil, fmt.Errorf("usb: didn't find endpoint address 0x%02x", epAddr) } - end := newEndpoint(d.handle, *ifs, *ep, d.ReadTimeout, d.WriteTimeout) + end := newEndpoint(d.handle, *ifs, *ep) // Set the configuration activeConf, err := libusb.getConfig(d.handle) @@ -183,6 +182,17 @@ func (d *Device) OpenEndpoint(epAddr, cfgNum, ifNum, setNum uint8) (*Endpoint, e return end, nil } +func (d *Device) InEndpoint(cfgNum, ifNum, setNum, epNum uint8) (*InEndpoint, error) { + ep, err := d.openEndpoint(cfgNum, ifNum, setNum, endpointAddr(epNum, EndpointDirectionIn)) + if err != nil { + return nil, err + } + return &InEndpoint{ + endpoint: ep, + timeout: d.ReadTimeout, + }, nil +} + // GetStringDescriptor returns a device string descriptor with the given index // number. The first supported language is always used and the returned // descriptor string is converted to ASCII (non-ASCII characters are replaced diff --git a/usb/device_test.go b/usb/device_test.go index 9b0d4eb..9224d0c 100644 --- a/usb/device_test.go +++ b/usb/device_test.go @@ -33,11 +33,11 @@ func TestOpenEndpoint(t *testing.T) { if err != nil { t.Fatalf("OpenDeviceWithVidPid(0x8888, 0x0002): got error %v, want nil", err) } - got, err := dev.OpenEndpoint(0x86, 1, 1, 2) + got, err := dev.InEndpoint(1, 1, 2, 6) if err != nil { - t.Fatalf("OpenEndpoint(cfg=1, if=1, alt=2, ep=0x86): got error %v, want nil", err) + t.Fatalf("InEndpoint(cfg=1, if=1, alt=2, ep=6IN): got error %v, want nil", err) } if want := fakeDevices[1].Configs[0].Interfaces[1].AltSettings[2].Endpoints[1]; !reflect.DeepEqual(got.Info, want) { - t.Errorf("OpenEndpoint(cfg=1, if=1, alt=2, ep=0x86): got %+v, want %+v", got, want) + t.Errorf("InEndpoint(cfg=1, if=1, alt=2, ep=6IN): got %+v, want %+v", got, want) } } diff --git a/usb/endpoint.go b/usb/endpoint.go index 269bccd..7ccd141 100644 --- a/usb/endpoint.go +++ b/usb/endpoint.go @@ -68,41 +68,19 @@ func (e EndpointInfo) String() string { return strings.Join(ret, " ") } -// Endpoint identifies a USB endpoint opened for transfer. -type Endpoint struct { +type endpoint struct { h *libusbDevHandle InterfaceSetting Info EndpointInfo - - readTimeout time.Duration - writeTimeout time.Duration } // String returns a human-readable description of the endpoint. -func (e *Endpoint) String() string { +func (e *endpoint) String() string { return e.Info.String() } -// Read reads data from an IN endpoint. -func (e *Endpoint) Read(buf []byte) (int, error) { - if e.Info.Direction != EndpointDirectionIn { - return 0, fmt.Errorf("usb: read: not an IN endpoint") - } - - return e.transfer(buf, e.readTimeout) -} - -// Write writes data to an OUT endpoint. -func (e *Endpoint) Write(buf []byte) (int, error) { - if e.Info.Direction != EndpointDirectionOut { - return 0, fmt.Errorf("usb: write: not an OUT endpoint") - } - - return e.transfer(buf, e.writeTimeout) -} - -func (e *Endpoint) transfer(buf []byte, timeout time.Duration) (int, error) { +func (e *endpoint) transfer(buf []byte, timeout time.Duration) (int, error) { if len(buf) == 0 { return 0, nil } @@ -124,12 +102,32 @@ func (e *Endpoint) transfer(buf []byte, timeout time.Duration) (int, error) { return n, nil } -func newEndpoint(h *libusbDevHandle, s InterfaceSetting, e EndpointInfo, rt, wt time.Duration) *Endpoint { - return &Endpoint{ +func newEndpoint(h *libusbDevHandle, s InterfaceSetting, e EndpointInfo) *endpoint { + return &endpoint{ InterfaceSetting: s, Info: e, h: h, - readTimeout: rt, - writeTimeout: wt, } } + +// OutEndpoint represents an IN endpoint open for transfer. +type InEndpoint struct { + *endpoint + timeout time.Duration +} + +// Read reads data from an IN endpoint. +func (e *InEndpoint) Read(buf []byte) (int, error) { + return e.transfer(buf, e.timeout) +} + +// OutEndpoint represents an OUT endpoint open for transfer. +type OutEndpoint struct { + *endpoint + timeout time.Duration +} + +// Write writes data to an OUT endpoint. +func (e *OutEndpoint) Write(buf []byte) (int, error) { + return e.transfer(buf, e.timeout) +} diff --git a/usb/endpoint_test.go b/usb/endpoint_test.go index eee2896..198923e 100644 --- a/usb/endpoint_test.go +++ b/usb/endpoint_test.go @@ -15,7 +15,6 @@ package usb import ( - "reflect" "testing" "time" ) @@ -51,90 +50,55 @@ var testIsoOutSetting = InterfaceSetting{ Endpoints: []EndpointInfo{testIsoOutEP}, } -func TestEndpoint(t *testing.T) { +func TestInEndpoint(t *testing.T) { defer func(i libusbIntf) { libusb = i }(libusb) - - for _, epCfg := range []struct { - method string - InterfaceSetting - EndpointInfo + for _, tc := range []struct { + desc string + buf []byte + ret int + status TransferStatus + want int + wantErr bool }{ - {"Read", testBulkInSetting, testBulkInEP}, - {"Write", testIsoOutSetting, testIsoOutEP}, + { + desc: "empty buffer", + buf: nil, + ret: 10, + want: 0, + }, + { + desc: "128B buffer, 60 transferred", + buf: make([]byte, 128), + ret: 60, + want: 60, + }, + { + desc: "128B buffer, 10 transferred and then error", + buf: make([]byte, 128), + ret: 10, + status: TransferError, + want: 10, + wantErr: true, + }, } { - t.Run(epCfg.method, func(t *testing.T) { - for _, tc := range []struct { - desc string - buf []byte - ret int - status TransferStatus - want int - wantErr bool - }{ - { - desc: "empty buffer", - buf: nil, - ret: 10, - want: 0, - }, - { - desc: "128B buffer, 60 transferred", - buf: make([]byte, 128), - ret: 60, - want: 60, - }, - { - desc: "128B buffer, 10 transferred and then error", - buf: make([]byte, 128), - ret: 10, - status: TransferError, - want: 10, - wantErr: true, - }, - } { - lib := newFakeLibusb() - libusb = lib - ep := newEndpoint(nil, epCfg.InterfaceSetting, epCfg.EndpointInfo, time.Second, time.Second) - op, ok := reflect.TypeOf(ep).MethodByName(epCfg.method) - if !ok { - t.Fatalf("method %s not found in endpoint struct", epCfg.method) - } - go func() { - fakeT := lib.waitForSubmitted() - fakeT.length = tc.ret - fakeT.status = tc.status - close(fakeT.done) - }() - opv := op.Func.Interface().(func(*Endpoint, []byte) (int, error)) - got, err := opv(ep, tc.buf) - if (err != nil) != tc.wantErr { - t.Errorf("%s: bulkInEP.Read(): got err: %v, err != nil is %v, want %v", tc.desc, err, err != nil, tc.wantErr) - continue - } - if got != tc.want { - t.Errorf("%s: bulkInEP.Read(): got %d bytes, want %d", tc.desc, got, tc.want) - } - } - }) - } -} + lib := newFakeLibusb() + libusb = lib -func TestEndpointWrongDirection(t *testing.T) { - ep := &Endpoint{ - InterfaceSetting: testBulkInSetting, - Info: testBulkInEP, - } - _, err := ep.Write([]byte{1, 2, 3}) - if err == nil { - t.Error("bulkInEP.Write(): got nil error, want non-nil") - } - ep = &Endpoint{ - InterfaceSetting: testIsoOutSetting, - Info: testIsoOutEP, - } - _, err = ep.Read(make([]byte, 64)) - if err == nil { - t.Error("isoOutEP.Read(): got nil error, want non-nil") + ep := InEndpoint{newEndpoint(nil, testBulkInSetting, testBulkInEP), time.Second} + go func() { + fakeT := lib.waitForSubmitted() + fakeT.length = tc.ret + fakeT.status = tc.status + close(fakeT.done) + }() + got, err := ep.Read(tc.buf) + if (err != nil) != tc.wantErr { + t.Errorf("%s: bulkInEP.Read(): got err: %v, err != nil is %v, want %v", tc.desc, err, err != nil, tc.wantErr) + continue + } + if got != tc.want { + t.Errorf("%s: bulkInEP.Read(): got %d bytes, want %d", tc.desc, got, tc.want) + } } }