Home | History | Annotate | Download | only in scripts
      1 #!/usr/bin/env bash
      2 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      3 #
      4 # Licensed under the Apache License, Version 2.0 (the "License");
      5 # you may not use this file except in compliance with the License.
      6 # You may obtain a copy of the License at
      7 #
      8 #     http://www.apache.org/licenses/LICENSE-2.0
      9 #
     10 # Unless required by applicable law or agreed to in writing, software
     11 # distributed under the License is distributed on an "AS IS" BASIS,
     12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 # See the License for the specific language governing permissions and
     14 # limitations under the License.
     15 # ==============================================================================
     16 #
     17 # Performs tests of TensorFlow's distributed runtime over a Kubernetes (k8s)
     18 # container cluster.
     19 #
     20 # This script tears down any existing TensorFlow cluster, consisting of
     21 # services, replication controllers and pods, before creating a new cluster.
     22 # The cluster containers a number of parameter server services and a number of
     23 # worker services. The parameter servers will hold parameters of the ML model,
     24 # e.g., weights and biases of the NN layers, while the workers will hold the
     25 # TensorFlow ops.
     26 #
     27 # Usage:
     28 #   dist_test.sh [--setup_cluster_only]
     29 #                [--model_name (MNIST | CENSUS_WIDENDEEP)]
     30 #                [--num_workers <NUM_WORKERS>]
     31 #                [--num_parameter_servers <NUM_PARAMETER_SERVERS>]
     32 #                [--sync_replicas]
     33 #
     34 # --setup_cluster_only:
     35 #   Lets the script only set up the k8s container network
     36 #
     37 # --model_name
     38 #   Name of the model to test. Default is MNIST.
     39 #
     40 # --num-workers <NUM_WORKERS>:
     41 #   Specifies the number of worker pods to start
     42 #
     43 # --num_parameter_servers <NUM_PARAMETER_SERVERS>:
     44 #   Specifies the number of parameter servers to start
     45 #
     46 # --sync_replicas
     47 #   Use the synchronized-replica mode. The parameter updates from the replicas
     48 #   (workers) will be aggregated before applied, which avoids stale parameter
     49 #   updates.
     50 #
     51 #
     52 # This script obeys values in the following environment variables:
     53 #   TF_DIST_GRPC_SERVER_URLS:     If it is set to a list of valid server urls,
     54 #                                 separated with spaces or commas
     55 #                                 (e.g., "grpc://1.2.3.4:2222 grpc//5.6.7.8:2222"),
     56 #                                 the script will bypass the cluster setup and
     57 #                                 teardown processes and just use this URL.
     58 
     59 
     60 # Helper functions
     61 die() {
     62   echo $@
     63   exit 1
     64 }
     65 
     66 # Parse input arguments: number of workers
     67 # Default values:
     68 MODEL_NAME="MNIST"  # Model name, default is "MNIST"
     69 NUM_WORKERS=2  # Number of worker container
     70 NUM_PARAMETER_SERVERS=2  # Number of parameter servers
     71 SYNC_REPLICAS=0
     72 SETUP_CLUSTER_ONLY=0
     73 
     74 while true; do
     75   if [[ "$1" == "--model_name" ]]; then
     76     MODEL_NAME=$2
     77   elif [[ "$1" == "--num_workers" ]]; then
     78     NUM_WORKERS=$2
     79   elif [[ "$1" == "--num_parameter_servers" ]]; then
     80     NUM_PARAMETER_SERVERS=$2
     81   elif [[ "$1" == "--sync_replicas" ]]; then
     82     SYNC_REPLICAS=1
     83   elif [[ "$1" == "--setup_cluster_only" ]]; then
     84     SETUP_CLUSTER_ONLY=1
     85   fi
     86   shift
     87 
     88   if [[ -z "$1" ]]; then
     89     break
     90   fi
     91 done
     92 
     93 echo "MODEL_NAME = \"MODEL_NAME\""
     94 echo "NUM_WORKERS = ${NUM_WORKERS}"
     95 echo "NUM_PARAMETER_SERVERS = ${NUM_PARAMETER_SERVERS}"
     96 echo "SETUP_CLUSTER_ONLY = ${SETUP_CLUSTER_ONLY}"
     97 
     98 # gcloud operation timeout (steps)
     99 GCLOUD_OP_MAX_STEPS=240
    100 
    101 if [[ ! -z ${TF_DIST_GRPC_SERVER_URLS} ]]; then
    102   GRPC_SERVER_URLS=${TF_DIST_GRPC_SERVER_URLS}
    103   GRPC_SERVER_URLS=$(echo ${GRPC_SERVER_URLS} | sed -e 's/,/ /g')
    104 fi
    105 
    106 # Report gcloud / GKE parameters
    107 echo "GRPC_SERVER_URLS: ${GRPC_SERVER_URLS}"
    108 echo "SYNC_REPLICAS: ${SYNC_REPLICAS}"
    109 
    110 # Get current script directory
    111 DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
    112 
    113 # Locate path to kubectl binary
    114 TEARDOWN_WHEN_DONE=1
    115 if [[ ! -z "${GRPC_SERVER_URLS}" ]]; then
    116   TEARDOWN_WHEN_DONE=0
    117   # Verify the validity of the GRPC URL
    118   for GRPC_SERVER_URL in ${GRPC_SERVER_URLS}; do
    119     if [[ -z $(echo "${GRPC_SERVER_URL}" | \
    120       grep -E "^grpc://.+:[0-9]+") ]]; then
    121       die "Invalid GRPC_SERVER_URL: \"${GRPC_SERVER_URL}\""
    122     fi
    123   done
    124 
    125   echo "The preset GRPC_SERVER_URLS appears to be valid: ${GRPC_SERVER_URLS}"
    126   echo "Will bypass the TensorFlow k8s cluster setup and teardown process"
    127   echo ""
    128 
    129 else
    130   TMP=$(mktemp)
    131   "${DIR}/create_tf_cluster.sh" ${NUM_WORKERS} ${NUM_PARAMETER_SERVERS} 2>&1 | \
    132       tee "${TMP}" || \
    133       die "Creation of TensorFlow k8s cluster FAILED"
    134 
    135   GRPC_SERVER_URLS=$(cat ${TMP} | grep "GRPC URLs of tf-worker instances: .*" | \
    136       sed -e 's/GRPC URLs of tf-worker instances://g')
    137 
    138   GRPC_PS_URLS=$(cat ${TMP} | grep "GRPC URLs of tf-ps instances: .*" | \
    139       sed -e 's/GRPC URLs of tf-ps instances://g')
    140 
    141   if [[ $(echo ${GRPC_SERVER_URLS} | wc -w) != ${NUM_WORKERS} ]]; then
    142     die "FAILED to determine GRPC server URLs of all workers"
    143   fi
    144   if [[ $(echo ${GRPC_PS_URLS} | wc -w) != ${NUM_PARAMETER_SERVERS} ]]; then
    145     die "FAILED to determine GRPC server URLs of all parameter servers"
    146   fi
    147 
    148   WORKER_HOSTS=$(echo "${GRPC_SERVER_URLS}" | sed -e 's/^[[:space:]]*//' | \
    149                  sed -e 's/grpc:\/\///g' | sed -e 's/ /,/g')
    150   PS_HOSTS=$(echo "${GRPC_PS_URLS}" | sed -e 's/^[[:space:]]*//' | \
    151              sed -e 's/grpc:\/\///g' | sed -e 's/ /,/g')
    152 
    153   echo "WORKER_HOSTS = ${WORKER_HOSTS}"
    154   echo "PS_HOSTS = ${PS_HOSTS}"
    155 
    156   rm -f ${TMP}
    157 
    158   if [[ ${SETUP_CLUSTER_ONLY} == "1" ]]; then
    159     echo "Skipping testing of distributed runtime due to "\
    160 "option flag --setup_cluster_only"
    161     exit 0
    162   fi
    163 fi
    164 
    165 
    166 # Test routine for model "MNIST"
    167 test_MNIST() {
    168   # Invoke script to perform distributed MNIST training
    169   MNIST_DIST_TEST_BIN="${DIR}/dist_mnist_test.sh"
    170   if [[ ! -f "${MNIST_DIST_TEST_BIN}" ]]; then
    171     echo "FAILED to find distributed mnist client test script at "\
    172   "${MNIST_DIST_TEST_BIN}"
    173     return 1
    174   fi
    175 
    176   echo "Performing distributed MNIST training through worker grpc sessions @ "\
    177   "${GRPC_SERVER_URLS}..."
    178 
    179   echo "and ps grpc sessions @ ${GRPC_PS_URLS}"
    180 
    181   SYNC_REPLICAS_FLAG=""
    182   if [[ ${SYNC_REPLICAS} == "1" ]]; then
    183     SYNC_REPLICAS_FLAG="--sync_replicas"
    184   fi
    185 
    186   "${MNIST_DIST_TEST_BIN}" \
    187       --existing_servers True \
    188       --ps_hosts "${PS_HOSTS}" \
    189       --worker_hosts "${WORKER_HOSTS}" \
    190       --num_gpus 0 \
    191       ${SYNC_REPLICAS_FLAG}
    192 
    193   if [[ $? == "0" ]]; then
    194     echo "MNIST-replica test PASSED"
    195   else
    196     echo "MNIST-replica test FAILED"
    197     return 1
    198   fi
    199   echo ""
    200 }
    201 
    202 # Test routine for model "CENSUS_WIDENDEEP"
    203 test_CENSUS_WIDENDEEP() {
    204   # Invoke script to perform distributed census_widendeep training
    205   CENSUS_WIDENDEEP_DIST_TEST_BIN="${DIR}/dist_census_widendeep_test.sh"
    206   if [[ ! -f "${CENSUS_WIDENDEEP_DIST_TEST_BIN}" ]]; then
    207     echo "FAILED to find distributed widen&deep client test script at "\
    208   "${CENSUS_WIDENDEEP_DIST_TEST_BIN}"
    209     return 1
    210   fi
    211 
    212   echo "Performing distributed wide&deep (census) training through grpc "\
    213   "sessions @ ${GRPC_SERVER_URLS}..."
    214 
    215   "${CENSUS_WIDENDEEP_DIST_TEST_BIN}" "${GRPC_SERVER_URLS}" \
    216       --num-workers "${NUM_WORKERS}" \
    217       --num-parameter-servers "${NUM_PARAMETER_SERVERS}"
    218 
    219   if [[ $? == "0" ]]; then
    220     echo "Census Wide & Deep test PASSED"
    221     echo ""
    222   else
    223     echo "Census Wide & Deep test FAILED"
    224     echo ""
    225     return 1
    226   fi
    227 }
    228 
    229 # Validate model name
    230 if [[ $(type -t "test_${MODEL_NAME}") != "function" ]]; then
    231   die "ERROR: Unsupported model: \"${MODEL_NAME}\""
    232 fi
    233 
    234 # Invoke test routine according to model name
    235 "test_${MODEL_NAME}" && \
    236     FAILED=0 || \
    237     FAILED=1
    238 
    239 # Tear down current k8s TensorFlow cluster
    240 if [[ "${TEARDOWN_WHEN_DONE}" == "1" ]]; then
    241   echo "Tearing down k8s TensorFlow cluster..."
    242   "${DIR}/delete_tf_cluster.sh" "${GCLOUD_OP_MAX_STEPS}" && \
    243       echo "Cluster tear-down SUCCEEDED" || \
    244       die "Cluster tear-down FAILED"
    245 fi
    246 
    247 if [[ "${FAILED}" == 1 ]]; then
    248   die "Test of distributed training of model ${MODEL_NAME} FAILED"
    249 else
    250   echo "SUCCESS: Test of distributed TensorFlow runtime PASSED"
    251   echo ""
    252 fi
    253