1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/ops/compat/op_compatibility_lib.h" 17 18 #include <stdio.h> 19 #include "tensorflow/core/framework/op.h" 20 #include "tensorflow/core/framework/op_def_util.h" 21 #include "tensorflow/core/lib/core/errors.h" 22 #include "tensorflow/core/lib/core/status.h" 23 #include "tensorflow/core/lib/io/path.h" 24 #include "tensorflow/core/lib/strings/str_util.h" 25 #include "tensorflow/core/lib/strings/strcat.h" 26 #include "tensorflow/core/platform/protobuf.h" 27 28 namespace tensorflow { 29 30 static string OpsHistoryFile(const string& ops_prefix, 31 const string& history_version) { 32 return io::JoinPath(ops_prefix, strings::StrCat("compat/ops_history.", 33 history_version, ".pbtxt")); 34 } 35 36 OpCompatibilityLib::OpCompatibilityLib(const string& ops_prefix, 37 const string& history_version, 38 const std::set<string>* stable_ops) 39 : ops_file_(io::JoinPath(ops_prefix, "ops.pbtxt")), 40 op_history_file_(OpsHistoryFile(ops_prefix, history_version)), 41 stable_ops_(stable_ops) { 42 // Get the sorted list of all registered OpDefs. 43 printf("Getting all registered ops...\n"); 44 OpRegistry::Global()->Export(false, &op_list_); 45 } 46 47 Status OpCompatibilityLib::ValidateCompatible(Env* env, int* changed_ops, 48 int* added_ops, 49 OpList* out_op_history) { 50 *changed_ops = 0; 51 *added_ops = 0; 52 53 // Strip docs out of op_list_. 54 RemoveDescriptionsFromOpList(&op_list_); 55 56 if (stable_ops_ != nullptr) { 57 printf("Verifying no stable ops have been removed...\n"); 58 std::vector<string> removed; 59 // We rely on stable_ops_ and op_list_ being in sorted order. 60 auto iter = stable_ops_->begin(); 61 for (int cur = 0; iter != stable_ops_->end() && cur < op_list_.op_size(); 62 ++cur) { 63 const string& op_name = op_list_.op(cur).name(); 64 while (op_name > *iter) { 65 removed.push_back(*iter); 66 ++iter; 67 } 68 if (op_name == *iter) { 69 ++iter; 70 } 71 } 72 for (; iter != stable_ops_->end(); ++iter) { 73 removed.push_back(*iter); 74 } 75 if (!removed.empty()) { 76 return errors::InvalidArgument("Error, stable op(s) removed: ", 77 str_util::Join(removed, ", ")); 78 } 79 } 80 81 OpList in_op_history; 82 { // Read op history. 83 printf("Reading op history from %s...\n", op_history_file_.c_str()); 84 string op_history_str; 85 TF_RETURN_IF_ERROR( 86 ReadFileToString(env, op_history_file_, &op_history_str)); 87 protobuf::TextFormat::ParseFromString(op_history_str, &in_op_history); 88 } 89 90 int cur = 0; 91 int start = 0; 92 93 printf("Verifying updates are compatible...\n"); 94 // Note: Op history is in (alphabetical, oldest-first) order. 95 while (cur < op_list_.op_size() && start < in_op_history.op_size()) { 96 const string& op_name = op_list_.op(cur).name(); 97 if (stable_ops_ != nullptr && stable_ops_->count(op_name) == 0) { 98 // Ignore unstable op. 99 for (++cur; cur < op_list_.op_size(); ++cur) { 100 if (op_list_.op(cur).name() != op_name) break; 101 } 102 } else if (op_name < in_op_history.op(start).name()) { 103 // New op: add it. 104 if (out_op_history != nullptr) { 105 *out_op_history->add_op() = op_list_.op(cur); 106 } 107 ++*added_ops; 108 ++cur; 109 } else if (op_name > in_op_history.op(start).name()) { 110 if (stable_ops_ != nullptr) { 111 // Okay to remove ops from the history that have been made unstable. 112 for (++start; start < in_op_history.op_size(); ++start) { 113 if (op_name <= in_op_history.op(start).name()) break; 114 } 115 } else { 116 // Op removed: error. 117 return errors::InvalidArgument("Error, removed op: ", 118 SummarizeOpDef(in_op_history.op(start))); 119 } 120 } else { 121 // Op match. 122 123 // Find all historical version of this op. 124 int end = start + 1; 125 for (; end < in_op_history.op_size(); ++end) { 126 if (in_op_history.op(end).name() != op_name) break; 127 } 128 129 if (out_op_history != nullptr) { 130 // Copy from in_op_history to *out_op_history. 131 for (int i = start; i < end; ++i) { 132 *out_op_history->add_op() = in_op_history.op(i); 133 } 134 } 135 136 // Is the last op in the history the same as the current op? 137 // Compare using their serialized representations. 138 string history_str, cur_str; 139 in_op_history.op(end - 1).SerializeToString(&history_str); 140 op_list_.op(cur).SerializeToString(&cur_str); 141 142 if (history_str != cur_str) { 143 // Op changed, verify the change is compatible. 144 for (int i = start; i < end; ++i) { 145 TF_RETURN_IF_ERROR( 146 OpDefCompatible(in_op_history.op(i), op_list_.op(cur))); 147 } 148 149 // Verify default value of attrs has not been added/removed/modified 150 // as compared to only the last historical version. 151 TF_RETURN_IF_ERROR(OpDefAttrDefaultsUnchanged(in_op_history.op(end - 1), 152 op_list_.op(cur))); 153 154 // Check that attrs missing from in_op_history.op(start) don't 155 // change their defaults. 156 if (start < end - 1) { 157 TF_RETURN_IF_ERROR(OpDefAddedDefaultsUnchanged( 158 in_op_history.op(start), in_op_history.op(end - 1), 159 op_list_.op(cur))); 160 } 161 162 // Compatible! Add changed op to the end of the history. 163 if (out_op_history != nullptr) { 164 *out_op_history->add_op() = op_list_.op(cur); 165 } 166 ++*changed_ops; 167 } 168 169 // Advance past this op. 170 start = end; 171 ++cur; 172 } 173 } 174 175 // Error if missing ops. 176 if (stable_ops_ == nullptr && start < in_op_history.op_size()) { 177 return errors::InvalidArgument("Error, removed op: ", 178 SummarizeOpDef(in_op_history.op(start))); 179 } 180 181 // Add remaining new ops. 182 for (; cur < op_list_.op_size(); ++cur) { 183 const string& op_name = op_list_.op(cur).name(); 184 if (stable_ops_ != nullptr && stable_ops_->count(op_name) == 0) { 185 // Ignore unstable op. 186 } else { 187 if (out_op_history) { 188 *out_op_history->add_op() = op_list_.op(cur); 189 } 190 ++*added_ops; 191 } 192 } 193 194 return Status::OK(); 195 } 196 197 } // namespace tensorflow 198