Home | History | Annotate | Download | only in scripts
      1 #!/usr/bin/env bash
      2 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      3 #
      4 # Licensed under the Apache License, Version 2.0 (the "License");
      5 # you may not use this file except in compliance with the License.
      6 # You may obtain a copy of the License at
      7 #
      8 #     http://www.apache.org/licenses/LICENSE-2.0
      9 #
     10 # Unless required by applicable law or agreed to in writing, software
     11 # distributed under the License is distributed on an "AS IS" BASIS,
     12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 # See the License for the specific language governing permissions and
     14 # limitations under the License.
     15 # ==============================================================================
     16 #
     17 # This script invokes dist_mnist.py multiple times concurrently to test the
     18 # TensorFlow's distributed runtime over a Kubernetes (k8s) cluster with the
     19 # grpc pods and service set up.
     20 #
     21 # Usage:
     22 #    dist_census_widendeep_test.sh <worker_grpc_urls>
     23 #        --num-workers <NUM_WORKERS>
     24 #        --num-parameter-servers <NUM_PARAMETER_SERVERS>
     25 #
     26 # worker_grp_url is the list of IP addresses or the GRPC URLs of the worker of
     27 # the worker sessions, separated with spaces,
     28 # e.g., "grpc://1.2.3.4:2222 grpc://5.6.7.8:2222"
     29 #
     30 # --num-workers <NUM_WORKERS>:
     31 #   Specifies the number of worker pods to use
     32 #
     33 # --num-parameter-server <NUM_PARAMETER_SERVERS>:
     34 #   Specifies the number of parameter servers to use
     35 
     36 # Configurations
     37 TIMEOUT=120  # Timeout for MNIST replica sessions
     38 
     39 # Helper functions
     40 die() {
     41   echo $@
     42   exit 1
     43 }
     44 
     45 # Parse command-line arguments
     46 WORKER_GRPC_URLS=$1
     47 shift
     48 
     49 # Process additional input arguments
     50 N_WORKERS=2  # Default value
     51 N_PS=2  # Default value
     52 SYNC_REPLICAS=0
     53 
     54 while true; do
     55   if [[ "$1" == "--num-workers" ]]; then
     56     N_WORKERS=$2
     57   elif [[ "$1" == "--num-parameter-servers" ]]; then
     58     N_PS=$2
     59   elif [[ "$1" == "--sync-replicas" ]]; then
     60     SYNC_REPLICAS="1"
     61     die "ERROR: --sync-replicas (synchronized-replicas) mode is not fully "\
     62 "supported by this test yet."
     63     # TODO(cais): Remove error message once sync-replicas is fully supported
     64   fi
     65   shift
     66 
     67   if [[ -z "$1" ]]; then
     68     break
     69   fi
     70 done
     71 
     72 echo "N_WORKERS = ${N_WORKERS}"
     73 echo "N_PS = ${N_PS}"
     74 
     75 # Dierctory to store the trained model and evaluation results.
     76 # The root (e.g., /shared) must be a directory shared among the workers.
     77 # See volumeMounts fields in k8s_tensorflow.py
     78 MODEL_DIR="/shared/census_widendeep_model"
     79 
     80 rm -rf ${MODEL_DIR} || \
     81     die "Failed to remove existing model directory: ${MODEL_DIR}"
     82 
     83 SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
     84 PY_PATH="${SCRIPT_DIR}/../python/census_widendeep.py"
     85 if [[ ! -f "${PY_PATH}" ]]; then
     86   echo "ERROR: Python file does not exist: ${PY_PATH}"
     87   exit 1
     88 fi
     89 
     90 STAGGERED_START_DELAY_SEC=0
     91 WKR_LOG_PREFIX="/tmp/worker_"
     92 
     93 IDX=0
     94 LOG_FILES=""
     95 for WORKER_GRPC_URL in ${WORKER_GRPC_URLS}; do
     96   if [[ ${IDX} != "0" ]]; then
     97     sleep ${STAGGERED_START_DELAY_SEC}
     98   fi
     99 
    100   LOG_FILE="${WKR_LOG_PREFIX}${IDX}.log"
    101   LOG_FILES="${LOG_FILES} ${LOG_FILE}"
    102   python ${PY_PATH} \
    103       --master_grpc_url="${WORKER_GRPC_URL}" \
    104       --num_parameter_servers="${N_PS}" \
    105       --worker_index="${IDX}" \
    106       --model_dir="${MODEL_DIR}" \
    107       --output_dir="/shared/output" \
    108       --train_steps=1000 \
    109       --eval_steps=2 2>&1 | tee "${LOG_FILE}" &
    110 
    111   echo "Worker ${IDX}: "
    112   echo "  GRPC URL: ${WORKER_GRPC_URL}"
    113   echo "  log file: ${LOG_FILE}"
    114 
    115   ((IDX++))
    116 done
    117 
    118 # Wait for all concurrent jobs to finish
    119 wait
    120 
    121 # Print logs from the workers
    122 ORD=1
    123 for LOG_FILE in ${LOG_FILES}; do
    124   echo "==================================================="
    125   echo "===        Log file from worker ${ORD} / ${N_WORKERS}          ==="
    126   cat "${LOG_FILE}"
    127   echo "==================================================="
    128   echo ""
    129 
    130   ((ORD++))
    131 done
    132 
    133 echo "Test for distributed training of Census Wide & Deep model PASSED"