1 #!/usr/bin/env bash 2 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 # 4 # Licensed under the Apache License, Version 2.0 (the "License"); 5 # you may not use this file except in compliance with the License. 6 # You may obtain a copy of the License at 7 # 8 # http://www.apache.org/licenses/LICENSE-2.0 9 # 10 # Unless required by applicable law or agreed to in writing, software 11 # distributed under the License is distributed on an "AS IS" BASIS, 12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 # See the License for the specific language governing permissions and 14 # limitations under the License. 15 # ============================================================================== 16 # 17 # Performs tests of TensorFlow's distributed runtime over a Kubernetes (k8s) 18 # container cluster. 19 # 20 # This script tears down any existing TensorFlow cluster, consisting of 21 # services, replication controllers and pods, before creating a new cluster. 22 # The cluster containers a number of parameter server services and a number of 23 # worker services. The parameter servers will hold parameters of the ML model, 24 # e.g., weights and biases of the NN layers, while the workers will hold the 25 # TensorFlow ops. 26 # 27 # Usage: 28 # dist_test.sh [--setup_cluster_only] 29 # [--model_name (MNIST | CENSUS_WIDENDEEP)] 30 # [--num_workers <NUM_WORKERS>] 31 # [--num_parameter_servers <NUM_PARAMETER_SERVERS>] 32 # [--sync_replicas] 33 # 34 # --setup_cluster_only: 35 # Lets the script only set up the k8s container network 36 # 37 # --model_name 38 # Name of the model to test. Default is MNIST. 39 # 40 # --num-workers <NUM_WORKERS>: 41 # Specifies the number of worker pods to start 42 # 43 # --num_parameter_servers <NUM_PARAMETER_SERVERS>: 44 # Specifies the number of parameter servers to start 45 # 46 # --sync_replicas 47 # Use the synchronized-replica mode. The parameter updates from the replicas 48 # (workers) will be aggregated before applied, which avoids stale parameter 49 # updates. 50 # 51 # 52 # This script obeys values in the following environment variables: 53 # TF_DIST_GRPC_SERVER_URLS: If it is set to a list of valid server urls, 54 # separated with spaces or commas 55 # (e.g., "grpc://1.2.3.4:2222 grpc//5.6.7.8:2222"), 56 # the script will bypass the cluster setup and 57 # teardown processes and just use this URL. 58 59 60 # Helper functions 61 die() { 62 echo $@ 63 exit 1 64 } 65 66 # Parse input arguments: number of workers 67 # Default values: 68 MODEL_NAME="MNIST" # Model name, default is "MNIST" 69 NUM_WORKERS=2 # Number of worker container 70 NUM_PARAMETER_SERVERS=2 # Number of parameter servers 71 SYNC_REPLICAS=0 72 SETUP_CLUSTER_ONLY=0 73 74 while true; do 75 if [[ "$1" == "--model_name" ]]; then 76 MODEL_NAME=$2 77 elif [[ "$1" == "--num_workers" ]]; then 78 NUM_WORKERS=$2 79 elif [[ "$1" == "--num_parameter_servers" ]]; then 80 NUM_PARAMETER_SERVERS=$2 81 elif [[ "$1" == "--sync_replicas" ]]; then 82 SYNC_REPLICAS=1 83 elif [[ "$1" == "--setup_cluster_only" ]]; then 84 SETUP_CLUSTER_ONLY=1 85 fi 86 shift 87 88 if [[ -z "$1" ]]; then 89 break 90 fi 91 done 92 93 echo "MODEL_NAME = \"MODEL_NAME\"" 94 echo "NUM_WORKERS = ${NUM_WORKERS}" 95 echo "NUM_PARAMETER_SERVERS = ${NUM_PARAMETER_SERVERS}" 96 echo "SETUP_CLUSTER_ONLY = ${SETUP_CLUSTER_ONLY}" 97 98 # gcloud operation timeout (steps) 99 GCLOUD_OP_MAX_STEPS=240 100 101 if [[ ! -z ${TF_DIST_GRPC_SERVER_URLS} ]]; then 102 GRPC_SERVER_URLS=${TF_DIST_GRPC_SERVER_URLS} 103 GRPC_SERVER_URLS=$(echo ${GRPC_SERVER_URLS} | sed -e 's/,/ /g') 104 fi 105 106 # Report gcloud / GKE parameters 107 echo "GRPC_SERVER_URLS: ${GRPC_SERVER_URLS}" 108 echo "SYNC_REPLICAS: ${SYNC_REPLICAS}" 109 110 # Get current script directory 111 DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" 112 113 # Locate path to kubectl binary 114 TEARDOWN_WHEN_DONE=1 115 if [[ ! -z "${GRPC_SERVER_URLS}" ]]; then 116 TEARDOWN_WHEN_DONE=0 117 # Verify the validity of the GRPC URL 118 for GRPC_SERVER_URL in ${GRPC_SERVER_URLS}; do 119 if [[ -z $(echo "${GRPC_SERVER_URL}" | \ 120 grep -E "^grpc://.+:[0-9]+") ]]; then 121 die "Invalid GRPC_SERVER_URL: \"${GRPC_SERVER_URL}\"" 122 fi 123 done 124 125 echo "The preset GRPC_SERVER_URLS appears to be valid: ${GRPC_SERVER_URLS}" 126 echo "Will bypass the TensorFlow k8s cluster setup and teardown process" 127 echo "" 128 129 else 130 TMP=$(mktemp) 131 "${DIR}/create_tf_cluster.sh" ${NUM_WORKERS} ${NUM_PARAMETER_SERVERS} 2>&1 | \ 132 tee "${TMP}" || \ 133 die "Creation of TensorFlow k8s cluster FAILED" 134 135 GRPC_SERVER_URLS=$(cat ${TMP} | grep "GRPC URLs of tf-worker instances: .*" | \ 136 sed -e 's/GRPC URLs of tf-worker instances://g') 137 138 GRPC_PS_URLS=$(cat ${TMP} | grep "GRPC URLs of tf-ps instances: .*" | \ 139 sed -e 's/GRPC URLs of tf-ps instances://g') 140 141 if [[ $(echo ${GRPC_SERVER_URLS} | wc -w) != ${NUM_WORKERS} ]]; then 142 die "FAILED to determine GRPC server URLs of all workers" 143 fi 144 if [[ $(echo ${GRPC_PS_URLS} | wc -w) != ${NUM_PARAMETER_SERVERS} ]]; then 145 die "FAILED to determine GRPC server URLs of all parameter servers" 146 fi 147 148 WORKER_HOSTS=$(echo "${GRPC_SERVER_URLS}" | sed -e 's/^[[:space:]]*//' | \ 149 sed -e 's/grpc:\/\///g' | sed -e 's/ /,/g') 150 PS_HOSTS=$(echo "${GRPC_PS_URLS}" | sed -e 's/^[[:space:]]*//' | \ 151 sed -e 's/grpc:\/\///g' | sed -e 's/ /,/g') 152 153 echo "WORKER_HOSTS = ${WORKER_HOSTS}" 154 echo "PS_HOSTS = ${PS_HOSTS}" 155 156 rm -f ${TMP} 157 158 if [[ ${SETUP_CLUSTER_ONLY} == "1" ]]; then 159 echo "Skipping testing of distributed runtime due to "\ 160 "option flag --setup_cluster_only" 161 exit 0 162 fi 163 fi 164 165 166 # Test routine for model "MNIST" 167 test_MNIST() { 168 # Invoke script to perform distributed MNIST training 169 MNIST_DIST_TEST_BIN="${DIR}/dist_mnist_test.sh" 170 if [[ ! -f "${MNIST_DIST_TEST_BIN}" ]]; then 171 echo "FAILED to find distributed mnist client test script at "\ 172 "${MNIST_DIST_TEST_BIN}" 173 return 1 174 fi 175 176 echo "Performing distributed MNIST training through worker grpc sessions @ "\ 177 "${GRPC_SERVER_URLS}..." 178 179 echo "and ps grpc sessions @ ${GRPC_PS_URLS}" 180 181 SYNC_REPLICAS_FLAG="" 182 if [[ ${SYNC_REPLICAS} == "1" ]]; then 183 SYNC_REPLICAS_FLAG="--sync_replicas" 184 fi 185 186 "${MNIST_DIST_TEST_BIN}" \ 187 --existing_servers True \ 188 --ps_hosts "${PS_HOSTS}" \ 189 --worker_hosts "${WORKER_HOSTS}" \ 190 --num_gpus 0 \ 191 ${SYNC_REPLICAS_FLAG} 192 193 if [[ $? == "0" ]]; then 194 echo "MNIST-replica test PASSED" 195 else 196 echo "MNIST-replica test FAILED" 197 return 1 198 fi 199 echo "" 200 } 201 202 # Test routine for model "CENSUS_WIDENDEEP" 203 test_CENSUS_WIDENDEEP() { 204 # Invoke script to perform distributed census_widendeep training 205 CENSUS_WIDENDEEP_DIST_TEST_BIN="${DIR}/dist_census_widendeep_test.sh" 206 if [[ ! -f "${CENSUS_WIDENDEEP_DIST_TEST_BIN}" ]]; then 207 echo "FAILED to find distributed widen&deep client test script at "\ 208 "${CENSUS_WIDENDEEP_DIST_TEST_BIN}" 209 return 1 210 fi 211 212 echo "Performing distributed wide&deep (census) training through grpc "\ 213 "sessions @ ${GRPC_SERVER_URLS}..." 214 215 "${CENSUS_WIDENDEEP_DIST_TEST_BIN}" "${GRPC_SERVER_URLS}" \ 216 --num-workers "${NUM_WORKERS}" \ 217 --num-parameter-servers "${NUM_PARAMETER_SERVERS}" 218 219 if [[ $? == "0" ]]; then 220 echo "Census Wide & Deep test PASSED" 221 echo "" 222 else 223 echo "Census Wide & Deep test FAILED" 224 echo "" 225 return 1 226 fi 227 } 228 229 # Validate model name 230 if [[ $(type -t "test_${MODEL_NAME}") != "function" ]]; then 231 die "ERROR: Unsupported model: \"${MODEL_NAME}\"" 232 fi 233 234 # Invoke test routine according to model name 235 "test_${MODEL_NAME}" && \ 236 FAILED=0 || \ 237 FAILED=1 238 239 # Tear down current k8s TensorFlow cluster 240 if [[ "${TEARDOWN_WHEN_DONE}" == "1" ]]; then 241 echo "Tearing down k8s TensorFlow cluster..." 242 "${DIR}/delete_tf_cluster.sh" "${GCLOUD_OP_MAX_STEPS}" && \ 243 echo "Cluster tear-down SUCCEEDED" || \ 244 die "Cluster tear-down FAILED" 245 fi 246 247 if [[ "${FAILED}" == 1 ]]; then 248 die "Test of distributed training of model ${MODEL_NAME} FAILED" 249 else 250 echo "SUCCESS: Test of distributed TensorFlow runtime PASSED" 251 echo "" 252 fi 253