Home | History | Annotate | Download | only in framework
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <algorithm>
     17 #include <vector>
     18 
     19 #include "tensorflow/cc/framework/scope_internal.h"
     20 #include "tensorflow/core/common_runtime/shape_refiner.h"
     21 #include "tensorflow/core/framework/node_def_util.h"
     22 #include "tensorflow/core/graph/graph_constructor.h"
     23 #include "tensorflow/core/graph/node_builder.h"
     24 
     25 namespace tensorflow {
     26 
     27 Scope::Scope(Impl* impl) : impl_(impl) {}
     28 
     29 Scope::Scope(const Scope& other) : impl_(new Impl(*other.impl())) {}
     30 
     31 Scope::~Scope() {}
     32 
     33 Scope& Scope::operator=(const Scope& other) {
     34   // We can't copy Impls because of the const members, use copy ctor instead
     35   impl_.reset(new Impl(*other.impl_));
     36   return *this;
     37 }
     38 
     39 Scope::Impl::Impl(Graph* graph, Status* status, NameMap* name_map,
     40                   ShapeRefiner* refiner, bool disable_shape_inference)
     41     : graph_(graph),
     42       status_(status),
     43       name_map_(name_map),
     44       refiner_(refiner),
     45       scope_used_(nullptr),
     46       colocation_constraints_(),
     47       disable_shape_inference_(disable_shape_inference) {}
     48 
     49 Scope::Impl::Impl(const std::shared_ptr<Graph>& graph,
     50                   const std::shared_ptr<Status>& status,
     51                   const std::shared_ptr<NameMap>& name_map,
     52                   const std::shared_ptr<ShapeRefiner>& refiner)
     53     : graph_(graph),
     54       status_(status),
     55       name_map_(name_map),
     56       refiner_(refiner),
     57       scope_used_(nullptr),
     58       colocation_constraints_(),
     59       disable_shape_inference_(false) {}
     60 
     61 Scope Scope::NewRootScope() {
     62   Graph* graph = new Graph(OpRegistry::Global());
     63   ShapeRefiner* refiner =
     64       new ShapeRefiner(graph->versions(), graph->op_registry());
     65   return Scope(new Impl(graph, new Status, new Impl::NameMap, refiner,
     66                         /* disable_shape_inference */ false));
     67 }
     68 
     69 Scope Scope::DisabledShapeInferenceScope() {
     70   Graph* graph = new Graph(OpRegistry::Global());
     71   ShapeRefiner* refiner =
     72       new ShapeRefiner(graph->versions(), graph->op_registry());
     73   return Scope(new Impl(graph, new Status, new Impl::NameMap, refiner,
     74                         /* disable_shape_inference */ true));
     75 }
     76 
     77 Scope::Impl::Impl(const Scope& other, Tags::ScopeName, const string& name,
     78                   bool copy_names)
     79     : graph_(other.impl()->graph_),
     80       status_(other.impl()->status_),
     81       name_map_(copy_names ? other.impl()->name_map_
     82                            : std::shared_ptr<NameMap>(new NameMap)),
     83       refiner_(other.impl()->refiner_),
     84       scope_used_(nullptr),
     85       control_deps_(other.impl()->control_deps_),
     86       name_(name),
     87       op_name_(""),
     88       exit_on_error_(other.impl()->exit_on_error_),
     89       kernel_label_(other.impl()->kernel_label_),
     90       device_(other.impl()->device_),
     91       colocation_constraints_(other.impl()->colocation_constraints_),
     92       disable_shape_inference_(other.impl()->disable_shape_inference_) {}
     93 
     94 Scope::Impl::Impl(const Scope& other, Tags::OpName, const string& name,
     95                   const string& op_name)
     96     : graph_(other.impl()->graph_),
     97       status_(other.impl()->status_),
     98       name_map_(other.impl()->name_map_),
     99       refiner_(other.impl()->refiner_),
    100       scope_used_(other.impl()->scope_used_),
    101       control_deps_(other.impl()->control_deps_),
    102       name_(name),
    103       op_name_(op_name),
    104       exit_on_error_(other.impl()->exit_on_error_),
    105       kernel_label_(other.impl()->kernel_label_),
    106       device_(other.impl()->device_),
    107       colocation_constraints_(other.impl()->colocation_constraints_),
    108       disable_shape_inference_(other.impl()->disable_shape_inference_) {}
    109 
    110 Scope::Impl::Impl(const Scope& other, Tags::ControlDeps,
    111                   std::vector<Operation> control_deps, bool clear_control_deps)
    112     : graph_(other.impl()->graph_),
    113       status_(other.impl()->status_),
    114       name_map_(other.impl()->name_map_),
    115       refiner_(other.impl()->refiner_),
    116       scope_used_(other.impl()->scope_used_),
    117       control_deps_(
    118           clear_control_deps
    119               ? std::vector<Operation>()
    120               : (control_deps.insert(control_deps.begin(),
    121                                      other.impl()->control_deps_.begin(),
    122                                      other.impl()->control_deps_.end()),
    123                  control_deps)),
    124       name_(other.impl()->name_),
    125       op_name_(other.impl()->op_name_),
    126       exit_on_error_(other.impl()->exit_on_error_),
    127       kernel_label_(other.impl()->kernel_label_),
    128       device_(other.impl()->device_),
    129       colocation_constraints_(other.impl()->colocation_constraints_),
    130       disable_shape_inference_(other.impl()->disable_shape_inference_) {}
    131 
    132 Scope::Impl::Impl(const Scope& other, Tags::Device, const string& device)
    133     : graph_(other.impl()->graph_),
    134       status_(other.impl()->status_),
    135       name_map_(other.impl()->name_map_),
    136       refiner_(other.impl()->refiner_),
    137       scope_used_(other.impl()->scope_used_),
    138       control_deps_(other.impl()->control_deps_),
    139       name_(other.impl()->name_),
    140       op_name_(other.impl()->op_name_),
    141       exit_on_error_(other.impl()->exit_on_error_),
    142       kernel_label_(other.impl()->kernel_label_),
    143       device_(device),
    144       colocation_constraints_(other.impl()->colocation_constraints_),
    145       disable_shape_inference_(other.impl()->disable_shape_inference_) {}
    146 
    147 Scope::Impl::Impl(const Scope& other, Tags::SingleUseScope,
    148                   const string& op_name)
    149     : graph_(other.impl()->graph_),
    150       status_(other.impl()->status_),
    151       name_map_(other.impl()->name_map_),
    152       refiner_(other.impl()->refiner_),
    153       scope_used_(new bool(false)),
    154       control_deps_(other.impl()->control_deps_),
    155       name_(other.impl()->name_),
    156       op_name_(op_name),
    157       exit_on_error_(other.impl()->exit_on_error_),
    158       kernel_label_(other.impl()->kernel_label_),
    159       device_(other.impl()->device_),
    160       colocation_constraints_(other.impl()->colocation_constraints_),
    161       disable_shape_inference_(other.impl()->disable_shape_inference_) {}
    162 
    163 Scope::Impl::Impl(const Scope& other, Tags::ExitOnError)
    164     : graph_(other.impl()->graph_),
    165       status_(other.impl()->status_),
    166       name_map_(other.impl()->name_map_),
    167       refiner_(other.impl()->refiner_),
    168       scope_used_(other.impl()->scope_used_),
    169       control_deps_(other.impl()->control_deps_),
    170       name_(other.impl()->name_),
    171       op_name_(other.impl()->op_name_),
    172       exit_on_error_(true),
    173       kernel_label_(other.impl()->kernel_label_),
    174       device_(other.impl()->device_),
    175       colocation_constraints_(other.impl()->colocation_constraints_),
    176       disable_shape_inference_(other.impl()->disable_shape_inference_) {}
    177 
    178 Scope::Impl::Impl(const Scope& other, Tags::KernelLabel,
    179                   const string& kernel_label)
    180     : graph_(other.impl()->graph_),
    181       status_(other.impl()->status_),
    182       name_map_(other.impl()->name_map_),
    183       refiner_(other.impl()->refiner_),
    184       scope_used_(other.impl()->scope_used_),
    185       control_deps_(other.impl()->control_deps_),
    186       name_(other.impl()->name_),
    187       op_name_(other.impl()->op_name_),
    188       exit_on_error_(other.impl()->exit_on_error_),
    189       kernel_label_(kernel_label),
    190       device_(other.impl()->device_),
    191       colocation_constraints_(other.impl()->colocation_constraints_),
    192       disable_shape_inference_(other.impl()->disable_shape_inference_) {}
    193 
    194 Scope::Impl::Impl(const Scope& other, Tags::Colocate,
    195                   const Operation& colocate_with_op, bool clear_colocations)
    196     : graph_(other.impl()->graph_),
    197       status_(other.impl()->status_),
    198       name_map_(other.impl()->name_map_),
    199       refiner_(other.impl()->refiner_),
    200       scope_used_(other.impl()->scope_used_),
    201       control_deps_(other.impl()->control_deps_),
    202       name_(other.impl()->name_),
    203       op_name_(other.impl()->op_name_),
    204       exit_on_error_(other.impl()->exit_on_error_),
    205       kernel_label_(other.impl()->kernel_label_),
    206       device_(other.impl()->device_),
    207       colocation_constraints_(
    208           clear_colocations
    209               ? std::unordered_set<string>()
    210               : other.impl()->GetColocationConstraints(colocate_with_op)),
    211       disable_shape_inference_(other.impl()->disable_shape_inference_) {}
    212 
    213 std::unordered_set<string> Scope::Impl::GetColocationConstraints(
    214     const Operation& colocate_with_op) const {
    215   std::unordered_set<string> current_constraints(colocation_constraints_);
    216   const AttrSlice attrs = colocate_with_op.node()->attrs();
    217   std::vector<string> node_constraints;
    218   if (GetNodeAttr(attrs, kColocationAttrName, &node_constraints).ok()) {
    219     for (const string& entry : node_constraints) {
    220       StringPiece s(entry);
    221       if (s.Consume(kColocationGroupPrefix)) {
    222         current_constraints.insert(s.ToString());
    223       }
    224     }
    225   } else {
    226     current_constraints.insert(colocate_with_op.node()->name());
    227   }
    228   return current_constraints;
    229 }
    230 
    231 bool Scope::ok() const { return impl()->status_->ok(); }
    232 
    233 Graph* Scope::graph() const { return impl()->graph_.get(); }
    234 
    235 std::shared_ptr<Graph> Scope::graph_as_shared_ptr() const {
    236   return impl()->graph_;
    237 }
    238 
    239 Status Scope::status() const { return *impl()->status_; }
    240 
    241 const std::vector<Operation>& Scope::control_deps() const {
    242   return impl()->control_deps_;
    243 }
    244 
    245 void Scope::UpdateStatus(const Status s) const {
    246   impl()->status_->Update(s);
    247   if (impl()->exit_on_error_ && !ok()) {
    248     LOG(FATAL) << *impl()->status_;
    249   }
    250 }
    251 
    252 Status Scope::ToGraphDef(GraphDef* gdef) const {
    253   if (!ok()) {
    254     return *impl()->status_;
    255   }
    256   graph()->ToGraphDef(gdef);
    257   return Status::OK();
    258 }
    259 
    260 Status Scope::ToGraph(Graph* g) const {
    261   if (ok()) {
    262     GraphDef graph_def;
    263     graph()->ToGraphDef(&graph_def);
    264     GraphConstructorOptions opts;
    265     UpdateStatus(ConvertGraphDefToGraph(opts, graph_def, g));
    266   }
    267   return *impl()->status_;
    268 }
    269 
    270 void Scope::UpdateBuilder(NodeBuilder* builder) const {
    271   std::vector<Node*> control_inputs;
    272   for (const auto& op : impl()->control_deps_) {
    273     control_inputs.push_back(op.node());
    274   }
    275   builder->ControlInputs(control_inputs);
    276 
    277   if (!impl()->kernel_label_.empty()) {
    278     builder->Attr("_kernel", impl()->kernel_label_);
    279   }
    280 
    281   if (!impl()->colocation_constraints_.empty()) {
    282     std::vector<string> constraints(impl()->colocation_constraints_.begin(),
    283                                     impl()->colocation_constraints_.end());
    284     // Sort the set.
    285     std::sort(constraints.begin(), constraints.end());
    286     // Add loc:@ prefix
    287     std::transform(constraints.begin(), constraints.end(), constraints.begin(),
    288                    [](const string& s) {
    289                      return strings::StrCat(kColocationGroupPrefix, s);
    290                    });
    291     builder->Attr(kColocationAttrName, constraints);
    292   }
    293   if (!impl()->device_.empty()) {
    294     builder->Device(impl()->device_);
    295   }
    296 }
    297 
    298 string Scope::Impl::GetUniqueName(const string& prefix,
    299                                   bool check_single_use) const {
    300   if (check_single_use && single_use_scope()) {
    301     if (*scope_used_) {
    302       *status_ =
    303           errors::AlreadyExists(prefix, " already exists in the current scope");
    304       return "";
    305     }
    306     *scope_used_ = true;
    307     return prefix;
    308   }
    309   auto entry = name_map_->find(prefix);
    310   string unique_name = prefix;
    311   if (entry == name_map_->end()) {
    312     name_map_->insert({prefix, 0});
    313   } else {
    314     unique_name = strings::StrCat(unique_name, "_", ++entry->second);
    315   }
    316   return unique_name;
    317 }
    318 
    319 string Scope::Impl::GetNameForOp(const string& default_name) const {
    320   const string unique_name =
    321       GetUniqueName(default_name, true /* check_single_use */);
    322   const string sep = name_.empty() || unique_name.empty() ? "" : "/";
    323   return strings::StrCat(name_, sep, unique_name);
    324 }
    325 
    326 string Scope::GetUniqueNameForOp(const string& default_name) const {
    327   if (impl()->single_use_scope()) {
    328     if (impl()->op_name_.empty() || *impl()->scope_used_) {
    329       *impl()->status_ =
    330           errors::InvalidArgument("Cannot get a unique name in this scope");
    331       return "";
    332     }
    333     *impl()->scope_used_ = true;
    334     return impl()->op_name_;
    335   }
    336   return impl()->op_name_.empty() ? impl()->GetNameForOp(default_name)
    337                                   : impl()->GetNameForOp(impl()->op_name_);
    338 }
    339 
    340 Scope Scope::NewSubScope(const string& child_scope_name) const {
    341   if (child_scope_name.empty()) {
    342     return Scope(new Impl(*this, Impl::Tags::ScopeName(), impl()->name_,
    343                           true /* copy_names */));
    344   }
    345   const string unique_name =
    346       impl()->GetUniqueName(child_scope_name, false /* check_single_use */);
    347   const string sep = impl()->name_.empty() || unique_name.empty() ? "" : "/";
    348   return Scope(new Impl(*this, Impl::Tags::ScopeName(),
    349                         strings::StrCat(impl()->name_, sep, unique_name),
    350                         false /* copy_names */));
    351 }
    352 
    353 Scope Scope::WithOpName(const string& op_name) const {
    354   if (impl()->single_use_scope()) {
    355     UpdateStatus(errors::InvalidArgument("Cannot set op name ", op_name,
    356                                          " on this scope"));
    357     return *this;
    358   }
    359   return Scope(new Impl(*this, Impl::Tags::OpName(), impl()->name_, op_name));
    360 }
    361 
    362 Scope Scope::WithControlDependencies(
    363     const gtl::ArraySlice<Operation>& control_deps) const {
    364   return Scope(
    365       new Impl(*this, Impl::Tags::ControlDeps(),
    366                std::vector<Operation>(control_deps.begin(), control_deps.end()),
    367                /* clear_control_deps */ false));
    368 }
    369 
    370 Scope Scope::WithControlDependencies(const Output& control_dep) const {
    371   return Scope(new Impl(*this, Impl::Tags::ControlDeps(),
    372                         std::vector<Operation>(1, control_dep.op()),
    373                         /* clear_control_deps */ false));
    374 }
    375 
    376 Scope Scope::WithNoControlDependencies() const {
    377   return Scope(new Impl(*this, Impl::Tags::ControlDeps(),
    378                         std::vector<Operation>(),
    379                         /* clear_control_deps */ true));
    380 }
    381 
    382 Scope Scope::WithDevice(const string& device) const {
    383   return Scope(new Impl(*this, Impl::Tags::Device(), device));
    384 }
    385 
    386 Scope Scope::ColocateWith(const Operation& op) const {
    387   return Scope(new Impl(*this, Impl::Tags::Colocate(), op,
    388                         /* clear_colocations */ false));
    389 }
    390 
    391 Scope Scope::ClearColocation() const {
    392   return Scope(new Impl(*this, Impl::Tags::Colocate(), Operation(),
    393                         /* clear_colocations */ true));
    394 }
    395 
    396 Scope Scope::ExitOnError() const {
    397   return Scope(new Impl(*this, Impl::Tags::ExitOnError()));
    398 }
    399 
    400 Scope Scope::WithKernelLabel(const string& kernel_label) const {
    401   return Scope(new Impl(*this, Impl::Tags::KernelLabel(), kernel_label));
    402 }
    403 
    404 CompositeOpScopes Scope::GetCompositeOpScopes(
    405     const string& composite_op_name) const {
    406   if (impl()->op_name_.empty() && composite_op_name.empty()) {
    407     UpdateStatus(errors::InvalidArgument(
    408         "Cannot create composite op scopes with empty name"));
    409     return {*this, *this};
    410   }
    411   if (!impl()->single_use_scope()) {
    412     Scope child = NewSubScope(impl()->op_name_.empty() ? composite_op_name
    413                                                        : impl()->op_name_);
    414     const string child_op_sep = impl()->name_.empty() ? "" : "_";
    415     const string child_name =
    416         strings::StrCat(impl()->name_, child_op_sep, child.impl()->name_);
    417     return {child,
    418             Scope(new Impl(child, Impl::Tags::SingleUseScope(), child_name))};
    419   } else {
    420     return {Scope(new Impl(*this, Impl::Tags::ScopeName(), impl()->op_name_,
    421                            true /* copy_names */)),
    422             *this};
    423   }
    424 }
    425 
    426 Status Scope::DoShapeInference(Node* node) const {
    427   if (impl_->disable_shape_inference_) return Status::OK();
    428   return impl_->refiner_->AddNode(node);
    429 }
    430 
    431 class InternalScope {
    432  public:
    433   // NewScope doesn't take ownership of the inputs.
    434   static Scope NewScope(Graph* graph, Status* status, ShapeRefiner* refiner) {
    435     Scope::Impl::NameMap* name_map = new Scope::Impl::NameMap;
    436     for (const Node* node : graph->nodes()) {
    437       (*name_map)[node->name()] = 0;
    438     }
    439     // We provide null destructors for these shared ptrs (except for name_map)
    440     // since the caller owns them and doesn't want the scope to destroy them.
    441     return Scope(new Scope::Impl(
    442         std::shared_ptr<Graph>(graph, [](Graph*) {}),
    443         std::shared_ptr<Status>(status, [](Status*) {}),
    444         std::shared_ptr<Scope::Impl::NameMap>(name_map),
    445         std::shared_ptr<ShapeRefiner>(refiner, [](ShapeRefiner*) {})));
    446   }
    447 };
    448 
    449 Scope NewInternalScope(Graph* graph, Status* status, ShapeRefiner* refiner) {
    450   return InternalScope::NewScope(graph, status, refiner);
    451 }
    452 
    453 }  // namespace tensorflow
    454