Home | History | Annotate | Download | only in scripts
      1 #!/usr/bin/python2
      2 #
      3 # Copyright (C) 2017 The Android Open Source Project
      4 #
      5 # Licensed under the Apache License, Version 2.0 (the "License");
      6 # you may not use this file except in compliance with the License.
      7 # You may obtain a copy of the License at
      8 #
      9 #      http://www.apache.org/licenses/LICENSE-2.0
     10 #
     11 # Unless required by applicable law or agreed to in writing, software
     12 # distributed under the License is distributed on an "AS IS" BASIS,
     13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     14 # See the License for the specific language governing permissions and
     15 # limitations under the License.
     16 #
     17 
     18 """Send an A/B update to an Android device over adb."""
     19 
     20 import argparse
     21 import BaseHTTPServer
     22 import hashlib
     23 import logging
     24 import os
     25 import socket
     26 import subprocess
     27 import sys
     28 import threading
     29 import xml.etree.ElementTree
     30 import zipfile
     31 
     32 import update_payload.payload
     33 
     34 
     35 # The path used to store the OTA package when applying the package from a file.
     36 OTA_PACKAGE_PATH = '/data/ota_package'
     37 
     38 # The path to the payload public key on the device.
     39 PAYLOAD_KEY_PATH = '/etc/update_engine/update-payload-key.pub.pem'
     40 
     41 # The port on the device that update_engine should connect to.
     42 DEVICE_PORT = 1234
     43 
     44 def CopyFileObjLength(fsrc, fdst, buffer_size=128 * 1024, copy_length=None):
     45   """Copy from a file object to another.
     46 
     47   This function is similar to shutil.copyfileobj except that it allows to copy
     48   less than the full source file.
     49 
     50   Args:
     51     fsrc: source file object where to read from.
     52     fdst: destination file object where to write to.
     53     buffer_size: size of the copy buffer in memory.
     54     copy_length: maximum number of bytes to copy, or None to copy everything.
     55 
     56   Returns:
     57     the number of bytes copied.
     58   """
     59   copied = 0
     60   while True:
     61     chunk_size = buffer_size
     62     if copy_length is not None:
     63       chunk_size = min(chunk_size, copy_length - copied)
     64       if not chunk_size:
     65         break
     66     buf = fsrc.read(chunk_size)
     67     if not buf:
     68       break
     69     fdst.write(buf)
     70     copied += len(buf)
     71   return copied
     72 
     73 
     74 class AndroidOTAPackage(object):
     75   """Android update payload using the .zip format.
     76 
     77   Android OTA packages traditionally used a .zip file to store the payload. When
     78   applying A/B updates over the network, a payload binary is stored RAW inside
     79   this .zip file which is used by update_engine to apply the payload. To do
     80   this, an offset and size inside the .zip file are provided.
     81   """
     82 
     83   # Android OTA package file paths.
     84   OTA_PAYLOAD_BIN = 'payload.bin'
     85   OTA_PAYLOAD_PROPERTIES_TXT = 'payload_properties.txt'
     86 
     87   def __init__(self, otafilename):
     88     self.otafilename = otafilename
     89 
     90     otazip = zipfile.ZipFile(otafilename, 'r')
     91     payload_info = otazip.getinfo(self.OTA_PAYLOAD_BIN)
     92     self.offset = payload_info.header_offset + len(payload_info.FileHeader())
     93     self.size = payload_info.file_size
     94     self.properties = otazip.read(self.OTA_PAYLOAD_PROPERTIES_TXT)
     95 
     96 
     97 class UpdateHandler(BaseHTTPServer.BaseHTTPRequestHandler):
     98   """A HTTPServer that supports single-range requests.
     99 
    100   Attributes:
    101     serving_payload: path to the only payload file we are serving.
    102     serving_range: the start offset and size tuple of the payload.
    103   """
    104 
    105   @staticmethod
    106   def _parse_range(range_str, file_size):
    107     """Parse an HTTP range string.
    108 
    109     Args:
    110       range_str: HTTP Range header in the request, not including "Header:".
    111       file_size: total size of the serving file.
    112 
    113     Returns:
    114       A tuple (start_range, end_range) with the range of bytes requested.
    115     """
    116     start_range = 0
    117     end_range = file_size
    118 
    119     if range_str:
    120       range_str = range_str.split('=', 1)[1]
    121       s, e = range_str.split('-', 1)
    122       if s:
    123         start_range = int(s)
    124         if e:
    125           end_range = int(e) + 1
    126       elif e:
    127         if int(e) < file_size:
    128           start_range = file_size - int(e)
    129     return start_range, end_range
    130 
    131 
    132   def do_GET(self):  # pylint: disable=invalid-name
    133     """Reply with the requested payload file."""
    134     if self.path != '/payload':
    135       self.send_error(404, 'Unknown request')
    136       return
    137 
    138     if not self.serving_payload:
    139       self.send_error(500, 'No serving payload set')
    140       return
    141 
    142     try:
    143       f = open(self.serving_payload, 'rb')
    144     except IOError:
    145       self.send_error(404, 'File not found')
    146       return
    147     # Handle the range request.
    148     if 'Range' in self.headers:
    149       self.send_response(206)
    150     else:
    151       self.send_response(200)
    152 
    153     serving_start, serving_size = self.serving_range
    154     start_range, end_range = self._parse_range(self.headers.get('range'),
    155                                                serving_size)
    156     logging.info('Serving request for %s from %s [%d, %d) length: %d',
    157                  self.path, self.serving_payload, serving_start + start_range,
    158                  serving_start + end_range, end_range - start_range)
    159 
    160     self.send_header('Accept-Ranges', 'bytes')
    161     self.send_header('Content-Range',
    162                      'bytes ' + str(start_range) + '-' + str(end_range - 1) +
    163                      '/' + str(end_range - start_range))
    164     self.send_header('Content-Length', end_range - start_range)
    165 
    166     stat = os.fstat(f.fileno())
    167     self.send_header('Last-Modified', self.date_time_string(stat.st_mtime))
    168     self.send_header('Content-type', 'application/octet-stream')
    169     self.end_headers()
    170 
    171     f.seek(serving_start + start_range)
    172     CopyFileObjLength(f, self.wfile, copy_length=end_range - start_range)
    173 
    174 
    175   def do_POST(self):  # pylint: disable=invalid-name
    176     """Reply with the omaha response xml."""
    177     if self.path != '/update':
    178       self.send_error(404, 'Unknown request')
    179       return
    180 
    181     if not self.serving_payload:
    182       self.send_error(500, 'No serving payload set')
    183       return
    184 
    185     try:
    186       f = open(self.serving_payload, 'rb')
    187     except IOError:
    188       self.send_error(404, 'File not found')
    189       return
    190 
    191     content_length = int(self.headers.getheader('Content-Length'))
    192     request_xml = self.rfile.read(content_length)
    193     xml_root = xml.etree.ElementTree.fromstring(request_xml)
    194     appid = None
    195     for app in xml_root.iter('app'):
    196       if 'appid' in app.attrib:
    197         appid = app.attrib['appid']
    198         break
    199     if not appid:
    200       self.send_error(400, 'No appid in Omaha request')
    201       return
    202 
    203     self.send_response(200)
    204     self.send_header("Content-type", "text/xml")
    205     self.end_headers()
    206 
    207     serving_start, serving_size = self.serving_range
    208     sha256 = hashlib.sha256()
    209     f.seek(serving_start)
    210     bytes_to_hash = serving_size
    211     while bytes_to_hash:
    212       buf = f.read(min(bytes_to_hash, 1024 * 1024))
    213       if not buf:
    214         self.send_error(500, 'Payload too small')
    215         return
    216       sha256.update(buf)
    217       bytes_to_hash -= len(buf)
    218 
    219     payload = update_payload.Payload(f, payload_file_offset=serving_start)
    220     payload.Init()
    221 
    222     response_xml = '''
    223         <?xml version="1.0" encoding="UTF-8"?>
    224         <response protocol="3.0">
    225           <app appid="{appid}">
    226             <updatecheck status="ok">
    227               <urls>
    228                 <url codebase="http://127.0.0.1:{port}/"/>
    229               </urls>
    230               <manifest version="0.0.0.1">
    231                 <actions>
    232                   <action event="install" run="payload"/>
    233                   <action event="postinstall" MetadataSize="{metadata_size}"/>
    234                 </actions>
    235                 <packages>
    236                   <package hash_sha256="{payload_hash}" name="payload" size="{payload_size}"/>
    237                 </packages>
    238               </manifest>
    239             </updatecheck>
    240           </app>
    241         </response>
    242     '''.format(appid=appid, port=DEVICE_PORT,
    243                metadata_size=payload.metadata_size,
    244                payload_hash=sha256.hexdigest(),
    245                payload_size=serving_size)
    246     self.wfile.write(response_xml.strip())
    247     return
    248 
    249 
    250 class ServerThread(threading.Thread):
    251   """A thread for serving HTTP requests."""
    252 
    253   def __init__(self, ota_filename, serving_range):
    254     threading.Thread.__init__(self)
    255     # serving_payload and serving_range are class attributes and the
    256     # UpdateHandler class is instantiated with every request.
    257     UpdateHandler.serving_payload = ota_filename
    258     UpdateHandler.serving_range = serving_range
    259     self._httpd = BaseHTTPServer.HTTPServer(('127.0.0.1', 0), UpdateHandler)
    260     self.port = self._httpd.server_port
    261 
    262   def run(self):
    263     try:
    264       self._httpd.serve_forever()
    265     except (KeyboardInterrupt, socket.error):
    266       pass
    267     logging.info('Server Terminated')
    268 
    269   def StopServer(self):
    270     self._httpd.socket.close()
    271 
    272 
    273 def StartServer(ota_filename, serving_range):
    274   t = ServerThread(ota_filename, serving_range)
    275   t.start()
    276   return t
    277 
    278 
    279 def AndroidUpdateCommand(ota_filename, payload_url, extra_headers):
    280   """Return the command to run to start the update in the Android device."""
    281   ota = AndroidOTAPackage(ota_filename)
    282   headers = ota.properties
    283   headers += 'USER_AGENT=Dalvik (something, something)\n'
    284   headers += 'NETWORK_ID=0\n'
    285   headers += extra_headers
    286 
    287   return ['update_engine_client', '--update', '--follow',
    288           '--payload=%s' % payload_url, '--offset=%d' % ota.offset,
    289           '--size=%d' % ota.size, '--headers="%s"' % headers]
    290 
    291 
    292 def OmahaUpdateCommand(omaha_url):
    293   """Return the command to run to start the update in a device using Omaha."""
    294   return ['update_engine_client', '--update', '--follow',
    295           '--omaha_url=%s' % omaha_url]
    296 
    297 
    298 class AdbHost(object):
    299   """Represents a device connected via ADB."""
    300 
    301   def __init__(self, device_serial=None):
    302     """Construct an instance.
    303 
    304     Args:
    305         device_serial: options string serial number of attached device.
    306     """
    307     self._device_serial = device_serial
    308     self._command_prefix = ['adb']
    309     if self._device_serial:
    310       self._command_prefix += ['-s', self._device_serial]
    311 
    312   def adb(self, command):
    313     """Run an ADB command like "adb push".
    314 
    315     Args:
    316       command: list of strings containing command and arguments to run
    317 
    318     Returns:
    319       the program's return code.
    320 
    321     Raises:
    322       subprocess.CalledProcessError on command exit != 0.
    323     """
    324     command = self._command_prefix + command
    325     logging.info('Running: %s', ' '.join(str(x) for x in command))
    326     p = subprocess.Popen(command, universal_newlines=True)
    327     p.wait()
    328     return p.returncode
    329 
    330   def adb_output(self, command):
    331     """Run an ADB command like "adb push" and return the output.
    332 
    333     Args:
    334       command: list of strings containing command and arguments to run
    335 
    336     Returns:
    337       the program's output as a string.
    338 
    339     Raises:
    340       subprocess.CalledProcessError on command exit != 0.
    341     """
    342     command = self._command_prefix + command
    343     logging.info('Running: %s', ' '.join(str(x) for x in command))
    344     return subprocess.check_output(command, universal_newlines=True)
    345 
    346 
    347 def main():
    348   parser = argparse.ArgumentParser(description='Android A/B OTA helper.')
    349   parser.add_argument('otafile', metavar='PAYLOAD', type=str,
    350                       help='the OTA package file (a .zip file) or raw payload \
    351                       if device uses Omaha.')
    352   parser.add_argument('--file', action='store_true',
    353                       help='Push the file to the device before updating.')
    354   parser.add_argument('--no-push', action='store_true',
    355                       help='Skip the "push" command when using --file')
    356   parser.add_argument('-s', type=str, default='', metavar='DEVICE',
    357                       help='The specific device to use.')
    358   parser.add_argument('--no-verbose', action='store_true',
    359                       help='Less verbose output')
    360   parser.add_argument('--public-key', type=str, default='',
    361                       help='Override the public key used to verify payload.')
    362   parser.add_argument('--extra-headers', type=str, default='',
    363                       help='Extra headers to pass to the device.')
    364   args = parser.parse_args()
    365   logging.basicConfig(
    366       level=logging.WARNING if args.no_verbose else logging.INFO)
    367 
    368   dut = AdbHost(args.s)
    369 
    370   server_thread = None
    371   # List of commands to execute on exit.
    372   finalize_cmds = []
    373   # Commands to execute when canceling an update.
    374   cancel_cmd = ['shell', 'su', '0', 'update_engine_client', '--cancel']
    375   # List of commands to perform the update.
    376   cmds = []
    377 
    378   help_cmd = ['shell', 'su', '0', 'update_engine_client', '--help']
    379   use_omaha = 'omaha' in dut.adb_output(help_cmd)
    380 
    381   if args.file:
    382     # Update via pushing a file to /data.
    383     device_ota_file = os.path.join(OTA_PACKAGE_PATH, 'debug.zip')
    384     payload_url = 'file://' + device_ota_file
    385     if not args.no_push:
    386       data_local_tmp_file = '/data/local/tmp/debug.zip'
    387       cmds.append(['push', args.otafile, data_local_tmp_file])
    388       cmds.append(['shell', 'su', '0', 'mv', data_local_tmp_file,
    389                    device_ota_file])
    390       cmds.append(['shell', 'su', '0', 'chcon',
    391                    'u:object_r:ota_package_file:s0', device_ota_file])
    392     cmds.append(['shell', 'su', '0', 'chown', 'system:cache', device_ota_file])
    393     cmds.append(['shell', 'su', '0', 'chmod', '0660', device_ota_file])
    394   else:
    395     # Update via sending the payload over the network with an "adb reverse"
    396     # command.
    397     payload_url = 'http://127.0.0.1:%d/payload' % DEVICE_PORT
    398     if use_omaha and zipfile.is_zipfile(args.otafile):
    399       ota = AndroidOTAPackage(args.otafile)
    400       serving_range = (ota.offset, ota.size)
    401     else:
    402       serving_range = (0, os.stat(args.otafile).st_size)
    403     server_thread = StartServer(args.otafile, serving_range)
    404     cmds.append(
    405         ['reverse', 'tcp:%d' % DEVICE_PORT, 'tcp:%d' % server_thread.port])
    406     finalize_cmds.append(['reverse', '--remove', 'tcp:%d' % DEVICE_PORT])
    407 
    408   if args.public_key:
    409     payload_key_dir = os.path.dirname(PAYLOAD_KEY_PATH)
    410     cmds.append(
    411         ['shell', 'su', '0', 'mount', '-t', 'tmpfs', 'tmpfs', payload_key_dir])
    412     # Allow adb push to payload_key_dir
    413     cmds.append(['shell', 'su', '0', 'chcon', 'u:object_r:shell_data_file:s0',
    414                  payload_key_dir])
    415     cmds.append(['push', args.public_key, PAYLOAD_KEY_PATH])
    416     # Allow update_engine to read it.
    417     cmds.append(['shell', 'su', '0', 'chcon', '-R', 'u:object_r:system_file:s0',
    418                  payload_key_dir])
    419     finalize_cmds.append(['shell', 'su', '0', 'umount', payload_key_dir])
    420 
    421   try:
    422     # The main update command using the configured payload_url.
    423     if use_omaha:
    424       update_cmd = \
    425           OmahaUpdateCommand('http://127.0.0.1:%d/update' % DEVICE_PORT)
    426     else:
    427       update_cmd = \
    428           AndroidUpdateCommand(args.otafile, payload_url, args.extra_headers)
    429     cmds.append(['shell', 'su', '0'] + update_cmd)
    430 
    431     for cmd in cmds:
    432       dut.adb(cmd)
    433   except KeyboardInterrupt:
    434     dut.adb(cancel_cmd)
    435   finally:
    436     if server_thread:
    437       server_thread.StopServer()
    438     for cmd in finalize_cmds:
    439       dut.adb(cmd)
    440 
    441   return 0
    442 
    443 if __name__ == '__main__':
    444   sys.exit(main())
    445