Home | History | Annotate | Download | only in util
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/util/device_name_utils.h"
     17 
     18 #include "tensorflow/core/lib/core/errors.h"
     19 #include "tensorflow/core/lib/strings/str_util.h"
     20 #include "tensorflow/core/lib/strings/strcat.h"
     21 #include "tensorflow/core/platform/logging.h"
     22 
     23 namespace tensorflow {
     24 
     25 static bool IsAlpha(char c) {
     26   return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z');
     27 }
     28 
     29 static bool IsAlphaNum(char c) { return IsAlpha(c) || (c >= '0' && c <= '9'); }
     30 
     31 // Returns true iff "in" is a valid job name.
     32 static bool IsJobName(StringPiece in) {
     33   if (in.empty()) return false;
     34   if (!IsAlpha(in[0])) return false;
     35   for (size_t i = 1; i < in.size(); ++i) {
     36     if (!(IsAlphaNum(in[i]) || in[i] == '_')) return false;
     37   }
     38   return true;
     39 }
     40 
     41 // Returns true and fills in "*job" iff "*in" starts with a job name.
     42 static bool ConsumeJobName(StringPiece* in, string* job) {
     43   if (in->empty()) return false;
     44   if (!IsAlpha((*in)[0])) return false;
     45   size_t i = 1;
     46   for (; i < in->size(); ++i) {
     47     const char c = (*in)[i];
     48     if (c == '/') break;
     49     if (!(IsAlphaNum(c) || c == '_')) {
     50       return false;
     51     }
     52   }
     53   job->assign(in->data(), i);
     54   in->remove_prefix(i);
     55   return true;
     56 }
     57 
     58 // Returns true and fills in "*device_type" iff "*in" starts with a device type
     59 // name.
     60 static bool ConsumeDeviceType(StringPiece* in, string* device_type) {
     61   if (in->empty()) return false;
     62   if (!IsAlpha((*in)[0])) return false;
     63   size_t i = 1;
     64   for (; i < in->size(); ++i) {
     65     const char c = (*in)[i];
     66     if (c == '/' || c == ':') break;
     67     if (!(IsAlphaNum(c) || c == '_')) {
     68       return false;
     69     }
     70   }
     71   device_type->assign(in->data(), i);
     72   in->remove_prefix(i);
     73   return true;
     74 }
     75 
     76 // Returns true and fills in "*val" iff "*in" starts with a decimal
     77 // number.
     78 static bool ConsumeNumber(StringPiece* in, int* val) {
     79   uint64 tmp;
     80   if (str_util::ConsumeLeadingDigits(in, &tmp)) {
     81     *val = tmp;
     82     return true;
     83   } else {
     84     return false;
     85   }
     86 }
     87 
     88 // Returns a fully qualified device name given the parameters.
     89 static string DeviceName(const string& job, int replica, int task,
     90                          const string& device_prefix, const string& device_type,
     91                          int id) {
     92   CHECK(IsJobName(job)) << job;
     93   CHECK_LE(0, replica);
     94   CHECK_LE(0, task);
     95   CHECK(!device_type.empty());
     96   CHECK_LE(0, id);
     97   return strings::StrCat("/job:", job, "/replica:", replica, "/task:", task,
     98                          device_prefix, device_type, ":", id);
     99 }
    100 
    101 /* static */
    102 string DeviceNameUtils::FullName(const string& job, int replica, int task,
    103                                  const string& type, int id) {
    104   return DeviceName(job, replica, task, "/device:", type, id);
    105 }
    106 
    107 namespace {
    108 string LegacyName(const string& job, int replica, int task, const string& type,
    109                   int id) {
    110   return DeviceName(job, replica, task, "/", str_util::Lowercase(type), id);
    111 }
    112 }  // anonymous namespace
    113 
    114 bool DeviceNameUtils::ParseFullName(StringPiece fullname, ParsedName* p) {
    115   p->Clear();
    116   if (fullname == "/") {
    117     return true;
    118   }
    119   while (!fullname.empty()) {
    120     bool progress = false;
    121     if (str_util::ConsumePrefix(&fullname, "/job:")) {
    122       p->has_job = !str_util::ConsumePrefix(&fullname, "*");
    123       if (p->has_job && !ConsumeJobName(&fullname, &p->job)) {
    124         return false;
    125       }
    126       progress = true;
    127     }
    128     if (str_util::ConsumePrefix(&fullname, "/replica:")) {
    129       p->has_replica = !str_util::ConsumePrefix(&fullname, "*");
    130       if (p->has_replica && !ConsumeNumber(&fullname, &p->replica)) {
    131         return false;
    132       }
    133       progress = true;
    134     }
    135     if (str_util::ConsumePrefix(&fullname, "/task:")) {
    136       p->has_task = !str_util::ConsumePrefix(&fullname, "*");
    137       if (p->has_task && !ConsumeNumber(&fullname, &p->task)) {
    138         return false;
    139       }
    140       progress = true;
    141     }
    142     if (str_util::ConsumePrefix(&fullname, "/device:")) {
    143       p->has_type = !str_util::ConsumePrefix(&fullname, "*");
    144       if (p->has_type && !ConsumeDeviceType(&fullname, &p->type)) {
    145         return false;
    146       }
    147       if (!str_util::ConsumePrefix(&fullname, ":")) {
    148         p->has_id = false;
    149       } else {
    150         p->has_id = !str_util::ConsumePrefix(&fullname, "*");
    151         if (p->has_id && !ConsumeNumber(&fullname, &p->id)) {
    152           return false;
    153         }
    154       }
    155       progress = true;
    156     }
    157 
    158     // Handle legacy naming convention for cpu and gpu.
    159     if (str_util::ConsumePrefix(&fullname, "/cpu:") ||
    160         str_util::ConsumePrefix(&fullname, "/CPU:")) {
    161       p->has_type = true;
    162       p->type = "CPU";  // Treat '/cpu:..' as uppercase '/device:CPU:...'
    163       p->has_id = !str_util::ConsumePrefix(&fullname, "*");
    164       if (p->has_id && !ConsumeNumber(&fullname, &p->id)) {
    165         return false;
    166       }
    167       progress = true;
    168     }
    169     if (str_util::ConsumePrefix(&fullname, "/gpu:") ||
    170         str_util::ConsumePrefix(&fullname, "/GPU:")) {
    171       p->has_type = true;
    172       p->type = "GPU";  // Treat '/gpu:..' as uppercase '/device:GPU:...'
    173       p->has_id = !str_util::ConsumePrefix(&fullname, "*");
    174       if (p->has_id && !ConsumeNumber(&fullname, &p->id)) {
    175         return false;
    176       }
    177       progress = true;
    178     }
    179 
    180     if (!progress) {
    181       return false;
    182     }
    183   }
    184   return true;
    185 }
    186 
    187 /* static */
    188 string DeviceNameUtils::CanonicalizeDeviceName(StringPiece fullname) {
    189   ParsedName parsed_name;
    190   if (ParseLocalName(fullname, &parsed_name)) {
    191     return ParsedNameToString(parsed_name);
    192   }
    193   if (ParseFullName(fullname, &parsed_name)) {
    194     return ParsedNameToString(parsed_name);
    195   }
    196   return "";
    197 }
    198 
    199 /* static */
    200 string DeviceNameUtils::ParsedNameToString(const ParsedName& pn) {
    201   string buf;
    202   if (pn.has_job) strings::StrAppend(&buf, "/job:", pn.job);
    203   if (pn.has_replica) strings::StrAppend(&buf, "/replica:", pn.replica);
    204   if (pn.has_task) strings::StrAppend(&buf, "/task:", pn.task);
    205   if (pn.has_type) {
    206     strings::StrAppend(&buf, "/device:", pn.type, ":");
    207     if (pn.has_id) {
    208       strings::StrAppend(&buf, pn.id);
    209     } else {
    210       strings::StrAppend(&buf, "*");
    211     }
    212   }
    213   return buf;
    214 }
    215 
    216 /* static */
    217 bool DeviceNameUtils::IsSpecification(const ParsedName& less_specific,
    218                                       const ParsedName& more_specific) {
    219   if (less_specific.has_job &&
    220       (!more_specific.has_job || (less_specific.job != more_specific.job))) {
    221     return false;
    222   }
    223   if (less_specific.has_replica &&
    224       (!more_specific.has_replica ||
    225        (less_specific.replica != more_specific.replica))) {
    226     return false;
    227   }
    228   if (less_specific.has_task &&
    229       (!more_specific.has_task || (less_specific.task != more_specific.task))) {
    230     return false;
    231   }
    232   if (less_specific.has_type &&
    233       (!more_specific.has_type || (less_specific.type != more_specific.type))) {
    234     return false;
    235   }
    236   if (less_specific.has_id &&
    237       (!more_specific.has_id || (less_specific.id != more_specific.id))) {
    238     return false;
    239   }
    240   return true;
    241 }
    242 
    243 /* static */
    244 bool DeviceNameUtils::IsCompleteSpecification(const ParsedName& pattern,
    245                                               const ParsedName& name) {
    246   CHECK(name.has_job && name.has_replica && name.has_task && name.has_type &&
    247         name.has_id);
    248 
    249   if (pattern.has_job && (pattern.job != name.job)) return false;
    250   if (pattern.has_replica && (pattern.replica != name.replica)) return false;
    251   if (pattern.has_task && (pattern.task != name.task)) return false;
    252   if (pattern.has_type && (pattern.type != name.type)) return false;
    253   if (pattern.has_id && (pattern.id != name.id)) return false;
    254   return true;
    255 }
    256 
    257 /* static */
    258 Status DeviceNameUtils::MergeDevNames(ParsedName* target,
    259                                       const ParsedName& other,
    260                                       bool allow_soft_placement) {
    261   if (other.has_job) {
    262     if (target->has_job && target->job != other.job) {
    263       return errors::InvalidArgument(
    264           "Cannot merge devices with incompatible jobs: '",
    265           ParsedNameToString(*target), "' and '", ParsedNameToString(other),
    266           "'");
    267     } else {
    268       target->has_job = other.has_job;
    269       target->job = other.job;
    270     }
    271   }
    272 
    273   if (other.has_replica) {
    274     if (target->has_replica && target->replica != other.replica) {
    275       return errors::InvalidArgument(
    276           "Cannot merge devices with incompatible replicas: '",
    277           ParsedNameToString(*target), "' and '", ParsedNameToString(other),
    278           "'");
    279     } else {
    280       target->has_replica = other.has_replica;
    281       target->replica = other.replica;
    282     }
    283   }
    284 
    285   if (other.has_task) {
    286     if (target->has_task && target->task != other.task) {
    287       return errors::InvalidArgument(
    288           "Cannot merge devices with incompatible tasks: '",
    289           ParsedNameToString(*target), "' and '", ParsedNameToString(other),
    290           "'");
    291     } else {
    292       target->has_task = other.has_task;
    293       target->task = other.task;
    294     }
    295   }
    296 
    297   if (other.has_type) {
    298     if (target->has_type && target->type != other.type) {
    299       if (!allow_soft_placement) {
    300         return errors::InvalidArgument(
    301             "Cannot merge devices with incompatible types: '",
    302             ParsedNameToString(*target), "' and '", ParsedNameToString(other),
    303             "'");
    304       } else {
    305         target->has_id = false;
    306         target->has_type = false;
    307         return Status::OK();
    308       }
    309     } else {
    310       target->has_type = other.has_type;
    311       target->type = other.type;
    312     }
    313   }
    314 
    315   if (other.has_id) {
    316     if (target->has_id && target->id != other.id) {
    317       if (!allow_soft_placement) {
    318         return errors::InvalidArgument(
    319             "Cannot merge devices with incompatible ids: '",
    320             ParsedNameToString(*target), "' and '", ParsedNameToString(other),
    321             "'");
    322       } else {
    323         target->has_id = false;
    324         return Status::OK();
    325       }
    326     } else {
    327       target->has_id = other.has_id;
    328       target->id = other.id;
    329     }
    330   }
    331 
    332   return Status::OK();
    333 }
    334 
    335 /* static */
    336 bool DeviceNameUtils::IsSameAddressSpace(const ParsedName& a,
    337                                          const ParsedName& b) {
    338   return (a.has_job && b.has_job && (a.job == b.job)) &&
    339          (a.has_replica && b.has_replica && (a.replica == b.replica)) &&
    340          (a.has_task && b.has_task && (a.task == b.task));
    341 }
    342 
    343 /* static */
    344 bool DeviceNameUtils::IsSameAddressSpace(StringPiece src, StringPiece dst) {
    345   ParsedName x;
    346   ParsedName y;
    347   return ParseFullName(src, &x) && ParseFullName(dst, &y) &&
    348          IsSameAddressSpace(x, y);
    349 }
    350 
    351 /* static */
    352 string DeviceNameUtils::LocalName(StringPiece type, int id) {
    353   return strings::StrCat("/device:", type, ":", id);
    354 }
    355 
    356 namespace {
    357 // Returns the legacy local device name given its "type" and "id" (which is
    358 // '/device:type:id').
    359 string LegacyLocalName(StringPiece type, int id) {
    360   return strings::StrCat(type, ":", id);
    361 }
    362 }  // anonymous namespace
    363 
    364 /* static */
    365 string DeviceNameUtils::LocalName(StringPiece fullname) {
    366   ParsedName x;
    367   CHECK(ParseFullName(fullname, &x)) << fullname;
    368   return LocalName(x.type, x.id);
    369 }
    370 
    371 /* static */
    372 bool DeviceNameUtils::ParseLocalName(StringPiece name, ParsedName* p) {
    373   if (!ConsumeDeviceType(&name, &p->type)) {
    374     return false;
    375   }
    376   p->has_type = true;
    377   if (!str_util::ConsumePrefix(&name, ":")) {
    378     return false;
    379   }
    380   if (!ConsumeNumber(&name, &p->id)) {
    381     return false;
    382   }
    383   p->has_id = true;
    384   return name.empty();
    385 }
    386 
    387 /* static */
    388 bool DeviceNameUtils::SplitDeviceName(StringPiece name, string* task,
    389                                       string* device) {
    390   ParsedName pn;
    391   if (ParseFullName(name, &pn) && pn.has_type && pn.has_id) {
    392     task->clear();
    393     task->reserve(
    394         (pn.has_job ? (5 + pn.job.size()) : 0) +
    395         (pn.has_replica ? (9 + 4 /*estimated UB for # replica digits*/) : 0) +
    396         (pn.has_task ? (6 + 4 /*estimated UB for # task digits*/) : 0));
    397     if (pn.has_job) {
    398       strings::StrAppend(task, "/job:", pn.job);
    399     }
    400     if (pn.has_replica) {
    401       strings::StrAppend(task, "/replica:", pn.replica);
    402     }
    403     if (pn.has_task) {
    404       strings::StrAppend(task, "/task:", pn.task);
    405     }
    406     device->clear();
    407     strings::StrAppend(device, pn.type, ":", pn.id);
    408     return true;
    409   }
    410   return false;
    411 }
    412 
    413 std::vector<string> DeviceNameUtils::GetNamesForDeviceMappings(
    414     const ParsedName& pn) {
    415   if (pn.has_job && pn.has_replica && pn.has_task && pn.has_type && pn.has_id) {
    416     return {
    417         DeviceNameUtils::FullName(pn.job, pn.replica, pn.task, pn.type, pn.id),
    418         LegacyName(pn.job, pn.replica, pn.task, pn.type, pn.id)};
    419   } else {
    420     return {};
    421   }
    422 }
    423 
    424 std::vector<string> DeviceNameUtils::GetLocalNamesForDeviceMappings(
    425     const ParsedName& pn) {
    426   if (pn.has_type && pn.has_id) {
    427     return {DeviceNameUtils::LocalName(pn.type, pn.id),
    428             LegacyLocalName(pn.type, pn.id)};
    429   } else {
    430     return {};
    431   }
    432 }
    433 
    434 }  // namespace tensorflow
    435