Home | History | Annotate | Download | only in scripts
      1 #!/usr/bin/env bash
      2 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      3 #
      4 # Licensed under the Apache License, Version 2.0 (the "License");
      5 # you may not use this file except in compliance with the License.
      6 # You may obtain a copy of the License at
      7 #
      8 #     http://www.apache.org/licenses/LICENSE-2.0
      9 #
     10 # Unless required by applicable law or agreed to in writing, software
     11 # distributed under the License is distributed on an "AS IS" BASIS,
     12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 # See the License for the specific language governing permissions and
     14 # limitations under the License.
     15 # ==============================================================================
     16 #
     17 # Create a Kubernetes (k8s) cluster of TensorFlow workers
     18 #
     19 # Usage:
     20 #   create_tf_cluster.sh <num_workers> <num_parameter_servers>
     21 #
     22 # In addition, this script obeys values in the following environment variables:
     23 #   TF_DIST_LOCAL_CLUSTER:        create TensorFlow cluster on local machine
     24 #   TF_DIST_SERVER_DOCKER_IMAGE:  overrides the default docker image to launch
     25 #                                 TensorFlow (GRPC) servers with
     26 #   TF_DIST_GCLOUD_PROJECT:       gcloud project in which the GKE cluster
     27 #                                 will be created (valid only if aforementioned
     28 #                                 TF_DIST_GRPC_SERVER_URL is empty).
     29 #   TF_DIST_GCLOUD_COMPUTE_ZONE:  gcloud compute zone.
     30 #   TF_DIST_CONTAINER_CLUSTER:    name of the GKE cluster
     31 #   TF_DIST_GCLOUD_KEY_FILE:      if non-empty, will override GCLOUD_KEY_FILE
     32 #   TF_DIST_GRPC_PORT:            overrides the default port (2222)
     33 #                                 to run the GRPC servers on
     34 
     35 # Configurations
     36 # gcloud operation timeout (steps)
     37 GCLOUD_OP_MAX_STEPS=360
     38 
     39 GRPC_PORT=${TF_DIST_GRPC_PORT:-2222}
     40 
     41 DEFAULT_GCLOUD_BIN=/var/gcloud/google-cloud-sdk/bin/gcloud
     42 GCLOUD_KEY_FILE=${TF_DIST_GCLOUD_KEY_FILE:-\
     43 "/var/gcloud/secrets/tensorflow-testing.json"}
     44 GCLOUD_PROJECT=${TF_DIST_GCLOUD_PROJECT:-"tensorflow-testing"}
     45 
     46 GCLOUD_COMPUTE_ZONE=${TF_DIST_GCLOUD_COMPUTE_ZONE:-"us-central1-f"}
     47 CONTAINER_CLUSTER=${TF_DIST_CONTAINER_CLUSTER:-"test-cluster"}
     48 
     49 SERVER_DOCKER_IMAGE=${TF_DIST_SERVER_DOCKER_IMAGE:-\
     50 "tensorflow/tf_grpc_test_server"}
     51 
     52 # Get current script directory
     53 DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
     54 
     55 # Get utility functions
     56 source "${DIR}/utils.sh"
     57 
     58 # Check input arguments
     59 if [[ $# != 2 ]]; then
     60   die "Usage: $0 <num_workers> <num_parameter_servers>"
     61 fi
     62 
     63 NUM_WORKERS=$1
     64 NUM_PARAMETER_SERVERS=$2
     65 
     66 # Verify port string
     67 if [[ -z $(echo "${GRPC_PORT}" | grep -E "^[0-9]{1,5}") ]]; then
     68   die "Invalid GRPC port: \"${GRPC_PORT}\""
     69 fi
     70 echo "GRPC port to be used when creating the k8s TensorFlow cluster: "\
     71 "${GRPC_PORT}"
     72 
     73 if [[ -z "${TF_DIST_LOCAL_CLUSTER}" ]] ||
     74    [[ "${TF_DIST_LOCAL_CLUSTER}" == "0" ]]; then
     75   IS_LOCAL_CLUSTER="0"
     76 else
     77   IS_LOCAL_CLUSTER="1"
     78 fi
     79 
     80 if [[ ${IS_LOCAL_CLUSTER} == "0" ]]; then
     81   # Locate gcloud binary path
     82   GCLOUD_BIN=$(which gcloud)
     83   if [[ -z "${GCLOUD_BIN}" ]]; then
     84     GCLOUD_BIN="${DEFAULT_GCLOUD_BIN}"
     85   fi
     86 
     87   if [[ ! -f "${GCLOUD_BIN}" ]]; then
     88     die "gcloud binary cannot be found at: ${GCLOUD_BIN}"
     89   fi
     90   echo "Path to gcloud binary: ${GCLOUD_BIN}"
     91 
     92   # Path to gcloud service key file
     93   if [[ ! -f "${GCLOUD_KEY_FILE}" ]]; then
     94     die "gcloud service account key file cannot be found at: ${GCLOUD_KEY_FILE}"
     95   fi
     96   echo "Path to gcloud key file: ${GCLOUD_KEY_FILE}"
     97 
     98   echo "GCLOUD_PROJECT: ${GCLOUD_PROJECT}"
     99   echo "GCLOUD_COMPUTER_ZONE: ${GCLOUD_COMPUTE_ZONE}"
    100   echo "CONTAINER_CLUSTER: ${CONTAINER_CLUSTER}"
    101 
    102   # Activate gcloud service account
    103   "${GCLOUD_BIN}" auth activate-service-account --key-file "${GCLOUD_KEY_FILE}"
    104 
    105   # See: https://github.com/kubernetes/kubernetes/issues/30617
    106   "${GCLOUD_BIN}" config set container/use_client_certificate True
    107 
    108   # Set gcloud project
    109   "${GCLOUD_BIN}" config set project "${GCLOUD_PROJECT}"
    110 
    111   # Set compute zone
    112   "${GCLOUD_BIN}" config set compute/zone "${GCLOUD_COMPUTE_ZONE}"
    113 
    114   # Set container cluster
    115   "${GCLOUD_BIN}" config set container/cluster "${CONTAINER_CLUSTER}"
    116 
    117   # Get container cluster credentials
    118   "${GCLOUD_BIN}" container clusters get-credentials "${CONTAINER_CLUSTER}"
    119   if [[ $? != "0" ]]; then
    120     die "FAILED to get credentials for container cluster: ${CONTAINER_CLUSTER}"
    121   fi
    122 
    123   # If there is any existing tf k8s cluster, delete it first
    124   "${DIR}/delete_tf_cluster.sh" "${GCLOUD_OP_MAX_STEPS}"
    125 fi
    126 
    127 # Path to kubectl binary
    128 KUBECTL_BIN=$(dirname "${GCLOUD_BIN}")/kubectl
    129 if [[ ! -f "${KUBECTL_BIN}" ]]; then
    130   die "kubectl binary cannot be found at: ${KUBECTL_BIN}"
    131 fi
    132 echo "Path to kubectl binary: ${KUBECTL_BIN}"
    133 
    134 # Create yaml file for k8s TensorFlow cluster creation
    135 # Path to the (Python) script for generating k8s yaml file
    136 K8S_GEN_TF_YAML="${DIR}/k8s_tensorflow.py"
    137 if [[ ! -f ${K8S_GEN_TF_YAML} ]]; then
    138   die "FAILED to find yaml-generating script at: ${K8S_GEN_TF_YAML}"
    139 fi
    140 
    141 K8S_YAML="/tmp/k8s_tf_lb.yaml"
    142 rm -f "${K8S_YAML}"
    143 
    144 echo ""
    145 echo "Generating k8s cluster yaml config file with the following settings"
    146 echo "  Server docker image: ${SERVER_DOCKER_IMAGE}"
    147 echo "  Number of workers: ${NUM_WORKERS}"
    148 echo "  Number of parameter servers: ${NUM_PARAMETER_SERVERS}"
    149 echo "  GRPC port: ${GRPC_PORT}"
    150 echo ""
    151 
    152 ${K8S_GEN_TF_YAML} \
    153     --docker_image "${SERVER_DOCKER_IMAGE}" \
    154     --num_workers "${NUM_WORKERS}" \
    155     --num_parameter_servers "${NUM_PARAMETER_SERVERS}" \
    156     --grpc_port "${GRPC_PORT}" \
    157     --request_load_balancer=True \
    158     > "${K8S_YAML}" || \
    159     die "Generation of the yaml configuration file for k8s cluster FAILED"
    160 
    161 if [[ ! -f "${K8S_YAML}" ]]; then
    162     die "FAILED to generate yaml file for TensorFlow k8s container cluster"
    163 else
    164     echo "Generated yaml configuration file for k8s TensorFlow cluster: "\
    165 "${K8S_YAML}"
    166     cat "${K8S_YAML}"
    167 fi
    168 
    169 # Create tf k8s container cluster
    170 "${KUBECTL_BIN}" create -f "${K8S_YAML}"
    171 
    172 # Wait for external IP of worker services to become available
    173 get_tf_external_ip() {
    174   # Usage: gen_tf_worker_external_ip <JOB_NAME> <TASK_INDEX>
    175   # E.g.,  gen_tf_worker_external_ip ps 2
    176   echo $("${KUBECTL_BIN}" get svc | grep "^tf-${1}${2}" | \
    177          awk '{print $3}' | grep -E "[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+")
    178 }
    179 
    180 if [[ ${IS_LOCAL_CLUSTER} == "0" ]]; then
    181   echo "Waiting for external IP of tf-worker0 service to emerge..."
    182   echo ""
    183 
    184   COUNTER=0
    185   while true; do
    186     sleep 1
    187     ((COUNTER++))
    188     if [[ "${COUNTER}" -gt "${GCLOUD_OP_MAX_STEPS}" ]]; then
    189       die "Reached maximum polling steps while waiting for external IP "\
    190 "of tf-worker0 service to emerge"
    191     fi
    192 
    193     WORKER_EXTERN_IPS=""
    194     WORKER_INDEX=0
    195     N_AVAILABLE_WORKER_EXTERNAL_IPS=0
    196     while true; do
    197       SVC_EXTERN_IP=$(get_tf_external_ip worker ${WORKER_INDEX})
    198 
    199       if [[ ! -z "${SVC_EXTERN_IP}" ]]; then
    200         WORKER_EXTERN_IPS="${WORKER_EXTERN_IPS} ${SVC_EXTERN_IP}"
    201 
    202         ((N_AVAILABLE_WORKER_EXTERNAL_IPS++))
    203       fi
    204 
    205       ((WORKER_INDEX++))
    206       if [[ ${WORKER_INDEX} == ${NUM_WORKERS} ]]; then
    207         break;
    208       fi
    209     done
    210 
    211     PS_EXTERN_IPS=""
    212     PS_INDEX=0
    213     N_AVAILABLE_PS_EXTERNAL_IPS=0
    214     while true; do
    215       SVC_EXTERN_IP=$(get_tf_external_ip ps ${PS_INDEX})
    216 
    217       if [[ ! -z "${SVC_EXTERN_IP}" ]]; then
    218         PS_EXTERN_IPS="${PS_EXTERN_IPS} ${SVC_EXTERN_IP}"
    219 
    220         ((N_AVAILABLE_PS_EXTERNAL_IPS++))
    221       fi
    222 
    223       ((PS_INDEX++))
    224       if [[ ${PS_INDEX} == ${NUM_PARAMETER_SERVERS} ]]; then
    225         break;
    226       fi
    227     done
    228 
    229     if [[ ${N_AVAILABLE_WORKER_EXTERNAL_IPS} == ${NUM_WORKERS} ]] && \
    230        [[ ${N_AVAILABLE_PS_EXTERNAL_IPS} == ${NUM_PARAMETER_SERVERS} ]]; then
    231       break;
    232     fi
    233   done
    234 
    235   GRPC_SERVER_URLS=""
    236   for IP in ${WORKER_EXTERN_IPS}; do
    237     GRPC_SERVER_URLS="${GRPC_SERVER_URLS} grpc://${IP}:${GRPC_PORT}"
    238   done
    239 
    240   GRPC_PS_URLS=""
    241   for IP in ${PS_EXTERN_IPS}; do
    242     GRPC_PS_URLS="${GRPC_PS_URLS} grpc://${IP}:${GRPC_PORT}"
    243   done
    244 
    245   echo "GRPC URLs of tf-worker instances: ${GRPC_SERVER_URLS}"
    246   echo "GRPC URLs of tf-ps instances: ${GRPC_PS_URLS}"
    247 
    248 else
    249   echo "Waiting for tf pods to be all running..."
    250   echo ""
    251 
    252   COUNTER=0
    253   while true; do
    254     sleep 1
    255     ((COUNTER++))
    256     if [[ "${COUNTER}" -gt "${GCLOUD_OP_MAX_STEPS}" ]]; then
    257       die "Reached maximum polling steps while waiting for all tf pods to "\
    258 "be running in local k8s TensorFlow cluster"
    259     fi
    260 
    261     PODS_STAT=$(are_all_pods_running "${KUBECTL_BIN}")
    262 
    263     if [[ ${PODS_STAT} == "2" ]]; then
    264       # Error has occurred
    265       die "Error(s) occurred while tring to launch tf k8s cluster. "\
    266 "One possible cause is that the Docker image used to launch the cluster is "\
    267 "invalid: \"${SERVER_DOCKER_IMAGE}\""
    268     fi
    269 
    270     if [[ ${PODS_STAT} == "1" ]]; then
    271       break
    272     fi
    273   done
    274 
    275   # Determine the tf-worker0 docker container id
    276   WORKER0_ID=$(docker ps | grep "k8s_tf-worker0" | awk '{print $1}')
    277   echo "WORKER0 Docker container ID: ${WORKER0_ID}"
    278 
    279 fi
    280 
    281 
    282 echo "Cluster setup complete."
    283 echo ""
    284