Home | History | Annotate | Download | only in tcp_client
      1 #
      2 # Copyright (C) 2016 The Android Open Source Project
      3 #
      4 # Licensed under the Apache License, Version 2.0 (the "License");
      5 # you may not use this file except in compliance with the License.
      6 # You may obtain a copy of the License at
      7 #
      8 #      http://www.apache.org/licenses/LICENSE-2.0
      9 #
     10 # Unless required by applicable law or agreed to in writing, software
     11 # distributed under the License is distributed on an "AS IS" BASIS,
     12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 # See the License for the specific language governing permissions and
     14 # limitations under the License.
     15 #
     16 
     17 import json
     18 import logging
     19 import os
     20 import socket
     21 import time
     22 import types
     23 
     24 from vts.proto import AndroidSystemControlMessage_pb2 as SysMsg_pb2
     25 from vts.proto import ComponentSpecificationMessage_pb2 as CompSpecMsg_pb2
     26 from vts.runners.host import const
     27 from vts.runners.host import errors
     28 from vts.utils.python.mirror import mirror_object
     29 
     30 from google.protobuf import text_format
     31 
     32 TARGET_IP = os.environ.get("TARGET_IP", None)
     33 TARGET_PORT = os.environ.get("TARGET_PORT", None)
     34 _DEFAULT_SOCKET_TIMEOUT_SECS = 1800
     35 _SOCKET_CONN_TIMEOUT_SECS = 60
     36 _SOCKET_CONN_RETRY_NUMBER = 5
     37 COMMAND_TYPE_NAME = {
     38     1: "LIST_HALS",
     39     2: "SET_HOST_INFO",
     40     101: "CHECK_DRIVER_SERVICE",
     41     102: "LAUNCH_DRIVER_SERVICE",
     42     103: "VTS_AGENT_COMMAND_READ_SPECIFICATION",
     43     201: "LIST_APIS",
     44     202: "CALL_API",
     45     203: "VTS_AGENT_COMMAND_GET_ATTRIBUTE",
     46     301: "VTS_AGENT_COMMAND_EXECUTE_SHELL_COMMAND"
     47 }
     48 
     49 
     50 class VtsTcpClient(object):
     51     """VTS TCP Client class.
     52 
     53     Attribute:
     54         connection: a TCP socket instance.
     55         channel: a file to write and read data.
     56         _mode: the connection mode (adb_forwarding or ssh_tunnel)
     57     """
     58 
     59     def __init__(self, mode="adb_forwarding"):
     60         self.connection = None
     61         self.channel = None
     62         self._mode = mode
     63 
     64     def Connect(self,
     65                 ip=TARGET_IP,
     66                 command_port=TARGET_PORT,
     67                 callback_port=None,
     68                 retry=_SOCKET_CONN_RETRY_NUMBER):
     69         """Connects to a target device.
     70 
     71         Args:
     72             ip: string, the IP address of a target device.
     73             command_port: int, the TCP port which can be used to connect to
     74                           a target device.
     75             callback_port: int, the TCP port number of a host-side callback
     76                            server.
     77             retry: int, the number of times to retry connecting before giving
     78                    up.
     79 
     80         Returns:
     81             True if success, False otherwise
     82 
     83         Raises:
     84             socket.error when the connection fails.
     85         """
     86         if not command_port:
     87             logging.error("ip %s, command_port %s, callback_port %s invalid",
     88                           ip, command_port, callback_port)
     89             return False
     90 
     91         for i in xrange(retry):
     92             try:
     93                 self.connection = socket.create_connection(
     94                     (ip, command_port), _SOCKET_CONN_TIMEOUT_SECS)
     95                 self.connection.settimeout(_DEFAULT_SOCKET_TIMEOUT_SECS)
     96             except socket.error as e:
     97                 # Wait a bit and retry.
     98                 logging.exception("Connect failed %s", e)
     99                 time.sleep(1)
    100                 if i + 1 == retry:
    101                     raise errors.VtsTcpClientCreationError(
    102                         "Couldn't connect to %s:%s" % (ip, command_port))
    103         self.channel = self.connection.makefile(mode="brw")
    104 
    105         if callback_port is not None:
    106             self.SendCommand(
    107                 SysMsg_pb2.SET_HOST_INFO, callback_port=callback_port)
    108             resp = self.RecvResponse()
    109             if (resp.response_code != SysMsg_pb2.SUCCESS):
    110                 return False
    111         return True
    112 
    113     def Disconnect(self):
    114         """Disconnects from the target device.
    115 
    116         TODO(yim): Send a msg to the target side to teardown handler session
    117         and release memory before closing the socket.
    118         """
    119         if self.connection is not None:
    120             self.channel = None
    121             self.connection.close()
    122             self.connection = None
    123 
    124     def ListHals(self, base_paths):
    125         """RPC to LIST_HALS."""
    126         self.SendCommand(SysMsg_pb2.LIST_HALS, paths=base_paths)
    127         resp = self.RecvResponse()
    128         if (resp.response_code == SysMsg_pb2.SUCCESS):
    129             return resp.file_names
    130         return None
    131 
    132     def CheckDriverService(self, service_name):
    133         """RPC to CHECK_DRIVER_SERVICE."""
    134         self.SendCommand(
    135             SysMsg_pb2.CHECK_DRIVER_SERVICE, service_name=service_name)
    136         resp = self.RecvResponse()
    137         return (resp.response_code == SysMsg_pb2.SUCCESS)
    138 
    139     def LaunchDriverService(self,
    140                             driver_type,
    141                             service_name,
    142                             bits,
    143                             file_path=None,
    144                             target_class=None,
    145                             target_type=None,
    146                             target_version=None,
    147                             target_package=None,
    148                             target_component_name=None,
    149                             hw_binder_service_name=None):
    150         """RPC to LAUNCH_DRIVER_SERVICE."""
    151         logging.info("service_name: %s", service_name)
    152         logging.info("file_path: %s", file_path)
    153         logging.info("bits: %s", bits)
    154         logging.info("driver_type: %s", driver_type)
    155         self.SendCommand(
    156             SysMsg_pb2.LAUNCH_DRIVER_SERVICE,
    157             driver_type=driver_type,
    158             service_name=service_name,
    159             bits=bits,
    160             file_path=file_path,
    161             target_class=target_class,
    162             target_type=target_type,
    163             target_version=target_version,
    164             target_package=target_package,
    165             target_component_name=target_component_name,
    166             hw_binder_service_name=hw_binder_service_name)
    167         resp = self.RecvResponse()
    168         logging.info("resp for LAUNCH_DRIVER_SERVICE: %s", resp)
    169         if driver_type == SysMsg_pb2.VTS_DRIVER_TYPE_HAL_HIDL \
    170                 or driver_type == SysMsg_pb2.VTS_DRIVER_TYPE_HAL_CONVENTIONAL \
    171                 or driver_type == SysMsg_pb2.VTS_DRIVER_TYPE_HAL_LEGACY:
    172             if resp.response_code == SysMsg_pb2.SUCCESS:
    173                 return int(resp.result)
    174             else:
    175                 return -1
    176         else:
    177             return (resp.response_code == SysMsg_pb2.SUCCESS)
    178 
    179     def ListApis(self):
    180         """RPC to LIST_APIS."""
    181         self.SendCommand(SysMsg_pb2.LIST_APIS)
    182         resp = self.RecvResponse()
    183         logging.info("resp for LIST_APIS: %s", resp)
    184         if (resp.response_code == SysMsg_pb2.SUCCESS):
    185             return resp.spec
    186         return None
    187 
    188     def GetPythonDataOfVariableSpecMsg(self, var_spec_msg):
    189         """Returns the python native data structure for a given message.
    190 
    191         Args:
    192             var_spec_msg: VariableSpecificationMessage
    193 
    194         Returns:
    195             python native data structure (e.g., string, integer, list).
    196 
    197         Raises:
    198             VtsUnsupportedTypeError if unsupported type is specified.
    199             VtsMalformedProtoStringError if StringDataValueMessage is
    200                 not populated.
    201         """
    202         if var_spec_msg.type == CompSpecMsg_pb2.TYPE_SCALAR:
    203             scalar_type = getattr(var_spec_msg, "scalar_type", "")
    204             if scalar_type:
    205                 return getattr(var_spec_msg.scalar_value, scalar_type)
    206         elif var_spec_msg.type == CompSpecMsg_pb2.TYPE_ENUM:
    207             scalar_type = getattr(var_spec_msg, "scalar_type", "")
    208             if scalar_type:
    209                 return getattr(var_spec_msg.scalar_value, scalar_type)
    210             else:
    211                 return var_spec_msg.scalar_value.int32_t
    212         elif var_spec_msg.type == CompSpecMsg_pb2.TYPE_STRING:
    213             if hasattr(var_spec_msg, "string_value"):
    214                 return getattr(var_spec_msg.string_value, "message", "")
    215             raise errors.VtsMalformedProtoStringError()
    216         elif var_spec_msg.type == CompSpecMsg_pb2.TYPE_STRUCT:
    217             result = {}
    218             index = 1
    219             for struct_value in var_spec_msg.struct_value:
    220                 if len(struct_value.name) > 0:
    221                     result[struct_value.
    222                            name] = self.GetPythonDataOfVariableSpecMsg(
    223                                struct_value)
    224                 else:
    225                     result["attribute%d" %
    226                            index] = self.GetPythonDataOfVariableSpecMsg(
    227                                struct_value)
    228                 index += 1
    229             return result
    230         elif var_spec_msg.type == CompSpecMsg_pb2.TYPE_UNION:
    231             result = VtsReturnValueObject()
    232             index = 1
    233             for union_value in var_spec_msg.union_value:
    234                 if len(union_value.name) > 0:
    235                     result[union_value.
    236                            name] = self.GetPythonDataOfVariableSpecMsg(
    237                                union_value)
    238                 else:
    239                     result["attribute%d" %
    240                            index] = self.GetPythonDataOfVariableSpecMsg(
    241                                union_value)
    242                 index += 1
    243             return result
    244         elif (var_spec_msg.type == CompSpecMsg_pb2.TYPE_VECTOR or
    245               var_spec_msg.type == CompSpecMsg_pb2.TYPE_ARRAY):
    246             result = []
    247             for vector_value in var_spec_msg.vector_value:
    248                 result.append(
    249                     self.GetPythonDataOfVariableSpecMsg(vector_value))
    250             return result
    251         elif (var_spec_msg.type == CompSpecMsg_pb2.TYPE_HIDL_INTERFACE):
    252             logging.debug("var_spec_msg: %s", var_spec_msg)
    253             return var_spec_msg
    254 
    255         raise errors.VtsUnsupportedTypeError("unsupported type %s" %
    256                                              var_spec_msg.type)
    257 
    258     def CallApi(self, arg, caller_uid=None):
    259         """RPC to CALL_API."""
    260         self.SendCommand(SysMsg_pb2.CALL_API, arg=arg, caller_uid=caller_uid)
    261         resp = self.RecvResponse()
    262         resp_code = resp.response_code
    263         if (resp_code == SysMsg_pb2.SUCCESS):
    264             result = CompSpecMsg_pb2.FunctionSpecificationMessage()
    265             if resp.result == "error":
    266                 raise errors.VtsTcpCommunicationError(
    267                     "API call error by the VTS driver.")
    268             try:
    269                 text_format.Merge(resp.result, result)
    270             except text_format.ParseError as e:
    271                 logging.exception(e)
    272                 logging.error("Paring error\n%s", resp.result)
    273             if result.return_type.type == CompSpecMsg_pb2.TYPE_SUBMODULE:
    274                 logging.info("returned a submodule spec")
    275                 logging.info("spec: %s", result.return_type_submodule_spec)
    276                 return mirror_object.MirrorObject(
    277                     self, result.return_type_submodule_spec, None)
    278 
    279             logging.info("result: %s", result.return_type_hidl)
    280             if len(result.return_type_hidl) == 1:
    281                 result_value = self.GetPythonDataOfVariableSpecMsg(
    282                     result.return_type_hidl[0])
    283             elif len(result.return_type_hidl) > 1:
    284                 result_value = []
    285                 for return_type_hidl in result.return_type_hidl:
    286                     result_value.append(
    287                         self.GetPythonDataOfVariableSpecMsg(return_type_hidl))
    288             else:  # For non-HIDL return value
    289                 if hasattr(result, "return_type"):
    290                     result_value = result
    291                 else:
    292                     result_value = None
    293 
    294             if hasattr(result, "raw_coverage_data"):
    295                 return result_value, {"coverage": result.raw_coverage_data}
    296             else:
    297                 return result_value
    298 
    299         logging.error("NOTICE - Likely a crash discovery!")
    300         logging.error("SysMsg_pb2.SUCCESS is %s", SysMsg_pb2.SUCCESS)
    301         raise errors.VtsTcpCommunicationError(
    302             "RPC Error, response code for %s is %s" % (arg, resp_code))
    303 
    304     def GetAttribute(self, arg):
    305         """RPC to VTS_AGENT_COMMAND_GET_ATTRIBUTE."""
    306         self.SendCommand(SysMsg_pb2.VTS_AGENT_COMMAND_GET_ATTRIBUTE, arg=arg)
    307         resp = self.RecvResponse()
    308         resp_code = resp.response_code
    309         if (resp_code == SysMsg_pb2.SUCCESS):
    310             result = CompSpecMsg_pb2.FunctionSpecificationMessage()
    311             if resp.result == "error":
    312                 raise errors.VtsTcpCommunicationError(
    313                     "Get attribute request failed on target.")
    314             try:
    315                 text_format.Merge(resp.result, result)
    316             except text_format.ParseError as e:
    317                 logging.exception(e)
    318                 logging.error("Paring error\n%s", resp.result)
    319             if result.return_type.type == CompSpecMsg_pb2.TYPE_SUBMODULE:
    320                 logging.info("returned a submodule spec")
    321                 logging.info("spec: %s", result.return_type_submodule_spec)
    322                 return mirror_object.MirrorObject(
    323                     self, result.return_type_submodule_spec, None)
    324             elif result.return_type.type == CompSpecMsg_pb2.TYPE_SCALAR:
    325                 return getattr(result.return_type.scalar_value,
    326                                result.return_type.scalar_type)
    327             return result
    328         logging.error("NOTICE - Likely a crash discovery!")
    329         logging.error("SysMsg_pb2.SUCCESS is %s", SysMsg_pb2.SUCCESS)
    330         raise errors.VtsTcpCommunicationError(
    331             "RPC Error, response code for %s is %s" % (arg, resp_code))
    332 
    333     def ExecuteShellCommand(self, command, no_except=False):
    334         """RPC to VTS_AGENT_COMMAND_EXECUTE_SHELL_COMMAND.
    335 
    336         Args:
    337             command: string or list of string, command to execute on device
    338             no_except: bool, whether to throw exceptions. If set to True,
    339                        when exception happens, return code will be -1 and
    340                        str(err) will be in stderr. Result will maintain the
    341                        same length as with input command.
    342 
    343         Returns:
    344             dictionary of list, command results that contains stdout,
    345             stderr, and exit_code.
    346         """
    347         if not no_except:
    348             return self.__ExecuteShellCommand(command)
    349 
    350         try:
    351             return self.__ExecuteShellCommand(command)
    352         except Exception as e:
    353             logging.exception(e)
    354             return {
    355                 const.STDOUT: [""] * len(command),
    356                 const.STDERR: [str(e)] * len(command),
    357                 const.EXIT_CODE: [-1] * len(command)
    358             }
    359 
    360     def __ExecuteShellCommand(self, command):
    361         """RPC to VTS_AGENT_COMMAND_EXECUTE_SHELL_COMMAND.
    362 
    363         Args:
    364             command: string or list of string, command to execute on device
    365 
    366         Returns:
    367             dictionary of list, command results that contains stdout,
    368             stderr, and exit_code.
    369         """
    370         self.SendCommand(
    371             SysMsg_pb2.VTS_AGENT_COMMAND_EXECUTE_SHELL_COMMAND,
    372             shell_command=command)
    373         resp = self.RecvResponse(retries=2)
    374         logging.info("resp for VTS_AGENT_COMMAND_EXECUTE_SHELL_COMMAND: %s",
    375                      resp)
    376 
    377         stdout = None
    378         stderr = None
    379         exit_code = None
    380 
    381         if not resp:
    382             logging.error("resp is: %s.", resp)
    383         elif resp.response_code != SysMsg_pb2.SUCCESS:
    384             logging.error("resp response code is not success: %s.",
    385                           resp.response_code)
    386         else:
    387             stdout = resp.stdout
    388             stderr = resp.stderr
    389             exit_code = resp.exit_code
    390 
    391         return {
    392             const.STDOUT: stdout,
    393             const.STDERR: stderr,
    394             const.EXIT_CODE: exit_code
    395         }
    396 
    397     def Ping(self):
    398         """RPC to send a PING request.
    399 
    400         Returns:
    401             True if the agent is alive, False otherwise.
    402         """
    403         self.SendCommand(SysMsg_pb2.PING)
    404         resp = self.RecvResponse()
    405         logging.info("resp for PING: %s", resp)
    406         if resp is not None and resp.response_code == SysMsg_pb2.SUCCESS:
    407             return True
    408         return False
    409 
    410     def ReadSpecification(self,
    411                           interface_name,
    412                           target_class,
    413                           target_type,
    414                           target_version,
    415                           target_package,
    416                           recursive=False):
    417         """RPC to VTS_AGENT_COMMAND_READ_SPECIFICATION.
    418 
    419         Args:
    420             other args: see SendCommand
    421             recursive: boolean, set to recursively read the imported
    422                        specification(s) and return the merged one.
    423         """
    424         self.SendCommand(
    425             SysMsg_pb2.VTS_AGENT_COMMAND_READ_SPECIFICATION,
    426             service_name=interface_name,
    427             target_class=target_class,
    428             target_type=target_type,
    429             target_version=target_version,
    430             target_package=target_package)
    431         resp = self.RecvResponse(retries=2)
    432         logging.info("resp for VTS_AGENT_COMMAND_EXECUTE_READ_INTERFACE: %s",
    433                      resp)
    434         logging.info("proto: %s", resp.result)
    435         result = CompSpecMsg_pb2.ComponentSpecificationMessage()
    436         if resp.result == "error":
    437             raise errors.VtsTcpCommunicationError(
    438                 "API call error by the VTS driver.")
    439         try:
    440             text_format.Merge(resp.result, result)
    441         except text_format.ParseError as e:
    442             logging.exception(e)
    443             logging.error("Paring error\n%s", resp.result)
    444 
    445         if recursive and hasattr(result, "import"):
    446             for imported_interface in getattr(result, "import"):
    447                 if imported_interface == "android.hidl.base (at] 1.0::types":
    448                     logging.warn("import android.hidl.base (at] 1.0::types skipped")
    449                     continue
    450                 imported_result = self.ReadSpecification(
    451                     imported_interface.split("::")[1],
    452                     # TODO(yim): derive target_class and
    453                     # target_type from package path or remove them
    454                     msg.component_class
    455                     if target_class is None else target_class,
    456                     msg.component_type if target_type is None else target_type,
    457                     float(imported_interface.split("@")[1].split("::")[0]),
    458                     imported_interface.split("@")[0])
    459                 result.MergeFrom(imported_result)
    460 
    461         return result
    462 
    463     def SendCommand(self,
    464                     command_type,
    465                     paths=None,
    466                     file_path=None,
    467                     bits=None,
    468                     target_class=None,
    469                     target_type=None,
    470                     target_version=None,
    471                     target_package=None,
    472                     target_component_name=None,
    473                     hw_binder_service_name=None,
    474                     module_name=None,
    475                     service_name=None,
    476                     callback_port=None,
    477                     driver_type=None,
    478                     shell_command=None,
    479                     caller_uid=None,
    480                     arg=None):
    481         """Sends a command.
    482 
    483         Args:
    484             command_type: integer, the command type.
    485             each of the other args are to fill in a field in
    486             AndroidSystemControlCommandMessage.
    487         """
    488         if not self.channel:
    489             raise errors.VtsTcpCommunicationError(
    490                 "channel is None, unable to send command.")
    491 
    492         command_msg = SysMsg_pb2.AndroidSystemControlCommandMessage()
    493         command_msg.command_type = command_type
    494         logging.info("sending a command (type %s)",
    495                      COMMAND_TYPE_NAME[command_type])
    496         if command_type == 202:
    497             logging.info("target API: %s", arg)
    498 
    499         if target_class is not None:
    500             command_msg.target_class = target_class
    501 
    502         if target_type is not None:
    503             command_msg.target_type = target_type
    504 
    505         if target_version is not None:
    506             command_msg.target_version = int(target_version * 100)
    507 
    508         if target_package is not None:
    509             command_msg.target_package = target_package
    510 
    511         if target_component_name is not None:
    512             command_msg.target_component_name = target_component_name
    513 
    514         if hw_binder_service_name is not None:
    515             command_msg.hw_binder_service_name = hw_binder_service_name
    516 
    517         if module_name is not None:
    518             command_msg.module_name = module_name
    519 
    520         if service_name is not None:
    521             command_msg.service_name = service_name
    522 
    523         if driver_type is not None:
    524             command_msg.driver_type = driver_type
    525 
    526         if paths is not None:
    527             command_msg.paths.extend(paths)
    528 
    529         if file_path is not None:
    530             command_msg.file_path = file_path
    531 
    532         if bits is not None:
    533             command_msg.bits = bits
    534 
    535         if callback_port is not None:
    536             command_msg.callback_port = callback_port
    537 
    538         if caller_uid is not None:
    539             command_msg.driver_caller_uid = caller_uid
    540 
    541         if arg is not None:
    542             command_msg.arg = arg
    543 
    544         if shell_command is not None:
    545             if isinstance(shell_command, types.ListType):
    546                 command_msg.shell_command.extend(shell_command)
    547             else:
    548                 command_msg.shell_command.append(shell_command)
    549 
    550         logging.info("command %s" % command_msg)
    551         message = command_msg.SerializeToString()
    552         message_len = len(message)
    553         logging.debug("sending %d bytes", message_len)
    554         self.channel.write(str(message_len) + b'\n')
    555         self.channel.write(message)
    556         self.channel.flush()
    557 
    558     def RecvResponse(self, retries=0):
    559         """Receives and parses the response, and returns the relevant ResponseMessage.
    560 
    561         Args:
    562             retries: an integer indicating the max number of retries in case of
    563                      session timeout error.
    564         """
    565         for index in xrange(1 + retries):
    566             try:
    567                 if index != 0:
    568                     logging.info("retrying...")
    569                 header = self.channel.readline().strip("\n")
    570                 length = int(header) if header else 0
    571                 logging.info("resp %d bytes", length)
    572                 data = self.channel.read(length)
    573                 response_msg = SysMsg_pb2.AndroidSystemControlResponseMessage()
    574                 response_msg.ParseFromString(data)
    575                 logging.debug("Response %s", "success" if
    576                               response_msg.response_code == SysMsg_pb2.SUCCESS
    577                               else "fail")
    578                 return response_msg
    579             except socket.timeout as e:
    580                 logging.exception(e)
    581         return None
    582