1 #!/usr/bin/env bash 2 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 # 4 # Licensed under the Apache License, Version 2.0 (the "License"); 5 # you may not use this file except in compliance with the License. 6 # You may obtain a copy of the License at 7 # 8 # http://www.apache.org/licenses/LICENSE-2.0 9 # 10 # Unless required by applicable law or agreed to in writing, software 11 # distributed under the License is distributed on an "AS IS" BASIS, 12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 # See the License for the specific language governing permissions and 14 # limitations under the License. 15 # ============================================================================== 16 # 17 # This script invokes dist_mnist.py multiple times concurrently to test the 18 # TensorFlow's distributed runtime over a Kubernetes (k8s) cluster with the 19 # grpc pods and service set up. 20 # 21 # Usage: 22 # dist_census_widendeep_test.sh <worker_grpc_urls> 23 # --num-workers <NUM_WORKERS> 24 # --num-parameter-servers <NUM_PARAMETER_SERVERS> 25 # 26 # worker_grp_url is the list of IP addresses or the GRPC URLs of the worker of 27 # the worker sessions, separated with spaces, 28 # e.g., "grpc://1.2.3.4:2222 grpc://5.6.7.8:2222" 29 # 30 # --num-workers <NUM_WORKERS>: 31 # Specifies the number of worker pods to use 32 # 33 # --num-parameter-server <NUM_PARAMETER_SERVERS>: 34 # Specifies the number of parameter servers to use 35 36 # Configurations 37 TIMEOUT=120 # Timeout for MNIST replica sessions 38 39 # Helper functions 40 die() { 41 echo $@ 42 exit 1 43 } 44 45 # Parse command-line arguments 46 WORKER_GRPC_URLS=$1 47 shift 48 49 # Process additional input arguments 50 N_WORKERS=2 # Default value 51 N_PS=2 # Default value 52 SYNC_REPLICAS=0 53 54 while true; do 55 if [[ "$1" == "--num-workers" ]]; then 56 N_WORKERS=$2 57 elif [[ "$1" == "--num-parameter-servers" ]]; then 58 N_PS=$2 59 elif [[ "$1" == "--sync-replicas" ]]; then 60 SYNC_REPLICAS="1" 61 die "ERROR: --sync-replicas (synchronized-replicas) mode is not fully "\ 62 "supported by this test yet." 63 # TODO(cais): Remove error message once sync-replicas is fully supported 64 fi 65 shift 66 67 if [[ -z "$1" ]]; then 68 break 69 fi 70 done 71 72 echo "N_WORKERS = ${N_WORKERS}" 73 echo "N_PS = ${N_PS}" 74 75 # Dierctory to store the trained model and evaluation results. 76 # The root (e.g., /shared) must be a directory shared among the workers. 77 # See volumeMounts fields in k8s_tensorflow.py 78 MODEL_DIR="/shared/census_widendeep_model" 79 80 rm -rf ${MODEL_DIR} || \ 81 die "Failed to remove existing model directory: ${MODEL_DIR}" 82 83 SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" 84 PY_PATH="${SCRIPT_DIR}/../python/census_widendeep.py" 85 if [[ ! -f "${PY_PATH}" ]]; then 86 echo "ERROR: Python file does not exist: ${PY_PATH}" 87 exit 1 88 fi 89 90 STAGGERED_START_DELAY_SEC=0 91 WKR_LOG_PREFIX="/tmp/worker_" 92 93 IDX=0 94 LOG_FILES="" 95 for WORKER_GRPC_URL in ${WORKER_GRPC_URLS}; do 96 if [[ ${IDX} != "0" ]]; then 97 sleep ${STAGGERED_START_DELAY_SEC} 98 fi 99 100 LOG_FILE="${WKR_LOG_PREFIX}${IDX}.log" 101 LOG_FILES="${LOG_FILES} ${LOG_FILE}" 102 python ${PY_PATH} \ 103 --master_grpc_url="${WORKER_GRPC_URL}" \ 104 --num_parameter_servers="${N_PS}" \ 105 --worker_index="${IDX}" \ 106 --model_dir="${MODEL_DIR}" \ 107 --output_dir="/shared/output" \ 108 --train_steps=1000 \ 109 --eval_steps=2 2>&1 | tee "${LOG_FILE}" & 110 111 echo "Worker ${IDX}: " 112 echo " GRPC URL: ${WORKER_GRPC_URL}" 113 echo " log file: ${LOG_FILE}" 114 115 ((IDX++)) 116 done 117 118 # Wait for all concurrent jobs to finish 119 wait 120 121 # Print logs from the workers 122 ORD=1 123 for LOG_FILE in ${LOG_FILES}; do 124 echo "===================================================" 125 echo "=== Log file from worker ${ORD} / ${N_WORKERS} ===" 126 cat "${LOG_FILE}" 127 echo "===================================================" 128 echo "" 129 130 ((ORD++)) 131 done 132 133 echo "Test for distributed training of Census Wide & Deep model PASSED"