Home | History | Annotate | Download | only in stress
      1 # Copyright 2016 gRPC authors.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 """Entry point for running stress tests."""
     15 
     16 import argparse
     17 from concurrent import futures
     18 import threading
     19 
     20 import grpc
     21 from six.moves import queue
     22 from src.proto.grpc.testing import metrics_pb2_grpc
     23 from src.proto.grpc.testing import test_pb2_grpc
     24 
     25 from tests.interop import methods
     26 from tests.interop import resources
     27 from tests.qps import histogram
     28 from tests.stress import metrics_server
     29 from tests.stress import test_runner
     30 
     31 
     32 def _args():
     33     parser = argparse.ArgumentParser(
     34         description='gRPC Python stress test client')
     35     parser.add_argument(
     36         '--server_addresses',
     37         help='comma seperated list of hostname:port to run servers on',
     38         default='localhost:8080',
     39         type=str)
     40     parser.add_argument(
     41         '--test_cases',
     42         help='comma seperated list of testcase:weighting of tests to run',
     43         default='large_unary:100',
     44         type=str)
     45     parser.add_argument(
     46         '--test_duration_secs',
     47         help='number of seconds to run the stress test',
     48         default=-1,
     49         type=int)
     50     parser.add_argument(
     51         '--num_channels_per_server',
     52         help='number of channels per server',
     53         default=1,
     54         type=int)
     55     parser.add_argument(
     56         '--num_stubs_per_channel',
     57         help='number of stubs to create per channel',
     58         default=1,
     59         type=int)
     60     parser.add_argument(
     61         '--metrics_port',
     62         help='the port to listen for metrics requests on',
     63         default=8081,
     64         type=int)
     65     parser.add_argument(
     66         '--use_test_ca',
     67         help='Whether to use our fake CA. Requires --use_tls=true',
     68         default=False,
     69         type=bool)
     70     parser.add_argument(
     71         '--use_tls', help='Whether to use TLS', default=False, type=bool)
     72     parser.add_argument(
     73         '--server_host_override',
     74         default="foo.test.google.fr",
     75         help='the server host to which to claim to connect',
     76         type=str)
     77     return parser.parse_args()
     78 
     79 
     80 def _test_case_from_arg(test_case_arg):
     81     for test_case in methods.TestCase:
     82         if test_case_arg == test_case.value:
     83             return test_case
     84     else:
     85         raise ValueError('No test case {}!'.format(test_case_arg))
     86 
     87 
     88 def _parse_weighted_test_cases(test_case_args):
     89     weighted_test_cases = {}
     90     for test_case_arg in test_case_args.split(','):
     91         name, weight = test_case_arg.split(':', 1)
     92         test_case = _test_case_from_arg(name)
     93         weighted_test_cases[test_case] = int(weight)
     94     return weighted_test_cases
     95 
     96 
     97 def _get_channel(target, args):
     98     if args.use_tls:
     99         if args.use_test_ca:
    100             root_certificates = resources.test_root_certificates()
    101         else:
    102             root_certificates = None  # will load default roots.
    103         channel_credentials = grpc.ssl_channel_credentials(
    104             root_certificates=root_certificates)
    105         options = ((
    106             'grpc.ssl_target_name_override',
    107             args.server_host_override,
    108         ),)
    109         channel = grpc.secure_channel(
    110             target, channel_credentials, options=options)
    111     else:
    112         channel = grpc.insecure_channel(target)
    113 
    114     # waits for the channel to be ready before we start sending messages
    115     grpc.channel_ready_future(channel).result()
    116     return channel
    117 
    118 
    119 def run_test(args):
    120     test_cases = _parse_weighted_test_cases(args.test_cases)
    121     test_server_targets = args.server_addresses.split(',')
    122     # Propagate any client exceptions with a queue
    123     exception_queue = queue.Queue()
    124     stop_event = threading.Event()
    125     hist = histogram.Histogram(1, 1)
    126     runners = []
    127 
    128     server = grpc.server(futures.ThreadPoolExecutor(max_workers=25))
    129     metrics_pb2_grpc.add_MetricsServiceServicer_to_server(
    130         metrics_server.MetricsServer(hist), server)
    131     server.add_insecure_port('[::]:{}'.format(args.metrics_port))
    132     server.start()
    133 
    134     for test_server_target in test_server_targets:
    135         for _ in xrange(args.num_channels_per_server):
    136             channel = _get_channel(test_server_target, args)
    137             for _ in xrange(args.num_stubs_per_channel):
    138                 stub = test_pb2_grpc.TestServiceStub(channel)
    139                 runner = test_runner.TestRunner(stub, test_cases, hist,
    140                                                 exception_queue, stop_event)
    141                 runners.append(runner)
    142 
    143     for runner in runners:
    144         runner.start()
    145     try:
    146         timeout_secs = args.test_duration_secs
    147         if timeout_secs < 0:
    148             timeout_secs = None
    149         raise exception_queue.get(block=True, timeout=timeout_secs)
    150     except queue.Empty:
    151         # No exceptions thrown, success
    152         pass
    153     finally:
    154         stop_event.set()
    155         for runner in runners:
    156             runner.join()
    157         runner = None
    158         server.stop(None)
    159 
    160 
    161 if __name__ == '__main__':
    162     run_test(_args())
    163