1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/grappler/costs/virtual_placer.h" 17 #include "tensorflow/core/framework/node_def.pb.h" 18 #include "tensorflow/core/grappler/clusters/cluster.h" 19 #include "tensorflow/core/grappler/devices.h" 20 #include "tensorflow/core/lib/strings/str_util.h" 21 #include "tensorflow/core/util/device_name_utils.h" 22 23 namespace tensorflow { 24 namespace grappler { 25 26 VirtualPlacer::VirtualPlacer(const Cluster* cluster) { 27 CHECK(cluster); 28 29 // Default job name for canonical device name. Needs to be set before the 30 // first call to to_lfqn_or_empty() 31 default_job_name_lowercase_ = "localhost"; 32 33 devices_ = cluster->GetDevices(); 34 lfqn_map_.reserve(devices_.size()); 35 for (const auto& kv : devices_) { 36 const auto lfqn = to_lfqn_or_empty(kv.first); 37 if (lfqn.empty()) { 38 LOG(ERROR) << "VirtualPlacer couldn't parse device name from cluster: " 39 << kv.first; 40 } else { 41 lfqn_map_[lfqn] = kv.first; 42 } 43 } 44 45 if (devices_.empty()) { 46 // If there are no devices in the cluster, add a single device, "UNKNOWN" to 47 // the cluster. 48 default_device_name_ = "UNKNOWN"; 49 DeviceProperties& prop = devices_["UNKNOWN"]; 50 prop.set_type("UNKNOWN"); 51 } else if (devices_.size() == 1) { 52 // If there is only one device in the cluster, use it as default device, 53 // whatever it is. 54 default_device_name_ = devices_.begin()->first; 55 } else { 56 // Default device is set from the devices in the cluster in the following 57 // priority: /gpu:0, /cpu:0, or any device. 58 // TODO(dyoon): This logic assumes single machine with CPU and GPU devices. 59 // Make it more general to support multiple machines, job types, and devices 60 // other than CPU and GPU. 61 std::map<int, string> cpu_devices; // CPU device map: id -> device name. 62 std::map<int, string> gpu_devices; // GPU device map: id -> device name. 63 for (const auto& kv : lfqn_map_) { 64 const auto& lfqn = kv.first; 65 const auto& cluster_device_name = kv.second; 66 DeviceNameUtils::ParsedName parsed_name; 67 bool parsed = DeviceNameUtils::ParseFullName(lfqn, &parsed_name); 68 if (parsed) { 69 // Parsed devices are stored to cpu_devices or gpu_devices map, 70 // addressed (and ordered) by device id. 71 const auto type = str_util::Lowercase(parsed_name.type); 72 if (type == "gpu") { 73 gpu_devices[parsed_name.id] = cluster_device_name; 74 } else if (type == "cpu") { 75 cpu_devices[parsed_name.id] = cluster_device_name; 76 } 77 } 78 } 79 80 if (!gpu_devices.empty()) { 81 // GPU:0 (or GPU with smallest device id). 82 default_device_name_ = gpu_devices.begin()->second; 83 } else if (!cpu_devices.empty()) { 84 // CPU:0 (or CPU with smallest device id). 85 default_device_name_ = cpu_devices.begin()->second; 86 } else { 87 default_device_name_ = devices_.begin()->first; // Any device. 88 } 89 } 90 91 // Scan the device names from the cluster, and if there is one job name used, 92 // use it for canonical device name. 93 std::unordered_set<string> job_names_from_cluster; 94 for (const auto& device : lfqn_map_) { 95 const auto& lfqn = device.first; 96 DeviceNameUtils::ParsedName parsed_name; 97 bool parsed = DeviceNameUtils::ParseFullName(lfqn, &parsed_name); 98 if (parsed && !parsed_name.job.empty()) { 99 job_names_from_cluster.insert(parsed_name.job); 100 if (job_names_from_cluster.size() > 1) { 101 break; 102 } 103 } 104 } 105 // If there is only type of job name in all the devices in the cluster, use 106 // that one as default job name; otherwise, use localhost. 107 // TODO(dyoon): this should be improved, especially when the cluster is 108 // composed of multiple worker, PS, and other types of jobs. 109 if (job_names_from_cluster.size() == 1) { 110 auto it = job_names_from_cluster.begin(); 111 default_job_name_lowercase_ = *it; 112 } 113 } 114 115 const DeviceProperties& VirtualPlacer::get_device(const NodeDef& node) const { 116 string device = get_canonical_device_name(node); 117 VLOG(3) << "node.name=" << node.name() << " node.device=" << node.device() 118 << " is placed on: " << device; 119 auto it = devices_.find(device); 120 DCHECK(it != devices_.end()); 121 return it->second; 122 } 123 124 string VirtualPlacer::get_canonical_device_name(const NodeDef& node) const { 125 if (node.device().empty()) { 126 return default_device_name_; 127 } 128 129 const auto lfqn = to_lfqn_or_empty(node.device()); 130 if (lfqn.empty()) { 131 return default_device_name_; 132 } 133 134 const auto it = lfqn_map_.find(lfqn); 135 if (it != lfqn_map_.end()) { 136 return it->second; 137 } 138 139 return default_device_name_; 140 } 141 142 string VirtualPlacer::to_lfqn_or_empty(const string& device_name) const { 143 DeviceNameUtils::ParsedName parsed_name; 144 const auto lowercase_name = str_util::Lowercase(device_name); 145 bool parsed = DeviceNameUtils::ParseFullName(lowercase_name, &parsed_name); 146 if (!parsed) { 147 parsed = DeviceNameUtils::ParseLocalName(lowercase_name, &parsed_name); 148 parsed_name.job = "localhost"; 149 } 150 if (!parsed) { 151 if (lowercase_name == "gpu" || lowercase_name == "cpu") { 152 parsed_name.job = "localhost"; 153 parsed_name.type = lowercase_name; 154 parsed = true; 155 } 156 } 157 if (!parsed) { 158 return {}; 159 } 160 161 if (parsed_name.job.empty()) { 162 parsed_name.job = default_job_name_lowercase_; 163 } 164 165 // Have to do this, because parser returns uppercase types for CPU and GPU. 166 parsed_name.type = str_util::Lowercase(parsed_name.type); 167 168 string lfqn = strings::StrCat( 169 "/job:", parsed_name.job, "/replica:", parsed_name.replica, 170 "/task:", parsed_name.task, "/device:", parsed_name.type, ":", 171 parsed_name.id); 172 return lfqn; 173 } 174 175 } // end namespace grappler 176 } // end namespace tensorflow 177