Home | History | Annotate | Download | only in compat
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/ops/compat/op_compatibility_lib.h"
     17 
     18 #include <stdio.h>
     19 #include "tensorflow/core/framework/op.h"
     20 #include "tensorflow/core/framework/op_def_util.h"
     21 #include "tensorflow/core/lib/core/errors.h"
     22 #include "tensorflow/core/lib/core/status.h"
     23 #include "tensorflow/core/lib/io/path.h"
     24 #include "tensorflow/core/lib/strings/str_util.h"
     25 #include "tensorflow/core/lib/strings/strcat.h"
     26 #include "tensorflow/core/platform/protobuf.h"
     27 
     28 namespace tensorflow {
     29 
     30 static string OpsHistoryFile(const string& ops_prefix,
     31                              const string& history_version) {
     32   return io::JoinPath(ops_prefix, strings::StrCat("compat/ops_history.",
     33                                                   history_version, ".pbtxt"));
     34 }
     35 
     36 OpCompatibilityLib::OpCompatibilityLib(const string& ops_prefix,
     37                                        const string& history_version,
     38                                        const std::set<string>* stable_ops)
     39     : ops_file_(io::JoinPath(ops_prefix, "ops.pbtxt")),
     40       op_history_file_(OpsHistoryFile(ops_prefix, history_version)),
     41       stable_ops_(stable_ops) {
     42   // Get the sorted list of all registered OpDefs.
     43   printf("Getting all registered ops...\n");
     44   OpRegistry::Global()->Export(false, &op_list_);
     45 }
     46 
     47 Status OpCompatibilityLib::ValidateCompatible(Env* env, int* changed_ops,
     48                                               int* added_ops,
     49                                               OpList* out_op_history) {
     50   *changed_ops = 0;
     51   *added_ops = 0;
     52 
     53   // Strip docs out of op_list_.
     54   RemoveDescriptionsFromOpList(&op_list_);
     55 
     56   if (stable_ops_ != nullptr) {
     57     printf("Verifying no stable ops have been removed...\n");
     58     std::vector<string> removed;
     59     // We rely on stable_ops_ and op_list_ being in sorted order.
     60     auto iter = stable_ops_->begin();
     61     for (int cur = 0; iter != stable_ops_->end() && cur < op_list_.op_size();
     62          ++cur) {
     63       const string& op_name = op_list_.op(cur).name();
     64       while (op_name > *iter) {
     65         removed.push_back(*iter);
     66         ++iter;
     67       }
     68       if (op_name == *iter) {
     69         ++iter;
     70       }
     71     }
     72     for (; iter != stable_ops_->end(); ++iter) {
     73       removed.push_back(*iter);
     74     }
     75     if (!removed.empty()) {
     76       return errors::InvalidArgument("Error, stable op(s) removed: ",
     77                                      str_util::Join(removed, ", "));
     78     }
     79   }
     80 
     81   OpList in_op_history;
     82   {  // Read op history.
     83     printf("Reading op history from %s...\n", op_history_file_.c_str());
     84     string op_history_str;
     85     TF_RETURN_IF_ERROR(
     86         ReadFileToString(env, op_history_file_, &op_history_str));
     87     protobuf::TextFormat::ParseFromString(op_history_str, &in_op_history);
     88   }
     89 
     90   int cur = 0;
     91   int start = 0;
     92 
     93   printf("Verifying updates are compatible...\n");
     94   // Note: Op history is in (alphabetical, oldest-first) order.
     95   while (cur < op_list_.op_size() && start < in_op_history.op_size()) {
     96     const string& op_name = op_list_.op(cur).name();
     97     if (stable_ops_ != nullptr && stable_ops_->count(op_name) == 0) {
     98       // Ignore unstable op.
     99       for (++cur; cur < op_list_.op_size(); ++cur) {
    100         if (op_list_.op(cur).name() != op_name) break;
    101       }
    102     } else if (op_name < in_op_history.op(start).name()) {
    103       // New op: add it.
    104       if (out_op_history != nullptr) {
    105         *out_op_history->add_op() = op_list_.op(cur);
    106       }
    107       ++*added_ops;
    108       ++cur;
    109     } else if (op_name > in_op_history.op(start).name()) {
    110       if (stable_ops_ != nullptr) {
    111         // Okay to remove ops from the history that have been made unstable.
    112         for (++start; start < in_op_history.op_size(); ++start) {
    113           if (op_name <= in_op_history.op(start).name()) break;
    114         }
    115       } else {
    116         // Op removed: error.
    117         return errors::InvalidArgument("Error, removed op: ",
    118                                        SummarizeOpDef(in_op_history.op(start)));
    119       }
    120     } else {
    121       // Op match.
    122 
    123       // Find all historical version of this op.
    124       int end = start + 1;
    125       for (; end < in_op_history.op_size(); ++end) {
    126         if (in_op_history.op(end).name() != op_name) break;
    127       }
    128 
    129       if (out_op_history != nullptr) {
    130         // Copy from in_op_history to *out_op_history.
    131         for (int i = start; i < end; ++i) {
    132           *out_op_history->add_op() = in_op_history.op(i);
    133         }
    134       }
    135 
    136       // Is the last op in the history the same as the current op?
    137       // Compare using their serialized representations.
    138       string history_str, cur_str;
    139       in_op_history.op(end - 1).SerializeToString(&history_str);
    140       op_list_.op(cur).SerializeToString(&cur_str);
    141 
    142       if (history_str != cur_str) {
    143         // Op changed, verify the change is compatible.
    144         for (int i = start; i < end; ++i) {
    145           TF_RETURN_IF_ERROR(
    146               OpDefCompatible(in_op_history.op(i), op_list_.op(cur)));
    147         }
    148 
    149         // Verify default value of attrs has not been added/removed/modified
    150         // as compared to only the last historical version.
    151         TF_RETURN_IF_ERROR(OpDefAttrDefaultsUnchanged(in_op_history.op(end - 1),
    152                                                       op_list_.op(cur)));
    153 
    154         // Check that attrs missing from in_op_history.op(start) don't
    155         // change their defaults.
    156         if (start < end - 1) {
    157           TF_RETURN_IF_ERROR(OpDefAddedDefaultsUnchanged(
    158               in_op_history.op(start), in_op_history.op(end - 1),
    159               op_list_.op(cur)));
    160         }
    161 
    162         // Compatible! Add changed op to the end of the history.
    163         if (out_op_history != nullptr) {
    164           *out_op_history->add_op() = op_list_.op(cur);
    165         }
    166         ++*changed_ops;
    167       }
    168 
    169       // Advance past this op.
    170       start = end;
    171       ++cur;
    172     }
    173   }
    174 
    175   // Error if missing ops.
    176   if (stable_ops_ == nullptr && start < in_op_history.op_size()) {
    177     return errors::InvalidArgument("Error, removed op: ",
    178                                    SummarizeOpDef(in_op_history.op(start)));
    179   }
    180 
    181   // Add remaining new ops.
    182   for (; cur < op_list_.op_size(); ++cur) {
    183     const string& op_name = op_list_.op(cur).name();
    184     if (stable_ops_ != nullptr && stable_ops_->count(op_name) == 0) {
    185       // Ignore unstable op.
    186     } else {
    187       if (out_op_history) {
    188         *out_op_history->add_op() = op_list_.op(cur);
    189       }
    190       ++*added_ops;
    191     }
    192   }
    193 
    194   return Status::OK();
    195 }
    196 
    197 }  // namespace tensorflow
    198