Home | History | Annotate | Download | only in framework
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <algorithm>
     17 #include <vector>
     18 
     19 #include "tensorflow/cc/framework/scope_internal.h"
     20 #include "tensorflow/core/common_runtime/shape_refiner.h"
     21 #include "tensorflow/core/framework/node_def_util.h"
     22 #include "tensorflow/core/graph/node_builder.h"
     23 #include "tensorflow/core/lib/strings/str_util.h"
     24 
     25 namespace tensorflow {
     26 
     27 Scope::Scope(Impl* impl) : impl_(impl) {}
     28 
     29 Scope::Scope(const Scope& other) : impl_(new Impl(*other.impl())) {}
     30 
     31 Scope::~Scope() {}
     32 
     33 Scope& Scope::operator=(const Scope& other) {
     34   // We can't copy Impls because of the const members, use copy ctor instead
     35   impl_.reset(new Impl(*other.impl_));
     36   return *this;
     37 }
     38 
     39 namespace {
     40 const char kScopeSeparator[] = "/";
     41 const char kSuffixSeparator[] = "_";
     42 }  // namespace
     43 
     44 Scope::Impl::Impl(Graph* graph, Status* status, NameMap* name_map,
     45                   ShapeRefiner* refiner, bool disable_shape_inference)
     46     : graph_(graph),
     47       status_(status),
     48       name_map_(name_map),
     49       refiner_(refiner),
     50       scope_used_(nullptr),
     51       colocation_constraints_(),
     52       disable_shape_inference_(disable_shape_inference) {}
     53 
     54 Scope::Impl::Impl(const std::shared_ptr<Graph>& graph,
     55                   const std::shared_ptr<Status>& status,
     56                   const std::shared_ptr<NameMap>& name_map,
     57                   const std::shared_ptr<ShapeRefiner>& refiner)
     58     : graph_(graph),
     59       status_(status),
     60       name_map_(name_map),
     61       refiner_(refiner),
     62       scope_used_(nullptr),
     63       colocation_constraints_(),
     64       disable_shape_inference_(refiner_ == nullptr) {}
     65 
     66 Scope Scope::NewRootScope() {
     67   Graph* graph = new Graph(OpRegistry::Global());
     68   ShapeRefiner* refiner =
     69       new ShapeRefiner(graph->versions(), graph->op_registry());
     70   return Scope(new Impl(graph, new Status, new Impl::NameMap, refiner,
     71                         /* disable_shape_inference */ false));
     72 }
     73 
     74 Scope Scope::DisabledShapeInferenceScope() {
     75   Graph* graph = new Graph(OpRegistry::Global());
     76   ShapeRefiner* refiner =
     77       new ShapeRefiner(graph->versions(), graph->op_registry());
     78   return Scope(new Impl(graph, new Status, new Impl::NameMap, refiner,
     79                         /* disable_shape_inference */ true));
     80 }
     81 
     82 Scope::Impl::Impl(const Scope& other, Tags::ScopeName, const string& name,
     83                   bool copy_names)
     84     : graph_(other.impl()->graph_),
     85       status_(other.impl()->status_),
     86       name_map_(copy_names ? other.impl()->name_map_
     87                            : std::shared_ptr<NameMap>(new NameMap)),
     88       refiner_(other.impl()->refiner_),
     89       scope_used_(nullptr),
     90       control_deps_(other.impl()->control_deps_),
     91       name_(name),
     92       op_name_(""),
     93       exit_on_error_(other.impl()->exit_on_error_),
     94       kernel_label_(other.impl()->kernel_label_),
     95       device_(other.impl()->device_),
     96       assigned_device_(other.impl()->assigned_device_),
     97       xla_cluster_(other.impl()->xla_cluster_),
     98       colocation_constraints_(other.impl()->colocation_constraints_),
     99       disable_shape_inference_(other.impl()->disable_shape_inference_) {}
    100 
    101 Scope::Impl::Impl(const Scope& other, Tags::OpName, const string& name,
    102                   const string& op_name)
    103     : graph_(other.impl()->graph_),
    104       status_(other.impl()->status_),
    105       name_map_(other.impl()->name_map_),
    106       refiner_(other.impl()->refiner_),
    107       scope_used_(other.impl()->scope_used_),
    108       control_deps_(other.impl()->control_deps_),
    109       name_(name),
    110       op_name_(op_name),
    111       exit_on_error_(other.impl()->exit_on_error_),
    112       kernel_label_(other.impl()->kernel_label_),
    113       device_(other.impl()->device_),
    114       assigned_device_(other.impl()->assigned_device_),
    115       xla_cluster_(other.impl()->xla_cluster_),
    116       colocation_constraints_(other.impl()->colocation_constraints_),
    117       disable_shape_inference_(other.impl()->disable_shape_inference_) {}
    118 
    119 Scope::Impl::Impl(const Scope& other, Tags::ControlDeps,
    120                   std::vector<Operation> control_deps, bool clear_control_deps)
    121     : graph_(other.impl()->graph_),
    122       status_(other.impl()->status_),
    123       name_map_(other.impl()->name_map_),
    124       refiner_(other.impl()->refiner_),
    125       scope_used_(other.impl()->scope_used_),
    126       control_deps_(
    127           clear_control_deps
    128               ? std::vector<Operation>()
    129               : (control_deps.insert(control_deps.begin(),
    130                                      other.impl()->control_deps_.begin(),
    131                                      other.impl()->control_deps_.end()),
    132                  control_deps)),
    133       name_(other.impl()->name_),
    134       op_name_(other.impl()->op_name_),
    135       exit_on_error_(other.impl()->exit_on_error_),
    136       kernel_label_(other.impl()->kernel_label_),
    137       device_(other.impl()->device_),
    138       assigned_device_(other.impl()->assigned_device_),
    139       xla_cluster_(other.impl()->xla_cluster_),
    140       colocation_constraints_(other.impl()->colocation_constraints_),
    141       disable_shape_inference_(other.impl()->disable_shape_inference_) {}
    142 
    143 Scope::Impl::Impl(const Scope& other, Tags::Device, const string& device)
    144     : graph_(other.impl()->graph_),
    145       status_(other.impl()->status_),
    146       name_map_(other.impl()->name_map_),
    147       refiner_(other.impl()->refiner_),
    148       scope_used_(other.impl()->scope_used_),
    149       control_deps_(other.impl()->control_deps_),
    150       name_(other.impl()->name_),
    151       op_name_(other.impl()->op_name_),
    152       exit_on_error_(other.impl()->exit_on_error_),
    153       kernel_label_(other.impl()->kernel_label_),
    154       device_(device),
    155       assigned_device_(other.impl()->assigned_device_),
    156       xla_cluster_(other.impl()->xla_cluster_),
    157       colocation_constraints_(other.impl()->colocation_constraints_),
    158       disable_shape_inference_(other.impl()->disable_shape_inference_) {}
    159 
    160 Scope::Impl::Impl(const Scope& other, Tags::SingleUseScope,
    161                   const string& op_name)
    162     : graph_(other.impl()->graph_),
    163       status_(other.impl()->status_),
    164       name_map_(other.impl()->name_map_),
    165       refiner_(other.impl()->refiner_),
    166       scope_used_(new bool(false)),
    167       control_deps_(other.impl()->control_deps_),
    168       name_(other.impl()->name_),
    169       op_name_(op_name),
    170       exit_on_error_(other.impl()->exit_on_error_),
    171       kernel_label_(other.impl()->kernel_label_),
    172       device_(other.impl()->device_),
    173       assigned_device_(other.impl()->assigned_device_),
    174       xla_cluster_(other.impl()->xla_cluster_),
    175       colocation_constraints_(other.impl()->colocation_constraints_),
    176       disable_shape_inference_(other.impl()->disable_shape_inference_) {}
    177 
    178 Scope::Impl::Impl(const Scope& other, Tags::ExitOnError)
    179     : graph_(other.impl()->graph_),
    180       status_(other.impl()->status_),
    181       name_map_(other.impl()->name_map_),
    182       refiner_(other.impl()->refiner_),
    183       scope_used_(other.impl()->scope_used_),
    184       control_deps_(other.impl()->control_deps_),
    185       name_(other.impl()->name_),
    186       op_name_(other.impl()->op_name_),
    187       exit_on_error_(true),
    188       kernel_label_(other.impl()->kernel_label_),
    189       device_(other.impl()->device_),
    190       assigned_device_(other.impl()->assigned_device_),
    191       xla_cluster_(other.impl()->xla_cluster_),
    192       colocation_constraints_(other.impl()->colocation_constraints_),
    193       disable_shape_inference_(other.impl()->disable_shape_inference_) {}
    194 
    195 Scope::Impl::Impl(const Scope& other, Tags::KernelLabel,
    196                   const string& kernel_label)
    197     : graph_(other.impl()->graph_),
    198       status_(other.impl()->status_),
    199       name_map_(other.impl()->name_map_),
    200       refiner_(other.impl()->refiner_),
    201       scope_used_(other.impl()->scope_used_),
    202       control_deps_(other.impl()->control_deps_),
    203       name_(other.impl()->name_),
    204       op_name_(other.impl()->op_name_),
    205       exit_on_error_(other.impl()->exit_on_error_),
    206       kernel_label_(kernel_label),
    207       device_(other.impl()->device_),
    208       assigned_device_(other.impl()->assigned_device_),
    209       xla_cluster_(other.impl()->xla_cluster_),
    210       colocation_constraints_(other.impl()->colocation_constraints_),
    211       disable_shape_inference_(other.impl()->disable_shape_inference_) {}
    212 
    213 Scope::Impl::Impl(const Scope& other, Tags::Colocate,
    214                   const Operation& colocate_with_op, bool clear_colocations)
    215     : graph_(other.impl()->graph_),
    216       status_(other.impl()->status_),
    217       name_map_(other.impl()->name_map_),
    218       refiner_(other.impl()->refiner_),
    219       scope_used_(other.impl()->scope_used_),
    220       control_deps_(other.impl()->control_deps_),
    221       name_(other.impl()->name_),
    222       op_name_(other.impl()->op_name_),
    223       exit_on_error_(other.impl()->exit_on_error_),
    224       kernel_label_(other.impl()->kernel_label_),
    225       device_(other.impl()->device_),
    226       assigned_device_(other.impl()->assigned_device_),
    227       xla_cluster_(other.impl()->xla_cluster_),
    228       colocation_constraints_(
    229           clear_colocations
    230               ? std::unordered_set<string>()
    231               : other.impl()->GetColocationConstraints(colocate_with_op)),
    232       disable_shape_inference_(other.impl()->disable_shape_inference_) {}
    233 
    234 Scope::Impl::Impl(const Scope& other, Tags::AssignedDevice,
    235                   const string& assigned_device)
    236     : graph_(other.impl()->graph_),
    237       status_(other.impl()->status_),
    238       name_map_(other.impl()->name_map_),
    239       refiner_(other.impl()->refiner_),
    240       scope_used_(other.impl()->scope_used_),
    241       control_deps_(other.impl()->control_deps_),
    242       name_(other.impl()->name_),
    243       op_name_(other.impl()->op_name_),
    244       exit_on_error_(other.impl()->exit_on_error_),
    245       kernel_label_(other.impl()->kernel_label_),
    246       device_(other.impl()->device_),
    247       assigned_device_(assigned_device),
    248       xla_cluster_(other.impl()->xla_cluster_),
    249       colocation_constraints_(other.impl()->colocation_constraints_),
    250       disable_shape_inference_(other.impl()->disable_shape_inference_) {}
    251 
    252 Scope::Impl::Impl(const Scope& other, Tags::XlaCluster,
    253                   const string& xla_cluster)
    254     : graph_(other.impl()->graph_),
    255       status_(other.impl()->status_),
    256       name_map_(other.impl()->name_map_),
    257       refiner_(other.impl()->refiner_),
    258       scope_used_(other.impl()->scope_used_),
    259       control_deps_(other.impl()->control_deps_),
    260       name_(other.impl()->name_),
    261       op_name_(other.impl()->op_name_),
    262       exit_on_error_(other.impl()->exit_on_error_),
    263       kernel_label_(other.impl()->kernel_label_),
    264       device_(other.impl()->device_),
    265       assigned_device_(other.impl()->assigned_device_),
    266       xla_cluster_(xla_cluster),
    267       colocation_constraints_(other.impl()->colocation_constraints_),
    268       disable_shape_inference_(other.impl()->disable_shape_inference_) {}
    269 
    270 std::unordered_set<string> Scope::Impl::GetColocationConstraints(
    271     const Operation& colocate_with_op) const {
    272   std::unordered_set<string> current_constraints(colocation_constraints_);
    273   const AttrSlice attrs = colocate_with_op.node()->attrs();
    274   std::vector<string> node_constraints;
    275   if (GetNodeAttr(attrs, kColocationAttrName, &node_constraints).ok()) {
    276     for (const string& entry : node_constraints) {
    277       StringPiece s(entry);
    278       if (str_util::ConsumePrefix(&s, kColocationGroupPrefix)) {
    279         current_constraints.emplace(s);
    280       }
    281     }
    282   } else {
    283     current_constraints.insert(colocate_with_op.node()->name());
    284   }
    285   return current_constraints;
    286 }
    287 
    288 bool Scope::ok() const { return impl()->status_->ok(); }
    289 
    290 Graph* Scope::graph() const { return impl()->graph_.get(); }
    291 
    292 std::shared_ptr<Graph> Scope::graph_as_shared_ptr() const {
    293   return impl()->graph_;
    294 }
    295 
    296 Status Scope::status() const { return *impl()->status_; }
    297 
    298 const std::vector<Operation>& Scope::control_deps() const {
    299   return impl()->control_deps_;
    300 }
    301 
    302 void Scope::UpdateStatus(const Status s) const {
    303   impl()->status_->Update(s);
    304   if (impl()->exit_on_error_ && !ok()) {
    305     LOG(FATAL) << *impl()->status_;
    306   }
    307 }
    308 
    309 Status Scope::ToGraphDef(GraphDef* gdef) const {
    310   if (!ok()) {
    311     return *impl()->status_;
    312   }
    313   graph()->ToGraphDef(gdef);
    314   return Status::OK();
    315 }
    316 
    317 Status Scope::ToGraph(Graph* g, GraphConstructorOptions opts) const {
    318   if (ok()) {
    319     GraphDef graph_def;
    320     graph()->ToGraphDef(&graph_def);
    321     UpdateStatus(ConvertGraphDefToGraph(opts, graph_def, g));
    322   }
    323   return *impl()->status_;
    324 }
    325 
    326 void Scope::UpdateBuilder(NodeBuilder* builder) const {
    327   std::vector<Node*> control_inputs;
    328   for (const auto& op : impl()->control_deps_) {
    329     control_inputs.push_back(op.node());
    330   }
    331   builder->ControlInputs(control_inputs);
    332 
    333   if (!impl()->kernel_label_.empty()) {
    334     builder->Attr("_kernel", impl()->kernel_label_);
    335   }
    336 
    337   if (!impl()->colocation_constraints_.empty()) {
    338     std::vector<string> constraints(impl()->colocation_constraints_.begin(),
    339                                     impl()->colocation_constraints_.end());
    340     // Sort the set.
    341     std::sort(constraints.begin(), constraints.end());
    342     // Add loc:@ prefix
    343     std::transform(constraints.begin(), constraints.end(), constraints.begin(),
    344                    [](const string& s) {
    345                      return strings::StrCat(kColocationGroupPrefix, s);
    346                    });
    347     builder->Attr(kColocationAttrName, constraints);
    348   }
    349   if (!impl()->device_.empty()) {
    350     builder->Device(impl()->device_);
    351   }
    352   if (!impl()->assigned_device_.empty()) {
    353     builder->AssignedDevice(impl()->assigned_device_);
    354   }
    355   if (!impl()->xla_cluster_.empty()) {
    356     builder->XlaCluster(impl()->xla_cluster_);
    357   }
    358 }
    359 
    360 string Scope::Impl::GetUniqueName(const string& prefix,
    361                                   bool check_single_use) const {
    362   if (check_single_use && single_use_scope()) {
    363     if (*scope_used_) {
    364       *status_ =
    365           errors::AlreadyExists(prefix, " already exists in the current scope");
    366       return "";
    367     }
    368     *scope_used_ = true;
    369     return prefix;
    370   }
    371   auto entry = name_map_->find(prefix);
    372   if (entry == name_map_->end()) {
    373     name_map_->insert({prefix, 0});
    374     return prefix;
    375   }
    376   string unique_name;
    377   do {
    378     unique_name = strings::StrCat(prefix, kSuffixSeparator, ++entry->second);
    379   } while (name_map_->find(unique_name) != name_map_->end());
    380   name_map_->insert({unique_name, 0});
    381   return unique_name;
    382 }
    383 
    384 string Scope::Impl::GetNameForOp(const string& default_name) const {
    385   const string unique_name =
    386       GetUniqueName(default_name, true /* check_single_use */);
    387   const string sep =
    388       name_.empty() || unique_name.empty() ? "" : kScopeSeparator;
    389   return strings::StrCat(name_, sep, unique_name);
    390 }
    391 
    392 string Scope::GetUniqueNameForOp(const string& default_name) const {
    393   if (impl()->single_use_scope()) {
    394     if (impl()->op_name_.empty() || *impl()->scope_used_) {
    395       *impl()->status_ =
    396           errors::InvalidArgument("Cannot get a unique name in this scope");
    397       return "";
    398     }
    399     *impl()->scope_used_ = true;
    400     return impl()->op_name_;
    401   }
    402   return impl()->op_name_.empty() ? impl()->GetNameForOp(default_name)
    403                                   : impl()->GetNameForOp(impl()->op_name_);
    404 }
    405 
    406 Scope Scope::NewSubScope(const string& child_scope_name) const {
    407   if (child_scope_name.empty()) {
    408     return Scope(new Impl(*this, Impl::Tags::ScopeName(), impl()->name_,
    409                           true /* copy_names */));
    410   }
    411   const string unique_name =
    412       impl()->GetUniqueName(child_scope_name, false /* check_single_use */);
    413   const string sep =
    414       impl()->name_.empty() || unique_name.empty() ? "" : kScopeSeparator;
    415   return Scope(new Impl(*this, Impl::Tags::ScopeName(),
    416                         strings::StrCat(impl()->name_, sep, unique_name),
    417                         false /* copy_names */));
    418 }
    419 
    420 Scope Scope::WithOpNameImpl(const string& op_name) const {
    421   if (impl()->single_use_scope()) {
    422     UpdateStatus(errors::InvalidArgument("Cannot set op name ", op_name,
    423                                          " on this scope"));
    424     return *this;
    425   }
    426   return Scope(new Impl(*this, Impl::Tags::OpName(), impl()->name_, op_name));
    427 }
    428 
    429 Scope Scope::WithControlDependencies(
    430     const gtl::ArraySlice<Operation>& control_deps) const {
    431   return Scope(
    432       new Impl(*this, Impl::Tags::ControlDeps(),
    433                std::vector<Operation>(control_deps.begin(), control_deps.end()),
    434                /* clear_control_deps */ false));
    435 }
    436 
    437 Scope Scope::WithControlDependencies(const Output& control_dep) const {
    438   return Scope(new Impl(*this, Impl::Tags::ControlDeps(),
    439                         std::vector<Operation>(1, control_dep.op()),
    440                         /* clear_control_deps */ false));
    441 }
    442 
    443 Scope Scope::WithNoControlDependencies() const {
    444   return Scope(new Impl(*this, Impl::Tags::ControlDeps(),
    445                         std::vector<Operation>(),
    446                         /* clear_control_deps */ true));
    447 }
    448 
    449 Scope Scope::WithDevice(const string& device) const {
    450   return Scope(new Impl(*this, Impl::Tags::Device(), device));
    451 }
    452 
    453 Scope Scope::WithAssignedDevice(const string& assigned_device) const {
    454   return Scope(new Impl(*this, Impl::Tags::AssignedDevice(), assigned_device));
    455 }
    456 
    457 Scope Scope::WithXlaCluster(const string& xla_cluster) const {
    458   return Scope(new Impl(*this, Impl::Tags::XlaCluster(), xla_cluster));
    459 }
    460 
    461 Scope Scope::ColocateWith(const Operation& op) const {
    462   return Scope(new Impl(*this, Impl::Tags::Colocate(), op,
    463                         /* clear_colocations */ false));
    464 }
    465 
    466 Scope Scope::ClearColocation() const {
    467   return Scope(new Impl(*this, Impl::Tags::Colocate(), Operation(),
    468                         /* clear_colocations */ true));
    469 }
    470 
    471 Scope Scope::ExitOnError() const {
    472   return Scope(new Impl(*this, Impl::Tags::ExitOnError()));
    473 }
    474 
    475 Scope Scope::WithKernelLabel(const string& kernel_label) const {
    476   return Scope(new Impl(*this, Impl::Tags::KernelLabel(), kernel_label));
    477 }
    478 
    479 CompositeOpScopes Scope::GetCompositeOpScopes(
    480     const string& composite_op_name) const {
    481   if (impl()->op_name_.empty() && composite_op_name.empty()) {
    482     UpdateStatus(errors::InvalidArgument(
    483         "Cannot create composite op scopes with empty name"));
    484     return {*this, *this};
    485   }
    486   if (!impl()->single_use_scope()) {
    487     Scope child = NewSubScope(impl()->op_name_.empty() ? composite_op_name
    488                                                        : impl()->op_name_);
    489     const string child_op_sep = impl()->name_.empty() ? "" : kSuffixSeparator;
    490     const string child_name =
    491         strings::StrCat(impl()->name_, child_op_sep, child.impl()->name_);
    492     return {child,
    493             Scope(new Impl(child, Impl::Tags::SingleUseScope(), child_name))};
    494   } else {
    495     return {Scope(new Impl(*this, Impl::Tags::ScopeName(), impl()->op_name_,
    496                            true /* copy_names */)),
    497             *this};
    498   }
    499 }
    500 
    501 Status Scope::DoShapeInference(Node* node) const {
    502   if (impl_->disable_shape_inference_) return Status::OK();
    503   return impl_->refiner_->AddNode(node);
    504 }
    505 
    506 class InternalScope {
    507  public:
    508   // NewScope doesn't take ownership of the inputs.
    509   static Scope NewScope(Graph* graph, Status* status, ShapeRefiner* refiner) {
    510     Scope::Impl::NameMap* name_map = new Scope::Impl::NameMap;
    511     for (const Node* node : graph->nodes()) {
    512       const string& name = node->name();
    513       (*name_map)[name] = 0;
    514       // Add all name prefixes ('/' separated).
    515       size_t idx = -1;
    516       while ((idx = name.find(kScopeSeparator, idx + 1)) != string::npos) {
    517         (*name_map)[name.substr(0, idx)] = 0;
    518       }
    519     }
    520     // We provide null destructors for these shared ptrs (except for name_map)
    521     // since the caller owns them and doesn't want the scope to destroy them.
    522     return Scope(new Scope::Impl(
    523         std::shared_ptr<Graph>(graph, [](Graph*) {}),
    524         std::shared_ptr<Status>(status, [](Status*) {}),
    525         std::shared_ptr<Scope::Impl::NameMap>(name_map),
    526         std::shared_ptr<ShapeRefiner>(refiner, [](ShapeRefiner*) {})));
    527   }
    528 };
    529 
    530 Scope NewInternalScope(Graph* graph, Status* status, ShapeRefiner* refiner) {
    531   return InternalScope::NewScope(graph, status, refiner);
    532 }
    533 
    534 }  // namespace tensorflow
    535