Home | History | Annotate | Download | only in framework
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/framework/attr_value_util.h"
     17 
     18 #include <string>
     19 #include <vector>
     20 
     21 #include "tensorflow/core/framework/attr_value.pb_text.h"
     22 #include "tensorflow/core/framework/tensor.pb_text.h"
     23 #include "tensorflow/core/framework/tensor_shape.pb.h"
     24 #include "tensorflow/core/framework/types.h"
     25 #include "tensorflow/core/framework/types.pb_text.h"
     26 #include "tensorflow/core/lib/core/errors.h"
     27 #include "tensorflow/core/lib/core/stringpiece.h"
     28 #include "tensorflow/core/lib/hash/hash.h"
     29 #include "tensorflow/core/lib/strings/str_util.h"
     30 #include "tensorflow/core/platform/protobuf.h"
     31 
     32 namespace tensorflow {
     33 namespace {
     34 
     35 string SummarizeString(const string& str) {
     36   string escaped = str_util::CEscape(str);
     37 
     38   // If the string is long, replace the middle with ellipses.
     39   constexpr int kMaxStringSummarySize = 80;
     40   if (escaped.size() >= kMaxStringSummarySize) {
     41     StringPiece prefix(escaped);
     42     StringPiece suffix = prefix;
     43     prefix.remove_suffix(escaped.size() - 10);
     44     suffix.remove_prefix(escaped.size() - 10);
     45     return strings::StrCat("\"", prefix, "...", suffix, "\"");
     46   } else {
     47     return strings::StrCat("\"", escaped, "\"");
     48   }
     49 }
     50 
     51 string SummarizeTensor(const TensorProto& tensor_proto) {
     52   Tensor t;
     53   if (!t.FromProto(tensor_proto)) {
     54     return strings::StrCat(
     55         "<Invalid TensorProto: ", ProtoShortDebugString(tensor_proto), ">");
     56   }
     57   return t.DebugString();
     58 }
     59 
     60 string SummarizeFunc(const NameAttrList& func) {
     61   std::vector<string> entries;
     62   for (auto p : func.attr()) {
     63     entries.push_back(
     64         strings::StrCat(p.first, "=", SummarizeAttrValue(p.second)));
     65   }
     66   std::sort(entries.begin(), entries.end());
     67   return strings::StrCat(func.name(), "[", str_util::Join(entries, ", "), "]");
     68 }
     69 
     70 }  // namespace
     71 
     72 string SummarizeAttrValue(const AttrValue& attr_value) {
     73   switch (attr_value.value_case()) {
     74     case AttrValue::kS:
     75       return SummarizeString(attr_value.s());
     76     case AttrValue::kI:
     77       return strings::StrCat(attr_value.i());
     78     case AttrValue::kF:
     79       return strings::StrCat(attr_value.f());
     80     case AttrValue::kB:
     81       return attr_value.b() ? "true" : "false";
     82     case AttrValue::kType:
     83       return EnumName_DataType(attr_value.type());
     84     case AttrValue::kShape:
     85       return PartialTensorShape::DebugString(attr_value.shape());
     86     case AttrValue::kTensor:
     87       return SummarizeTensor(attr_value.tensor());
     88     case AttrValue::kList: {
     89       std::vector<string> pieces;
     90       if (attr_value.list().s_size() > 0) {
     91         for (int i = 0; i < attr_value.list().s_size(); ++i) {
     92           pieces.push_back(SummarizeString(attr_value.list().s(i)));
     93         }
     94       } else if (attr_value.list().i_size() > 0) {
     95         for (int i = 0; i < attr_value.list().i_size(); ++i) {
     96           pieces.push_back(strings::StrCat(attr_value.list().i(i)));
     97         }
     98       } else if (attr_value.list().f_size() > 0) {
     99         for (int i = 0; i < attr_value.list().f_size(); ++i) {
    100           pieces.push_back(strings::StrCat(attr_value.list().f(i)));
    101         }
    102       } else if (attr_value.list().b_size() > 0) {
    103         for (int i = 0; i < attr_value.list().b_size(); ++i) {
    104           pieces.push_back(attr_value.list().b(i) ? "true" : "false");
    105         }
    106       } else if (attr_value.list().type_size() > 0) {
    107         for (int i = 0; i < attr_value.list().type_size(); ++i) {
    108           pieces.push_back(EnumName_DataType(attr_value.list().type(i)));
    109         }
    110       } else if (attr_value.list().shape_size() > 0) {
    111         for (int i = 0; i < attr_value.list().shape_size(); ++i) {
    112           pieces.push_back(
    113               TensorShape::DebugString(attr_value.list().shape(i)));
    114         }
    115       } else if (attr_value.list().tensor_size() > 0) {
    116         for (int i = 0; i < attr_value.list().tensor_size(); ++i) {
    117           pieces.push_back(SummarizeTensor(attr_value.list().tensor(i)));
    118         }
    119       } else if (attr_value.list().func_size() > 0) {
    120         for (int i = 0; i < attr_value.list().func_size(); ++i) {
    121           pieces.push_back(SummarizeFunc(attr_value.list().func(i)));
    122         }
    123       }
    124       constexpr int kMaxListSummarySize = 15;
    125       if (pieces.size() >= kMaxListSummarySize) {
    126         pieces.erase(pieces.begin() + 5, pieces.begin() + (pieces.size() - 6));
    127         pieces[5] = "...";
    128       }
    129       return strings::StrCat("[", str_util::Join(pieces, ", "), "]");
    130     }
    131     case AttrValue::kFunc: {
    132       return SummarizeFunc(attr_value.func());
    133     }
    134     case AttrValue::kPlaceholder:
    135       return strings::StrCat("$", attr_value.placeholder());
    136     case AttrValue::VALUE_NOT_SET:
    137       return "<Unknown AttrValue type>";
    138   }
    139   return "<Unknown AttrValue type>";  // Prevent missing return warning
    140 }
    141 
    142 Status AttrValueHasType(const AttrValue& attr_value, StringPiece type) {
    143   int num_set = 0;
    144 
    145 #define VALIDATE_FIELD(name, type_string, oneof_case)                         \
    146   do {                                                                        \
    147     if (attr_value.has_list()) {                                              \
    148       if (attr_value.list().name##_size() > 0) {                              \
    149         if (type != "list(" type_string ")") {                                \
    150           return errors::InvalidArgument(                                     \
    151               "AttrValue had value with type 'list(" type_string ")' when '", \
    152               type, "' expected");                                            \
    153         }                                                                     \
    154         ++num_set;                                                            \
    155       }                                                                       \
    156     } else if (attr_value.value_case() == AttrValue::oneof_case) {            \
    157       if (type != type_string) {                                              \
    158         return errors::InvalidArgument(                                       \
    159             "AttrValue had value with type '" type_string "' when '", type,   \
    160             "' expected");                                                    \
    161       }                                                                       \
    162       ++num_set;                                                              \
    163     }                                                                         \
    164   } while (false)
    165 
    166   VALIDATE_FIELD(s, "string", kS);
    167   VALIDATE_FIELD(i, "int", kI);
    168   VALIDATE_FIELD(f, "float", kF);
    169   VALIDATE_FIELD(b, "bool", kB);
    170   VALIDATE_FIELD(type, "type", kType);
    171   VALIDATE_FIELD(shape, "shape", kShape);
    172   VALIDATE_FIELD(tensor, "tensor", kTensor);
    173   VALIDATE_FIELD(func, "func", kFunc);
    174 
    175 #undef VALIDATE_FIELD
    176 
    177   if (attr_value.value_case() == AttrValue::kPlaceholder) {
    178     return errors::InvalidArgument(
    179         "AttrValue had value with unexpected type 'placeholder'");
    180   }
    181 
    182   // If the attr type is 'list', we expect attr_value.has_list() to be
    183   // true.  However, proto3's attr_value.has_list() can be false when
    184   // set to an empty list for GraphDef versions <= 4. So we simply
    185   // check if has_list is false and some other field in attr_value is
    186   // set to flag the error.  This test can be made more strict once
    187   // support for GraphDef versions <= 4 is dropped.
    188   if (StringPiece(type).starts_with("list(") && !attr_value.has_list()) {
    189     if (num_set) {
    190       return errors::InvalidArgument(
    191           "AttrValue missing value with expected type '", type, "'");
    192     } else {
    193       // Indicate that we have a list, but an empty one.
    194       ++num_set;
    195     }
    196   }
    197 
    198   // Okay to have an empty list, but not to be missing a non-list value.
    199   if (num_set == 0 && !StringPiece(type).starts_with("list(")) {
    200     return errors::InvalidArgument(
    201         "AttrValue missing value with expected type '", type, "'");
    202   }
    203 
    204   // Ref types and DT_INVALID are illegal, and DataTypes must
    205   // be a valid enum type.
    206   if (type == "type") {
    207     if (!DataType_IsValid(attr_value.type())) {
    208       return errors::InvalidArgument("AttrValue has invalid DataType enum: ",
    209                                      attr_value.type());
    210     }
    211     if (IsRefType(attr_value.type())) {
    212       return errors::InvalidArgument(
    213           "AttrValue must not have reference type value of ",
    214           DataTypeString(attr_value.type()));
    215     }
    216     if (attr_value.type() == DT_INVALID) {
    217       return errors::InvalidArgument("AttrValue has invalid DataType");
    218     }
    219   } else if (type == "list(type)") {
    220     for (auto as_int : attr_value.list().type()) {
    221       const DataType dtype = static_cast<DataType>(as_int);
    222       if (!DataType_IsValid(dtype)) {
    223         return errors::InvalidArgument("AttrValue has invalid DataType enum: ",
    224                                        as_int);
    225       }
    226       if (IsRefType(dtype)) {
    227         return errors::InvalidArgument(
    228             "AttrValue must not have reference type value of ",
    229             DataTypeString(dtype));
    230       }
    231       if (dtype == DT_INVALID) {
    232         return errors::InvalidArgument("AttrValue contains invalid DataType");
    233       }
    234     }
    235   }
    236 
    237   return Status::OK();
    238 }
    239 
    240 bool ParseAttrValue(StringPiece type, StringPiece text, AttrValue* out) {
    241   // Parse type.
    242   string field_name;
    243   bool is_list = type.Consume("list(");
    244   if (type.Consume("string")) {
    245     field_name = "s";
    246   } else if (type.Consume("int")) {
    247     field_name = "i";
    248   } else if (type.Consume("float")) {
    249     field_name = "f";
    250   } else if (type.Consume("bool")) {
    251     field_name = "b";
    252   } else if (type.Consume("type")) {
    253     field_name = "type";
    254   } else if (type.Consume("shape")) {
    255     field_name = "shape";
    256   } else if (type.Consume("tensor")) {
    257     field_name = "tensor";
    258   } else if (type.Consume("func")) {
    259     field_name = "func";
    260   } else if (type.Consume("placeholder")) {
    261     field_name = "placeholder";
    262   } else {
    263     return false;
    264   }
    265   if (is_list && !type.Consume(")")) {
    266     return false;
    267   }
    268 
    269   // Construct a valid text proto message to parse.
    270   string to_parse;
    271   if (is_list) {
    272     // TextFormat parser considers "i: 7" to be the same as "i: [7]",
    273     // but we only want to allow list values with [].
    274     StringPiece cleaned = text;
    275     str_util::RemoveLeadingWhitespace(&cleaned);
    276     str_util::RemoveTrailingWhitespace(&cleaned);
    277     if (cleaned.size() < 2 || cleaned[0] != '[' ||
    278         cleaned[cleaned.size() - 1] != ']') {
    279       return false;
    280     }
    281     cleaned.remove_prefix(1);
    282     str_util::RemoveLeadingWhitespace(&cleaned);
    283     if (cleaned.size() == 1) {
    284       // User wrote "[]", so return empty list without invoking the TextFormat
    285       // parse which returns an error for "i: []".
    286       out->Clear();
    287       out->mutable_list();
    288       return true;
    289     }
    290     to_parse = strings::StrCat("list { ", field_name, ": ", text, " }");
    291   } else {
    292     to_parse = strings::StrCat(field_name, ": ", text);
    293   }
    294 
    295   return ProtoParseFromString(to_parse, out);
    296 }
    297 
    298 void SetAttrValue(const AttrValue& value, AttrValue* out) { *out = value; }
    299 
    300 #define DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD) \
    301   void SetAttrValue(ARG_TYPE value, AttrValue* out) { out->set_##FIELD(value); }
    302 
    303 #define DEFINE_SET_ATTR_VALUE_LIST(ARG_TYPE, FIELD)                       \
    304   void SetAttrValue(ARG_TYPE value, AttrValue* out) {                     \
    305     out->mutable_list()->Clear(); /* create list() even if value empty */ \
    306     for (const auto& v : value) {                                         \
    307       out->mutable_list()->add_##FIELD(v);                                \
    308     }                                                                     \
    309   }
    310 
    311 #define DEFINE_SET_ATTR_VALUE_BOTH(ARG_TYPE, FIELD) \
    312   DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD)        \
    313   DEFINE_SET_ATTR_VALUE_LIST(gtl::ArraySlice<ARG_TYPE>, FIELD)
    314 
    315 DEFINE_SET_ATTR_VALUE_ONE(const string&, s)
    316 DEFINE_SET_ATTR_VALUE_LIST(gtl::ArraySlice<string>, s)
    317 DEFINE_SET_ATTR_VALUE_BOTH(const char*, s)
    318 DEFINE_SET_ATTR_VALUE_BOTH(int64, i)
    319 DEFINE_SET_ATTR_VALUE_BOTH(int32, i)
    320 DEFINE_SET_ATTR_VALUE_BOTH(float, f)
    321 DEFINE_SET_ATTR_VALUE_BOTH(double, f)
    322 DEFINE_SET_ATTR_VALUE_BOTH(bool, b)
    323 DEFINE_SET_ATTR_VALUE_LIST(const std::vector<bool>&, b)
    324 DEFINE_SET_ATTR_VALUE_LIST(std::initializer_list<bool>, b)
    325 DEFINE_SET_ATTR_VALUE_BOTH(DataType, type)
    326 
    327 void SetAttrValue(StringPiece value, AttrValue* out) {
    328   out->set_s(value.data(), value.size());
    329 }
    330 
    331 void SetAttrValue(const gtl::ArraySlice<StringPiece> value, AttrValue* out) {
    332   out->mutable_list()->Clear();  // Create list() even if value empty.
    333   for (const auto& v : value) {
    334     out->mutable_list()->add_s(v.data(), v.size());
    335   }
    336 }
    337 
    338 void SetAttrValue(const TensorShape& value, AttrValue* out) {
    339   value.AsProto(out->mutable_shape());
    340 }
    341 
    342 void SetAttrValue(const TensorShapeProto& value, AttrValue* out) {
    343   *out->mutable_shape() = value;
    344 }
    345 
    346 void SetAttrValue(const PartialTensorShape& value, AttrValue* out) {
    347   value.AsProto(out->mutable_shape());
    348 }
    349 
    350 void SetAttrValue(const gtl::ArraySlice<TensorShape> value, AttrValue* out) {
    351   out->mutable_list()->Clear();  // Create list() even if value empty.
    352   for (const auto& v : value) {
    353     v.AsProto(out->mutable_list()->add_shape());
    354   }
    355 }
    356 
    357 void SetAttrValue(gtl::ArraySlice<TensorShapeProto> value, AttrValue* out) {
    358   out->mutable_list()->Clear();  // Create list() even if value empty.
    359   for (const auto& v : value) {
    360     *out->mutable_list()->add_shape() = v;
    361   }
    362 }
    363 
    364 void SetAttrValue(const gtl::ArraySlice<PartialTensorShape> value,
    365                   AttrValue* out) {
    366   out->mutable_list()->Clear();  // Create list() even if value empty.
    367   for (const auto& v : value) {
    368     v.AsProto(out->mutable_list()->add_shape());
    369   }
    370 }
    371 
    372 void SetAttrValue(const Tensor& value, AttrValue* out) {
    373   if (value.NumElements() > 1) {
    374     value.AsProtoTensorContent(out->mutable_tensor());
    375   } else {
    376     value.AsProtoField(out->mutable_tensor());
    377   }
    378 }
    379 
    380 void SetAttrValue(const gtl::ArraySlice<Tensor> value, AttrValue* out) {
    381   out->mutable_list()->Clear();  // Create list() even if value empty.
    382   for (const auto& v : value) {
    383     if (v.NumElements() > 1) {
    384       v.AsProtoTensorContent(out->mutable_list()->add_tensor());
    385     } else {
    386       v.AsProtoField(out->mutable_list()->add_tensor());
    387     }
    388   }
    389 }
    390 
    391 void SetAttrValue(const TensorProto& value, AttrValue* out) {
    392   *out->mutable_tensor() = value;
    393 }
    394 
    395 void SetAttrValue(const gtl::ArraySlice<TensorProto> value, AttrValue* out) {
    396   out->mutable_list()->Clear();  // Create list() even if value empty.
    397   for (const auto& v : value) {
    398     *out->mutable_list()->add_tensor() = v;
    399   }
    400 }
    401 
    402 void SetAttrValue(const NameAttrList& value, AttrValue* out) {
    403   *out->mutable_func() = value;
    404 }
    405 
    406 void SetAttrValue(gtl::ArraySlice<NameAttrList> value, AttrValue* out) {
    407   out->mutable_list()->Clear();  // Create list() even if value empty.
    408   for (const auto& v : value) {
    409     *out->mutable_list()->add_func() = v;
    410   }
    411 }
    412 
    413 bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b) {
    414   // There are multiple equivalent representations of attr values containing
    415   // TensorProtos. Compare them by constructing Tensors and serializing them
    416   // back. Comparing Tensor objects is pretty tricky.
    417   if (a.has_tensor() != b.has_tensor()) {
    418     return false;
    419   } else if (a.has_tensor() && b.has_tensor()) {
    420     Tensor at(a.tensor().dtype());
    421     bool success = at.FromProto(a.tensor());
    422     DCHECK(success);
    423 
    424     Tensor bt(b.tensor().dtype());
    425     success = bt.FromProto(b.tensor());
    426     DCHECK(success);
    427 
    428     TensorProto ap;
    429     at.AsProtoTensorContent(&ap);
    430 
    431     TensorProto bp;
    432     bt.AsProtoTensorContent(&bp);
    433 
    434     string a_str, b_str;
    435     SerializeToStringDeterministic(ap, &a_str);
    436     SerializeToStringDeterministic(bp, &b_str);
    437     return a_str == b_str;
    438   }
    439 
    440   // `func` field contains a nested AttrValue. Compare such AttrValues
    441   // recursively.
    442   if (a.has_func() != b.has_func()) {
    443     return false;
    444   } else if (a.has_func() && b.has_func()) {
    445     const NameAttrList& af = a.func();
    446     const NameAttrList& bf = b.func();
    447     if (af.name() != bf.name()) return false;
    448     std::unordered_map<string, AttrValue> am(af.attr().begin(),
    449                                              af.attr().end());
    450     for (const auto& bm_pair : bf.attr()) {
    451       const auto& iter = am.find(bm_pair.first);
    452       if (iter == am.end()) return false;
    453       if (!AreAttrValuesEqual(iter->second, bm_pair.second)) return false;
    454       am.erase(iter);
    455     }
    456     if (!am.empty()) return false;
    457     return true;
    458   }
    459 
    460   // All other fields in AttrValue have deterministic representations.
    461   // It is safe to compare their serialized strings.
    462   string a_str, b_str;
    463   SerializeToStringDeterministic(a, &a_str);
    464   SerializeToStringDeterministic(b, &b_str);
    465   return a_str == b_str;
    466 }
    467 
    468 uint64 AttrValueHash(const AttrValue& a) {
    469   if (a.has_tensor()) {
    470     // Deal with multiple representations by parsing TensorProto to
    471     // Tensor and serializing it back. This is slow, but current use case
    472     // don't need high efficiency.
    473     Tensor tensor(a.tensor().dtype());
    474     bool success = tensor.FromProto(a.tensor());
    475     DCHECK(success);
    476     TensorProto p;
    477     tensor.AsProtoTensorContent(&p);
    478     string s;
    479     SerializeToStringDeterministic(p, &s);
    480     return Hash64(s);
    481   }
    482   if (a.has_func()) {
    483     const NameAttrList& func = a.func();
    484     uint64 h = Hash64(func.name());
    485     std::map<string, AttrValue> map(func.attr().begin(), func.attr().end());
    486     for (const auto& pair : map) {
    487       h = Hash64(pair.first.data(), pair.first.size(), h);
    488       h = Hash64Combine(AttrValueHash(pair.second), h);
    489     }
    490     return h;
    491   }
    492 
    493   // If `a` is not a tensor or func, get a hash of serialized string.
    494   string s;
    495   SerializeToStringDeterministic(a, &s);
    496   return Hash64(s);
    497 }
    498 
    499 bool HasPlaceHolder(const AttrValue& val) {
    500   switch (val.value_case()) {
    501     case AttrValue::kList: {
    502       for (const NameAttrList& func : val.list().func()) {
    503         for (const auto& p : func.attr()) {
    504           if (HasPlaceHolder(p.second)) {
    505             return true;
    506           }
    507         }
    508       }
    509       break;
    510     }
    511     case AttrValue::kFunc:
    512       for (const auto& p : val.func().attr()) {
    513         if (HasPlaceHolder(p.second)) {
    514           return true;
    515         }
    516       }
    517       break;
    518     case AttrValue::kPlaceholder:
    519       return true;
    520     default:
    521       break;
    522   }
    523   return false;
    524 }
    525 
    526 bool SubstitutePlaceholders(const SubstituteFunc& substitute,
    527                             AttrValue* value) {
    528   switch (value->value_case()) {
    529     case AttrValue::kList: {
    530       for (NameAttrList& func : *value->mutable_list()->mutable_func()) {
    531         for (auto& p : *func.mutable_attr()) {
    532           if (!SubstitutePlaceholders(substitute, &p.second)) {
    533             return false;
    534           }
    535         }
    536       }
    537       break;
    538     }
    539     case AttrValue::kFunc:
    540       for (auto& p : *(value->mutable_func()->mutable_attr())) {
    541         if (!SubstitutePlaceholders(substitute, &p.second)) {
    542           return false;
    543         }
    544       }
    545       break;
    546     case AttrValue::kPlaceholder:
    547       return substitute(value->placeholder(), value);
    548     case AttrValue::VALUE_NOT_SET:
    549       return false;
    550     default:
    551       break;
    552   }
    553   return true;
    554 }
    555 
    556 }  // namespace tensorflow
    557