Home | History | Annotate | Download | only in usb_gadget
      1 #!/usr/bin/python
      2 # Copyright 2014 The Chromium Authors. All rights reserved.
      3 # Use of this source code is governed by a BSD-style license that can be
      4 # found in the LICENSE file.
      5 
      6 import unittest
      7 
      8 import mock
      9 
     10 import hid_constants
     11 import hid_descriptors
     12 import hid_gadget
     13 import usb_constants
     14 
     15 
     16 report_desc = hid_descriptors.ReportDescriptor(
     17     hid_descriptors.UsagePage(0xFF00),  # Vendor Defined
     18     hid_descriptors.Usage(0x00),
     19     hid_descriptors.Collection(
     20         hid_constants.CollectionType.APPLICATION,
     21         hid_descriptors.LogicalMinimum(0, force_length=1),
     22         hid_descriptors.LogicalMaximum(255, force_length=2),
     23         hid_descriptors.ReportSize(8),
     24         hid_descriptors.ReportCount(8),
     25         hid_descriptors.Input(hid_descriptors.Data,
     26                               hid_descriptors.Variable,
     27                               hid_descriptors.Absolute,
     28                               hid_descriptors.BufferedBytes),
     29         hid_descriptors.Output(hid_descriptors.Data,
     30                                hid_descriptors.Variable,
     31                                hid_descriptors.Absolute,
     32                                hid_descriptors.BufferedBytes),
     33         hid_descriptors.Feature(hid_descriptors.Data,
     34                                 hid_descriptors.Variable,
     35                                 hid_descriptors.Absolute,
     36                                 hid_descriptors.BufferedBytes)
     37     )
     38 )
     39 
     40 combo_report_desc = hid_descriptors.ReportDescriptor(
     41     hid_descriptors.ReportID(1),
     42     report_desc,
     43     hid_descriptors.ReportID(2),
     44     report_desc
     45 )
     46 
     47 
     48 class HidGadgetTest(unittest.TestCase):
     49 
     50   def test_bad_intervals(self):
     51     with self.assertRaisesRegexp(ValueError, 'Full speed'):
     52       hid_gadget.HidGadget(report_desc, features={}, interval_ms=50000,
     53                            vendor_id=0, product_id=0)
     54     with self.assertRaisesRegexp(ValueError, 'High speed'):
     55       hid_gadget.HidGadget(report_desc, features={}, interval_ms=5000,
     56                            vendor_id=0, product_id=0)
     57 
     58   def test_get_string_descriptor(self):
     59     g = hid_gadget.HidGadget(report_desc=report_desc, features={},
     60                              vendor_id=0, product_id=0)
     61     g.AddStringDescriptor(2, 'HID Gadget')
     62     desc = g.ControlRead(0x80, 6, 0x0302, 0x0409, 255)
     63     self.assertEquals(desc, '\x16\x03H\0I\0D\0 \0G\0a\0d\0g\0e\0t\0')
     64 
     65   def test_get_report_descriptor(self):
     66     g = hid_gadget.HidGadget(report_desc=report_desc, features={},
     67                              vendor_id=0, product_id=0)
     68     desc = g.ControlRead(0x81, 6, 0x2200, 0, 63)
     69     self.assertEquals(desc, report_desc)
     70 
     71   def test_set_idle(self):
     72     g = hid_gadget.HidGadget(report_desc=report_desc, features={},
     73                              vendor_id=0, product_id=0)
     74     self.assertTrue(g.ControlWrite(0x21, 0x0A, 0, 0, ''))
     75 
     76   def test_class_wrong_target(self):
     77     g = hid_gadget.HidGadget(report_desc=report_desc, features={},
     78                              vendor_id=0, product_id=0)
     79     self.assertIsNone(g.ControlRead(0xA0, 0, 0, 0, 0))  # Device
     80     self.assertIsNone(g.ControlRead(0xA1, 0, 0, 1, 0))  # Interface 1
     81     self.assertIsNone(g.ControlWrite(0x20, 0, 0, 0, ''))  # Device
     82     self.assertIsNone(g.ControlWrite(0x21, 0, 0, 1, ''))  # Interface 1
     83 
     84   def test_send_report_zero(self):
     85     g = hid_gadget.HidGadget(report_desc=report_desc, features={},
     86                              vendor_id=0, product_id=0)
     87     chip = mock.Mock()
     88     g.Connected(chip, usb_constants.Speed.HIGH)
     89     g.SendReport(0, 'Hello world!')
     90     chip.SendPacket.assert_called_once_with(0x81, 'Hello world!')
     91 
     92   def test_send_multiple_reports(self):
     93     g = hid_gadget.HidGadget(report_desc=report_desc, features={},
     94                              vendor_id=0, product_id=0)
     95     chip = mock.Mock()
     96     g.Connected(chip, usb_constants.Speed.HIGH)
     97     g.SendReport(1, 'Hello!')
     98     g.SendReport(2, 'World!')
     99     chip.SendPacket.assert_has_calls([
    100         mock.call(0x81, '\x01Hello!'),
    101         mock.call(0x81, '\x02World!'),
    102     ])
    103 
    104 
    105 class TestFeature(hid_gadget.HidFeature):
    106 
    107   def SetInputReport(self, data):
    108     self.input_report = data
    109     return True
    110 
    111   def SetOutputReport(self, data):
    112     self.output_report = data
    113     return True
    114 
    115   def SetFeatureReport(self, data):
    116     self.feature_report = data
    117     return True
    118 
    119   def GetInputReport(self):
    120     return 'Input report.'
    121 
    122   def GetOutputReport(self):
    123     return 'Output report.'
    124 
    125   def GetFeatureReport(self):
    126     return 'Feature report.'
    127 
    128 
    129 class HidFeatureTest(unittest.TestCase):
    130 
    131   def test_disconnected(self):
    132     feature = TestFeature()
    133     with self.assertRaisesRegexp(RuntimeError, 'not connected'):
    134       feature.SendReport('Hello world!')
    135 
    136   def test_send_report(self):
    137     feature = TestFeature()
    138     g = hid_gadget.HidGadget(report_desc, features={1: feature},
    139                              vendor_id=0, product_id=0)
    140     chip = mock.Mock()
    141     g.Connected(chip, usb_constants.Speed.HIGH)
    142     feature.SendReport('Hello world!')
    143     chip.SendPacket.assert_called_once_with(0x81, '\x01Hello world!')
    144     g.Disconnected()
    145 
    146   def test_get_bad_report(self):
    147     feature = TestFeature()
    148     g = hid_gadget.HidGadget(report_desc, features={1: feature},
    149                              vendor_id=0, product_id=0)
    150     self.assertIsNone(g.ControlRead(0xA1, 1, 0x0102, 0, 8))
    151 
    152   def test_set_bad_report(self):
    153     feature = TestFeature()
    154     g = hid_gadget.HidGadget(report_desc, features={1: feature},
    155                              vendor_id=0, product_id=0)
    156     self.assertIsNone(g.ControlWrite(0x21, 0x09, 0x0102, 0, 'Hello!'))
    157 
    158   def test_get_input_report(self):
    159     feature = TestFeature()
    160     g = hid_gadget.HidGadget(report_desc, features={1: feature},
    161                              vendor_id=0, product_id=0)
    162     report = g.ControlRead(0xA1, 1, 0x0101, 0, 8)
    163     self.assertEquals(report, 'Input re')
    164 
    165   def test_set_input_report(self):
    166     feature = TestFeature()
    167     g = hid_gadget.HidGadget(report_desc, features={1: feature},
    168                              vendor_id=0, product_id=0)
    169     self.assertTrue(g.ControlWrite(0x21, 0x09, 0x0101, 0, 'Hello!'))
    170     self.assertEquals(feature.input_report, 'Hello!')
    171 
    172   def test_get_output_report(self):
    173     feature = TestFeature()
    174     g = hid_gadget.HidGadget(report_desc, features={1: feature},
    175                              vendor_id=0, product_id=0)
    176     report = g.ControlRead(0xA1, 1, 0x0201, 0, 8)
    177     self.assertEquals(report, 'Output r')
    178 
    179   def test_set_output_report(self):
    180     feature = TestFeature()
    181     g = hid_gadget.HidGadget(report_desc, features={1: feature},
    182                              vendor_id=0, product_id=0)
    183     self.assertTrue(g.ControlWrite(0x21, 0x09, 0x0201, 0, 'Hello!'))
    184     self.assertEquals(feature.output_report, 'Hello!')
    185 
    186   def test_receive_interrupt(self):
    187     feature = TestFeature()
    188     g = hid_gadget.HidGadget(report_desc, features={1: feature},
    189                              vendor_id=0, product_id=0)
    190     chip = mock.Mock()
    191     g.Connected(chip, usb_constants.Speed.HIGH)
    192     g.ReceivePacket(0x01, '\x01Hello!')
    193     self.assertFalse(chip.HaltEndpoint.called)
    194     self.assertEquals(feature.output_report, 'Hello!')
    195 
    196   def test_receive_interrupt_report_zero(self):
    197     feature = TestFeature()
    198     g = hid_gadget.HidGadget(report_desc, features={0: feature},
    199                              vendor_id=0, product_id=0)
    200     chip = mock.Mock()
    201     g.Connected(chip, usb_constants.Speed.HIGH)
    202     g.ReceivePacket(0x01, 'Hello!')
    203     self.assertFalse(chip.HaltEndpoint.called)
    204     self.assertEquals(feature.output_report, 'Hello!')
    205 
    206   def test_receive_bad_interrupt(self):
    207     feature = TestFeature()
    208     g = hid_gadget.HidGadget(report_desc, features={1: feature},
    209                              vendor_id=0, product_id=0)
    210     chip = mock.Mock()
    211     g.Connected(chip, usb_constants.Speed.HIGH)
    212     g.ReceivePacket(0x01, '\x00Hello!')
    213     chip.HaltEndpoint.assert_called_once_with(0x01)
    214 
    215   def test_get_feature_report(self):
    216     feature = TestFeature()
    217     g = hid_gadget.HidGadget(report_desc, features={1: feature},
    218                              vendor_id=0, product_id=0)
    219     report = g.ControlRead(0xA1, 1, 0x0301, 0, 8)
    220     self.assertEquals(report, 'Feature ')
    221 
    222   def test_set_feature_report(self):
    223     feature = TestFeature()
    224     g = hid_gadget.HidGadget(report_desc, features={1: feature},
    225                              vendor_id=0, product_id=0)
    226     self.assertTrue(g.ControlWrite(0x21, 0x09, 0x0301, 0, 'Hello!'))
    227     self.assertEquals(feature.feature_report, 'Hello!')
    228 
    229 
    230 if __name__ == '__main__':
    231   unittest.main()
    232