Home | History | Annotate | Download | only in framework
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/framework/fake_input.h"
     17 
     18 #include <vector>
     19 #include "tensorflow/core/framework/attr_value.pb.h"
     20 #include "tensorflow/core/framework/node_def_util.h"
     21 #include "tensorflow/core/framework/op_def.pb.h"
     22 #include "tensorflow/core/framework/op_def_util.h"
     23 #include "tensorflow/core/lib/core/errors.h"
     24 #include "tensorflow/core/lib/core/status.h"
     25 
     26 namespace tensorflow {
     27 namespace {
     28 
     29 class FakeInputImpl {
     30  public:
     31   FakeInputImpl(const OpDef* op_def, int in_index, const NodeDef* node_def,
     32                 NodeDefBuilder* builder);
     33   void SetN(int n);
     34   void SetDataType(DataType dt);
     35   void SetTypeList(DataTypeSlice dts);
     36   Status AddInputToBuilder();
     37 
     38  private:
     39   static string FakeNodeName(int in_index);
     40   Status GetN(int* n) const;
     41   Status GetDataType(DataType* dt) const;
     42   void NSources(int n, DataType dt) const;
     43   void SourceList(DataTypeSlice dts) const;
     44 
     45   const OpDef* const op_def_;
     46   const OpDef::ArgDef* const arg_;
     47   const string in_node_;
     48   const NodeDef* const node_def_;
     49   NodeDefBuilder* const builder_;
     50 
     51   bool n_specified_;
     52   int n_;
     53   bool dt_specified_;
     54   DataType dt_;
     55   bool dts_specified_;
     56   DataTypeSlice dts_;
     57 };
     58 
     59 FakeInputImpl::FakeInputImpl(const OpDef* op_def, int in_index,
     60                              const NodeDef* node_def, NodeDefBuilder* builder)
     61     : op_def_(op_def),
     62       arg_(&op_def->input_arg(in_index)),
     63       in_node_(FakeNodeName(in_index)),
     64       node_def_(node_def),
     65       builder_(builder),
     66       n_specified_(false),
     67       dt_specified_(false),
     68       dts_specified_(false) {}
     69 
     70 void FakeInputImpl::SetN(int n) {
     71   n_specified_ = true;
     72   n_ = n;
     73 }
     74 
     75 void FakeInputImpl::SetDataType(DataType dt) {
     76   dt_specified_ = true;
     77   dt_ = dt;
     78 }
     79 
     80 void FakeInputImpl::SetTypeList(DataTypeSlice dts) {
     81   dts_specified_ = true;
     82   dts_ = dts;
     83 }
     84 
     85 Status FakeInputImpl::AddInputToBuilder() {
     86   if (dts_specified_) {
     87     SourceList(dts_);
     88 
     89   } else if (n_specified_ || !arg_->number_attr().empty()) {
     90     int n;
     91     TF_RETURN_IF_ERROR(GetN(&n));
     92 
     93     DataType dt;
     94     if (n > 0) {
     95       TF_RETURN_IF_ERROR(GetDataType(&dt));
     96     } else {
     97       dt = DT_FLOAT;
     98     }
     99 
    100     NSources(n, dt);
    101   } else {
    102     if (!dt_specified_ && !arg_->type_list_attr().empty()) {
    103       DataTypeVector dts;
    104       Status status = GetNodeAttr(*node_def_, arg_->type_list_attr(), &dts);
    105       if (!status.ok()) {
    106         return errors::InvalidArgument(
    107             "Could not infer list of types for input '", arg_->name(),
    108             "': ", status.error_message());
    109       }
    110       SourceList(dts);
    111       return Status::OK();
    112     }
    113 
    114     DataType dt;
    115     TF_RETURN_IF_ERROR(GetDataType(&dt));
    116     builder_->Input(in_node_, 0, dt);
    117   }
    118   return Status::OK();
    119 }
    120 
    121 // static
    122 string FakeInputImpl::FakeNodeName(int in_index) {
    123   char c = 'a' + (in_index % 26);
    124   return string(&c, 1);
    125 }
    126 
    127 Status FakeInputImpl::GetN(int* n) const {
    128   if (n_specified_) {
    129     *n = n_;
    130   } else {
    131     Status status = GetNodeAttr(*node_def_, arg_->number_attr(), n);
    132     if (!status.ok()) {
    133       return errors::InvalidArgument("Could not infer length of input '",
    134                                      arg_->name(),
    135                                      "': ", status.error_message());
    136     }
    137   }
    138   return Status::OK();
    139 }
    140 
    141 Status FakeInputImpl::GetDataType(DataType* dt) const {
    142   if (dt_specified_) {
    143     *dt = dt_;
    144     return Status::OK();  // Ignore is_ref field of arg_.
    145   } else if (arg_->type() != DT_INVALID) {
    146     *dt = arg_->type();
    147   } else if (!arg_->type_attr().empty()) {
    148     Status status = GetNodeAttr(*node_def_, arg_->type_attr(), dt);
    149     if (!status.ok()) {
    150       // Check if the type attr has a default
    151       const OpDef::AttrDef* attr = FindAttr(arg_->type_attr(), *op_def_);
    152       if (attr && attr->has_default_value()) {
    153         *dt = attr->default_value().type();
    154       } else {
    155         return errors::InvalidArgument("Could not infer type for input '",
    156                                        arg_->name(),
    157                                        "': ", status.error_message());
    158       }
    159     }
    160   } else {
    161     return errors::InvalidArgument("No type or type_attr field in arg '",
    162                                    arg_->name(), "'");
    163   }
    164   if (arg_->is_ref()) {
    165     *dt = MakeRefType(*dt);
    166   }
    167   return Status::OK();
    168 }
    169 
    170 void FakeInputImpl::NSources(int n, DataType dt) const {
    171   std::vector<NodeDefBuilder::NodeOut> srcs;
    172   srcs.reserve(n);
    173   for (int i = 0; i < n; ++i) {
    174     srcs.emplace_back(in_node_, i, dt);
    175   }
    176   builder_->Input(gtl::ArraySlice<NodeDefBuilder::NodeOut>(srcs));
    177 }
    178 
    179 void FakeInputImpl::SourceList(DataTypeSlice dts) const {
    180   std::vector<NodeDefBuilder::NodeOut> srcs;
    181   srcs.reserve(dts.size());
    182   for (size_t i = 0; i < dts.size(); ++i) {
    183     srcs.emplace_back(in_node_, i, dts[i]);
    184   }
    185   builder_->Input(gtl::ArraySlice<NodeDefBuilder::NodeOut>(srcs));
    186 }
    187 
    188 }  // namespace
    189 
    190 // Public interface ------------------------------------------------------------
    191 
    192 FakeInputFunctor FakeInput() {
    193   return [](const OpDef& op_def, int in_index, const NodeDef& node_def,
    194             NodeDefBuilder* builder) {
    195     FakeInputImpl impl(&op_def, in_index, &node_def, builder);
    196     return impl.AddInputToBuilder();
    197   };
    198 }
    199 
    200 FakeInputFunctor FakeInput(DataType dt) {
    201   return [dt](const OpDef& op_def, int in_index, const NodeDef& node_def,
    202               NodeDefBuilder* builder) {
    203     FakeInputImpl impl(&op_def, in_index, &node_def, builder);
    204     impl.SetDataType(dt);
    205     return impl.AddInputToBuilder();
    206   };
    207 }
    208 
    209 FakeInputFunctor FakeInput(int n) {
    210   return [n](const OpDef& op_def, int in_index, const NodeDef& node_def,
    211              NodeDefBuilder* builder) {
    212     FakeInputImpl impl(&op_def, in_index, &node_def, builder);
    213     impl.SetN(n);
    214     return impl.AddInputToBuilder();
    215   };
    216 }
    217 
    218 FakeInputFunctor FakeInput(int n, DataType dt) {
    219   return [n, dt](const OpDef& op_def, int in_index, const NodeDef& node_def,
    220                  NodeDefBuilder* builder) {
    221     FakeInputImpl impl(&op_def, in_index, &node_def, builder);
    222     impl.SetN(n);
    223     impl.SetDataType(dt);
    224     return impl.AddInputToBuilder();
    225   };
    226 }
    227 
    228 FakeInputFunctor FakeInput(DataTypeSlice dts) {
    229   // Make a copy to ensure the data will still be around when the lambda is
    230   // called.
    231   DataTypeVector dtv(dts.begin(), dts.end());
    232   return [dtv](const OpDef& op_def, int in_index, const NodeDef& node_def,
    233                NodeDefBuilder* builder) {
    234     FakeInputImpl impl(&op_def, in_index, &node_def, builder);
    235     impl.SetTypeList(dtv);
    236     return impl.AddInputToBuilder();
    237   };
    238 }
    239 
    240 }  // namespace tensorflow
    241