Home | History | Annotate | Download | only in lib
      1 #!/usr/bin/env python
      2 #
      3 # Copyright 2016 - The Android Open Source Project
      4 #
      5 # Licensed under the Apache License, Version 2.0 (the "License");
      6 # you may not use this file except in compliance with the License.
      7 # You may obtain a copy of the License at
      8 #
      9 #     http://www.apache.org/licenses/LICENSE-2.0
     10 #
     11 # Unless required by applicable law or agreed to in writing, software
     12 # distributed under the License is distributed on an "AS IS" BASIS,
     13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     14 # See the License for the specific language governing permissions and
     15 # limitations under the License.
     16 
     17 """Common Utilities.
     18 
     19 The following code is copied from chromite with modifications.
     20   - class TempDir: chromite/lib/osutils.py
     21 
     22 """
     23 
     24 import base64
     25 import binascii
     26 import errno
     27 import getpass
     28 import logging
     29 import os
     30 import shutil
     31 import struct
     32 import subprocess
     33 import sys
     34 import tarfile
     35 import tempfile
     36 import time
     37 import uuid
     38 
     39 from acloud.public import errors
     40 
     41 
     42 logger = logging.getLogger(__name__)
     43 
     44 
     45 SSH_KEYGEN_CMD = ["ssh-keygen", "-t", "rsa", "-b", "4096"]
     46 
     47 
     48 class TempDir(object):
     49     """Object that creates a temporary directory.
     50 
     51     This object can either be used as a context manager or just as a simple
     52     object. The temporary directory is stored as self.tempdir in the object, and
     53     is returned as a string by a 'with' statement.
     54     """
     55 
     56     def __init__(self, prefix='tmp', base_dir=None, delete=True):
     57         """Constructor. Creates the temporary directory.
     58 
     59         Args:
     60             prefix: See tempfile.mkdtemp documentation.
     61             base_dir: The directory to place the temporary directory.
     62                       If None, will choose from system default tmp dir.
     63             delete: Whether the temporary dir should be deleted as part of cleanup.
     64         """
     65         self.delete = delete
     66         self.tempdir = tempfile.mkdtemp(prefix=prefix, dir=base_dir)
     67         os.chmod(self.tempdir, 0o700)
     68 
     69     def Cleanup(self):
     70         """Clean up the temporary directory."""
     71         # Note that _TempDirSetup may have failed, resulting in these attributes
     72         # not being set; this is why we use getattr here (and must).
     73         tempdir = getattr(self, 'tempdir', None)
     74         if tempdir is not None and self.delete:
     75             try:
     76                 shutil.rmtree(tempdir)
     77             except EnvironmentError as e:
     78                 # Ignore error if directory or file does not exist.
     79                 if e.errno != errno.ENOENT:
     80                     raise
     81             finally:
     82                 self.tempdir = None
     83 
     84     def __enter__(self):
     85         """Return the temporary directory."""
     86         return self.tempdir
     87 
     88     def __exit__(self, exc_type, exc_value, exc_traceback):
     89         """Exit the context manager."""
     90         try:
     91             self.Cleanup()
     92         except Exception:  # pylint: disable=W0703
     93             if exc_type:
     94                 # If an exception from inside the context was already in progress,
     95                 # log our cleanup exception, then allow the original to resume.
     96                 logger.error('While exiting %s:', self, exc_info=True)
     97 
     98                 if self.tempdir:
     99                     # Log all files in tempdir at the time of the failure.
    100                     try:
    101                         logger.error('Directory contents were:')
    102                         for name in os.listdir(self.tempdir):
    103                             logger.error('  %s', name)
    104                     except OSError:
    105                         logger.error('  Directory did not exist.')
    106             else:
    107                 # If there was not an exception from the context, raise ours.
    108                 raise
    109 
    110     def __del__(self):
    111         """Delete the object."""
    112         self.Cleanup()
    113 
    114 def RetryOnException(retry_checker, max_retries, sleep_multiplier=0,
    115                      retry_backoff_factor=1):
    116   """Decorater which retries the function call if |retry_checker| returns true.
    117 
    118   Args:
    119     retry_checker: A callback function which should take an exception instance
    120                    and return True if functor(*args, **kwargs) should be retried
    121                    when such exception is raised, and return False if it should
    122                    not be retried.
    123     max_retries: Maximum number of retries allowed.
    124     sleep_multiplier: Will sleep sleep_multiplier * attempt_count seconds if
    125                       retry_backoff_factor is 1.  Will sleep
    126                       sleep_multiplier * (
    127                           retry_backoff_factor ** (attempt_count -  1))
    128                       if retry_backoff_factor != 1.
    129     retry_backoff_factor: See explanation of sleep_multiplier.
    130 
    131   Returns:
    132     The function wrapper.
    133   """
    134   def _Wrapper(func):
    135     def _FunctionWrapper(*args, **kwargs):
    136       return Retry(retry_checker, max_retries, func, sleep_multiplier,
    137                    retry_backoff_factor,
    138                    *args, **kwargs)
    139     return _FunctionWrapper
    140   return _Wrapper
    141 
    142 
    143 def Retry(retry_checker, max_retries, functor, sleep_multiplier=0,
    144           retry_backoff_factor=1, *args, **kwargs):
    145   """Conditionally retry a function.
    146 
    147   Args:
    148     retry_checker: A callback function which should take an exception instance
    149                    and return True if functor(*args, **kwargs) should be retried
    150                    when such exception is raised, and return False if it should
    151                    not be retried.
    152     max_retries: Maximum number of retries allowed.
    153     functor: The function to call, will call functor(*args, **kwargs).
    154     sleep_multiplier: Will sleep sleep_multiplier * attempt_count seconds if
    155                       retry_backoff_factor is 1.  Will sleep
    156                       sleep_multiplier * (
    157                           retry_backoff_factor ** (attempt_count -  1))
    158                       if retry_backoff_factor != 1.
    159     retry_backoff_factor: See explanation of sleep_multiplier.
    160     *args: Arguments to pass to the functor.
    161     **kwargs: Key-val based arguments to pass to the functor.
    162 
    163   Returns:
    164     The return value of the functor.
    165 
    166   Raises:
    167     Exception: The exception that functor(*args, **kwargs) throws.
    168   """
    169   attempt_count = 0
    170   while attempt_count <= max_retries:
    171     try:
    172       attempt_count += 1
    173       return_value = functor(*args, **kwargs)
    174       return return_value
    175     except Exception as e:  # pylint: disable=W0703
    176       if retry_checker(e) and attempt_count <= max_retries:
    177         if retry_backoff_factor != 1:
    178           sleep = sleep_multiplier * (
    179               retry_backoff_factor ** (attempt_count -  1))
    180         else:
    181           sleep = sleep_multiplier * attempt_count
    182         time.sleep(sleep)
    183       else:
    184         raise
    185 
    186 
    187 def RetryExceptionType(exception_types, max_retries, functor, *args, **kwargs):
    188   """Retry exception if it is one of the given types.
    189 
    190   Args:
    191     exception_types: A tuple of exception types, e.g. (ValueError, KeyError)
    192     max_retries: Max number of retries allowed.
    193     functor: The function to call. Will be retried if exception is raised and
    194              the exception is one of the exception_types.
    195     *args: Arguments to pass to Retry function.
    196     **kwargs: Key-val based arguments to pass to Retry functions.
    197 
    198   Returns:
    199     The value returned by calling functor.
    200   """
    201   return Retry(lambda e: isinstance(e, exception_types), max_retries,
    202                functor, *args, **kwargs)
    203 
    204 
    205 def PollAndWait(func, expected_return, timeout_exception, timeout_secs,
    206                 sleep_interval_secs, *args, **kwargs):
    207     """Call a function until the function returns expected value or times out.
    208 
    209     Args:
    210         func: Function to call.
    211         expected_return: The expected return value.
    212         timeout_exception: Exception to raise when it hits timeout.
    213         timeout_secs: Timeout seconds.
    214                       If 0 or less than zero, the function will run once and
    215                       we will not wait on it.
    216         sleep_interval_secs: Time to sleep between two attemps.
    217         *args: list of args to pass to func.
    218         **kwargs: dictionary of keyword based args to pass to func.
    219 
    220     Raises:
    221         timeout_exception: if the run of function times out.
    222     """
    223     # TODO(fdeng): Currently this method does not kill
    224     # |func|, if |func| takes longer than |timeout_secs|.
    225     # We can use a more robust version from chromite.
    226     start = time.time()
    227     while True:
    228         return_value = func(*args, **kwargs)
    229         if return_value == expected_return:
    230             return
    231         elif time.time() - start > timeout_secs:
    232             raise timeout_exception
    233         else:
    234             if sleep_interval_secs > 0:
    235                 time.sleep(sleep_interval_secs)
    236 
    237 
    238 def GenerateUniqueName(prefix=None, suffix=None):
    239     """Generate a random unque name using uuid4.
    240 
    241     Args:
    242         prefix: String, desired prefix to prepend to the generated name.
    243         suffix: String, desired suffix to append to the generated name.
    244 
    245     Returns:
    246         String, a random name.
    247     """
    248     name = uuid.uuid4().hex
    249     if prefix:
    250         name = "-".join([prefix, name])
    251     if suffix:
    252         name = "-".join([name, suffix])
    253     return name
    254 
    255 
    256 def MakeTarFile(src_dict, dest):
    257     """Archive files in tar.gz format to a file named as |dest|.
    258 
    259     Args:
    260         src_dict: A dictionary that maps a path to be archived
    261                   to the corresponding name that appears in the archive.
    262         dest: String, path to output file, e.g. /tmp/myfile.tar.gz
    263     """
    264     logger.info("Compressing %s into %s.", src_dict.keys(), dest)
    265     with tarfile.open(dest, "w:gz") as tar:
    266         for src, arcname in src_dict.iteritems():
    267             tar.add(src, arcname=arcname)
    268 
    269 
    270 def CreateSshKeyPairIfNotExist(private_key_path, public_key_path):
    271     """Create the ssh key pair if they don't exist.
    272 
    273     Check if the public and private key pairs exist at
    274     the given places. If not, create them.
    275 
    276     Args:
    277         private_key_path: Path to the private key file.
    278                           e.g. ~/.ssh/acloud_rsa
    279         public_key_path: Path to the public key file.
    280                          e.g. ~/.ssh/acloud_rsa.pub
    281     Raises:
    282         error.DriverError: If failed to create the key pair.
    283     """
    284     public_key_path = os.path.expanduser(public_key_path)
    285     private_key_path = os.path.expanduser(private_key_path)
    286     create_key = (
    287             not os.path.exists(public_key_path) and
    288             not os.path.exists(private_key_path))
    289     if not create_key:
    290         logger.debug("The ssh private key (%s) or public key (%s) already exist,"
    291                      "will not automatically create the key pairs.",
    292                      private_key_path, public_key_path)
    293         return
    294     cmd = SSH_KEYGEN_CMD + ["-C", getpass.getuser(), "-f", private_key_path]
    295     logger.info("The ssh private key (%s) and public key (%s) do not exist, "
    296                 "automatically creating key pair, calling: %s",
    297                 private_key_path, public_key_path, " ".join(cmd))
    298     try:
    299         subprocess.check_call(cmd, stdout=sys.stderr, stderr=sys.stdout)
    300     except subprocess.CalledProcessError as e:
    301         raise errors.DriverError(
    302                 "Failed to create ssh key pair: %s" % str(e))
    303     except OSError as e:
    304         raise errors.DriverError(
    305                 "Failed to create ssh key pair, please make sure "
    306                 "'ssh-keygen' is installed: %s" % str(e))
    307 
    308     # By default ssh-keygen will create a public key file
    309     # by append .pub to the private key file name. Rename it
    310     # to what's requested by public_key_path.
    311     default_pub_key_path = "%s.pub" % private_key_path
    312     try:
    313         if default_pub_key_path != public_key_path:
    314             os.rename(default_pub_key_path, public_key_path)
    315     except OSError as e:
    316         raise errors.DriverError(
    317                 "Failed to rename %s to %s: %s" %
    318                 (default_pub_key_path, public_key_path, str(e)))
    319 
    320     logger.info("Created ssh private key (%s) and public key (%s)",
    321                 private_key_path, public_key_path)
    322 
    323 
    324 def VerifyRsaPubKey(rsa):
    325     """Verify the format of rsa public key.
    326 
    327     Args:
    328         rsa: content of rsa public key. It should follow the format of
    329              ssh-rsa AAAAB3NzaC1yc2EA.... test (at] test.com
    330 
    331     Raises:
    332         DriverError if the format is not correct.
    333     """
    334     if not rsa or not all(ord(c) < 128 for c in rsa):
    335         raise errors.DriverError(
    336             "rsa key is empty or contains non-ascii character: %s" % rsa)
    337 
    338     elements = rsa.split()
    339     if len(elements) != 3:
    340         raise errors.DriverError("rsa key is invalid, wrong format: %s" % rsa)
    341 
    342     key_type, data, _ = elements
    343     try:
    344         binary_data = base64.decodestring(data)
    345         # number of bytes of int type
    346         int_length = 4
    347         # binary_data is like "7ssh-key..." in a binary format.
    348         # The first 4 bytes should represent 7, which should be
    349         # the length of the following string "ssh-key".
    350         # And the next 7 bytes should be string "ssh-key".
    351         # We will verify that the rsa conforms to this format.
    352         # ">I" in the following line means "big-endian unsigned integer".
    353         type_length = struct.unpack(">I", binary_data[:int_length])[0]
    354         if binary_data[int_length:int_length + type_length] != key_type:
    355             raise errors.DriverError("rsa key is invalid: %s" % rsa)
    356     except (struct.error, binascii.Error) as e:
    357         raise errors.DriverError("rsa key is invalid: %s, error: %s" %
    358                                  (rsa, str(e)))
    359 
    360 
    361 class BatchHttpRequestExecutor(object):
    362     """A helper class that executes requests in batch with retry.
    363 
    364     This executor executes http requests in a batch and retry
    365     those that have failed. It iteratively updates the dictionary
    366     self._final_results with latest results, which can be retrieved
    367     via GetResults.
    368     """
    369 
    370     def __init__(self,
    371                  execute_once_functor,
    372                  requests,
    373                  retry_http_codes=None,
    374                  max_retry=None,
    375                  sleep=None,
    376                  backoff_factor=None,
    377                  other_retriable_errors=None):
    378         """Initializes the executor.
    379 
    380         Args:
    381             execute_once_functor: A function that execute requests in batch once.
    382                                   It should return a dictionary like
    383                                   {request_id: (response, exception)}
    384             requests: A dictionary where key is request id picked by caller,
    385                       and value is a apiclient.http.HttpRequest.
    386             retry_http_codes: A list of http codes to retry.
    387             max_retry: See utils.Retry.
    388             sleep: See utils.Retry.
    389             backoff_factor: See utils.Retry.
    390             other_retriable_errors: A tuple of error types that should be retried
    391                                     other than errors.HttpError.
    392         """
    393         self._execute_once_functor = execute_once_functor
    394         self._requests = requests
    395         # A dictionary that maps request id to pending request.
    396         self._pending_requests = {}
    397         # A dictionary that maps request id to a tuple (response, exception).
    398         self._final_results = {}
    399         self._retry_http_codes = retry_http_codes
    400         self._max_retry = max_retry
    401         self._sleep = sleep
    402         self._backoff_factor = backoff_factor
    403         self._other_retriable_errors = other_retriable_errors
    404 
    405     def _ShoudRetry(self, exception):
    406         """Check if an exception is retriable."""
    407         if isinstance(exception, self._other_retriable_errors):
    408             return True
    409 
    410         if (isinstance(exception, errors.HttpError) and
    411                 exception.code in self._retry_http_codes):
    412             return True
    413         return False
    414 
    415     def _ExecuteOnce(self):
    416         """Executes pending requests and update it with failed, retriable ones.
    417 
    418         Raises:
    419             HasRetriableRequestsError: if some requests fail and are retriable.
    420         """
    421         results = self._execute_once_functor(self._pending_requests)
    422         # Update final_results with latest results.
    423         self._final_results.update(results)
    424         # Clear pending_requests
    425         self._pending_requests.clear()
    426         for request_id, result in results.iteritems():
    427             exception = result[1]
    428             if exception is not None and self._ShoudRetry(exception):
    429                 # If this is a retriable exception, put it in pending_requests
    430                 self._pending_requests[request_id] = self._requests[request_id]
    431         if self._pending_requests:
    432             # If there is still retriable requests pending, raise an error
    433             # so that Retry will retry this function with pending_requests.
    434             raise errors.HasRetriableRequestsError(
    435                 "Retriable errors: %s" % [str(results[rid][1])
    436                                           for rid in self._pending_requests])
    437 
    438     def Execute(self):
    439         """Executes the requests and retry if necessary.
    440 
    441         Will populate self._final_results.
    442         """
    443         def _ShouldRetryHandler(exc):
    444             """Check if |exc| is a retriable exception.
    445 
    446             Args:
    447                 exc: An exception.
    448 
    449             Returns:
    450                 True if exception is of type HasRetriableRequestsError; False otherwise.
    451             """
    452             should_retry = isinstance(exc, errors.HasRetriableRequestsError)
    453             if should_retry:
    454                 logger.info("Will retry failed requests.", exc_info=True)
    455                 logger.info("%s", exc)
    456             return should_retry
    457 
    458         try:
    459             self._pending_requests = self._requests.copy()
    460             Retry(
    461                 _ShouldRetryHandler, max_retries=self._max_retry,
    462                 functor=self._ExecuteOnce,
    463                 sleep_multiplier=self._sleep,
    464                 retry_backoff_factor=self._backoff_factor)
    465         except errors.HasRetriableRequestsError:
    466             logger.debug("Some requests did not succeed after retry.")
    467 
    468     def GetResults(self):
    469         """Returns final results.
    470 
    471         Returns:
    472             results, a dictionary in the following format
    473             {request_id: (response, exception)}
    474             request_ids are those from requests; response
    475             is the http response for the request or None on error;
    476             exception is an instance of DriverError or None if no error.
    477         """
    478         return self._final_results
    479