1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/python/client/tf_session_helper.h" 17 18 #include <cstring> 19 20 #include "tensorflow/c/c_api.h" 21 #include "tensorflow/c/c_api_internal.h" 22 #include "tensorflow/c/tf_status_helper.h" 23 #include "tensorflow/core/framework/allocator.h" 24 #include "tensorflow/core/framework/attr_value.pb.h" 25 #include "tensorflow/core/framework/attr_value_util.h" 26 #include "tensorflow/core/framework/log_memory.h" 27 #include "tensorflow/core/framework/op_kernel.h" 28 #include "tensorflow/core/graph/tensor_id.h" 29 #include "tensorflow/core/lib/core/coding.h" 30 #include "tensorflow/core/lib/strings/stringprintf.h" 31 #include "tensorflow/core/platform/types.h" 32 #include "tensorflow/core/util/equal_graph_def.h" 33 #include "tensorflow/python/lib/core/ndarray_tensor.h" 34 #include "tensorflow/python/lib/core/ndarray_tensor_bridge.h" 35 #include "tensorflow/python/lib/core/safe_ptr.h" 36 37 namespace tensorflow { 38 39 namespace { 40 41 static const char* kFeedDictErrorMsg = 42 "feed_dict must be a dictionary mapping strings to NumPy arrays."; 43 } // end namespace 44 45 void TF_Run_wrapper_helper(TF_DeprecatedSession* session, const char* handle, 46 const TF_Buffer* run_options, PyObject* feed_dict, 47 const NameVector& output_names, 48 const NameVector& target_nodes, 49 TF_Status* out_status, PyObjectVector* out_values, 50 TF_Buffer* run_outputs) { 51 // 1. Convert the feed inputs to the appropriate form for TF_Run. 52 if (!PyDict_Check(feed_dict)) { 53 Set_TF_Status_from_Status(out_status, 54 errors::InvalidArgument(kFeedDictErrorMsg)); 55 return; 56 } 57 58 NameVector input_names; 59 std::vector<Safe_TF_TensorPtr> inputs_safe; // Used to delete tensors. 60 TF_TensorVector inputs_unsafe; // Used to contain the arg to TF_Run. 61 62 PyObject* key; 63 PyObject* value; 64 Py_ssize_t pos = 0; 65 int index = 0; 66 Status s; 67 68 while (PyDict_Next(feed_dict, &pos, &key, &value)) { 69 char* key_string = PyBytes_AsString(key); 70 if (!key_string) { 71 Set_TF_Status_from_Status(out_status, 72 errors::InvalidArgument(kFeedDictErrorMsg)); 73 return; 74 } 75 input_names.push_back(key_string); 76 77 inputs_safe.emplace_back(make_safe(static_cast<TF_Tensor*>(nullptr))); 78 s = PyArrayToTF_Tensor(value, &inputs_safe.back()); 79 if (!s.ok()) { 80 Set_TF_Status_from_Status(out_status, s); 81 return; 82 } 83 inputs_unsafe.push_back(inputs_safe.back().get()); 84 ++index; 85 } 86 87 // 2. Allocate a container for the output data. 88 TF_TensorVector outputs(output_names.size()); 89 90 // In case any tensors were leftover from previous runs we might as well clear 91 // them here. 92 ClearDecrefCache(); 93 94 // 3. Actually call TF_Run(). 95 Py_BEGIN_ALLOW_THREADS; 96 if (handle == nullptr) { 97 TF_Run(session, run_options, input_names.data(), inputs_unsafe.data(), 98 input_names.size(), const_cast<const char**>(output_names.data()), 99 outputs.data(), output_names.size(), 100 const_cast<const char**>(target_nodes.data()), target_nodes.size(), 101 run_outputs, out_status); 102 } else { 103 TF_PRun(session, handle, input_names.data(), inputs_unsafe.data(), 104 input_names.size(), const_cast<const char**>(output_names.data()), 105 outputs.data(), output_names.size(), 106 const_cast<const char**>(target_nodes.data()), target_nodes.size(), 107 out_status); 108 } 109 110 Py_END_ALLOW_THREADS; 111 112 // Decref any numpy arrays we are not using anymore. 113 ClearDecrefCache(); 114 115 if (TF_GetCode(out_status) != TF_OK) { 116 return; 117 } 118 119 // 4. We now own the fetched tensors, so set up a safe container to 120 // delete them when we exit this scope. 121 std::vector<Safe_TF_TensorPtr> tf_outputs_safe; 122 for (const auto& output : outputs) { 123 tf_outputs_safe.emplace_back(make_safe(output)); 124 } 125 126 // 5. Convert the fetched tensors into numpy ndarrays. Store them in a safe 127 // container so that we do not leak 128 std::vector<Safe_PyObjectPtr> py_outputs_safe; 129 for (size_t i = 0; i < output_names.size(); ++i) { 130 PyObject* py_array; 131 s = TF_TensorToPyArray(std::move(tf_outputs_safe[i]), &py_array); 132 if (!s.ok()) { 133 Set_TF_Status_from_Status(out_status, s); 134 return; 135 } 136 py_outputs_safe.emplace_back(make_safe(py_array)); 137 } 138 139 // 6. If we reach this point, we have successfully built a list of objects 140 // so we can release them from the safe container. 141 for (auto& output : py_outputs_safe) { 142 out_values->push_back(output.release()); 143 } 144 } 145 146 // Wrapper for TF_Run that converts the arguments to appropriate types. 147 // If *out_status is OK, the caller becomes the owner of the PyObjects 148 // in *out_values. 149 void TF_Run_wrapper(TF_DeprecatedSession* session, const TF_Buffer* run_options, 150 PyObject* feed_dict, const NameVector& output_names, 151 const NameVector& target_nodes, TF_Status* out_status, 152 PyObjectVector* out_values, TF_Buffer* run_outputs) { 153 TF_Run_wrapper_helper(session, nullptr, run_options, feed_dict, output_names, 154 target_nodes, out_status, out_values, run_outputs); 155 ClearDecrefCache(); 156 } 157 158 // Wrapper for TF_PRunSetup that converts the arguments to appropriate types. 159 // If *out_status is OK, the caller becomes the owner of *out_handle. 160 void TF_PRunSetup_wrapper(TF_DeprecatedSession* session, 161 const NameVector& input_names, 162 const NameVector& output_names, 163 const NameVector& target_nodes, TF_Status* out_status, 164 const char** out_handle) { 165 Py_BEGIN_ALLOW_THREADS; 166 TF_PRunSetup( 167 session, const_cast<const char**>(input_names.data()), input_names.size(), 168 const_cast<const char**>(output_names.data()), output_names.size(), 169 const_cast<const char**>(target_nodes.data()), target_nodes.size(), 170 out_handle, out_status); 171 Py_END_ALLOW_THREADS; 172 } 173 174 // Wrapper for TF_PRun that converts the arguments to appropriate types. 175 // If *out_status is OK, the caller becomes the owner of the PyObjects 176 // in *out_values. 177 void TF_PRun_wrapper(TF_DeprecatedSession* session, const char* handle, 178 PyObject* feed_dict, const NameVector& output_names, 179 TF_Status* out_status, PyObjectVector* out_values) { 180 TF_Run_wrapper_helper(session, handle, nullptr, feed_dict, output_names, 181 NameVector(), out_status, out_values, nullptr); 182 ClearDecrefCache(); 183 } 184 185 // Wrapper for TF_Reset that converts the string vectors to character arrays. 186 void TF_Reset_wrapper(const TF_SessionOptions* opt, 187 const NameVector& containers, TF_Status* out_status) { 188 TF_Reset(opt, const_cast<const char**>(containers.data()), containers.size(), 189 out_status); 190 } 191 192 void TF_SessionRun_wrapper_helper(TF_Session* session, const char* handle, 193 const TF_Buffer* run_options, 194 const std::vector<TF_Output>& inputs, 195 const std::vector<PyObject*>& input_ndarrays, 196 const std::vector<TF_Output>& outputs, 197 const std::vector<TF_Operation*>& targets, 198 TF_Buffer* run_metadata, 199 TF_Status* out_status, 200 std::vector<PyObject*>* py_outputs) { 201 DCHECK_EQ(inputs.size(), input_ndarrays.size()); 202 DCHECK(py_outputs != nullptr); 203 DCHECK(py_outputs->empty()); 204 Status s; 205 206 // Convert input ndarray PyObjects to TF_Tensors. We maintain a continuous 207 // array of TF_Tensor*s as well as scoped containers to make sure they're 208 // cleaned up properly. 209 // 210 // Memory management: 211 // PyArrayToTF_Tensor() creates a new ndarray PyObject from the input 212 // ndarray. We manage the new ndarray's lifetime in order to keep the 213 // underlying data buffer alive (the new ndarray also guarantees a contiguous 214 // data buffer). The new ndarray's data buffer is used to create the 215 // corresponding TF_Tensor. The TF_Tensor's deallocator will queue the new 216 // ndarray to be decref'd by the next ClearDecrefCache() call (we can't call 217 // Py_DECREF in the deallocator directly because the GIL must be held). 218 // 219 // Note that TF_Tensor may directly delegate its data and deallocator to a 220 // TensorBuffer, which may outlive the TF_Tensor (e.g. if the tensor gets 221 // queued or assigned to a variable). 222 TF_TensorVector input_vals; 223 std::vector<Safe_TF_TensorPtr> input_vals_safe; 224 for (PyObject* ndarray : input_ndarrays) { 225 input_vals_safe.emplace_back(make_safe(static_cast<TF_Tensor*>(nullptr))); 226 s = PyArrayToTF_Tensor(ndarray, &input_vals_safe.back()); 227 if (!s.ok()) { 228 Set_TF_Status_from_Status(out_status, s); 229 return; 230 } 231 input_vals.push_back(input_vals_safe.back().get()); 232 } 233 234 // Allocate space for output TF_Tensor*s 235 TF_TensorVector output_vals(outputs.size()); 236 237 // Clear up any unused memory leftover from previous runs 238 ClearDecrefCache(); 239 240 // Call TF_SessionRun() (and release GIL during execution) 241 Py_BEGIN_ALLOW_THREADS; 242 if (handle == nullptr) { 243 TF_SessionRun(session, run_options, inputs.data(), input_vals.data(), 244 inputs.size(), outputs.data(), output_vals.data(), 245 outputs.size(), targets.data(), targets.size(), run_metadata, 246 out_status); 247 } else { 248 TF_SessionPRun(session, handle, inputs.data(), input_vals.data(), 249 inputs.size(), outputs.data(), output_vals.data(), 250 outputs.size(), targets.data(), targets.size(), out_status); 251 } 252 Py_END_ALLOW_THREADS; 253 254 // Create scoped containers for output tensors 255 std::vector<Safe_TF_TensorPtr> output_vals_safe; 256 for (TF_Tensor* output : output_vals) { 257 output_vals_safe.emplace_back(make_safe(output)); 258 } 259 260 // Convert outputs to ndarrays (in scoped containers) 261 std::vector<Safe_PyObjectPtr> py_outputs_safe; 262 for (size_t i = 0; i < outputs.size(); ++i) { 263 PyObject* py_array; 264 s = TF_TensorToPyArray(std::move(output_vals_safe[i]), &py_array); 265 if (!s.ok()) { 266 Set_TF_Status_from_Status(out_status, s); 267 return; 268 } 269 py_outputs_safe.emplace_back(make_safe(py_array)); 270 } 271 272 // If we reach this point, we have successfully built a list of objects so we 273 // can release them from the safe container into the return vector. 274 for (size_t i = 0; i < outputs.size(); ++i) { 275 py_outputs->push_back(py_outputs_safe[i].release()); 276 } 277 } 278 279 void TF_SessionRun_wrapper(TF_Session* session, const TF_Buffer* run_options, 280 const std::vector<TF_Output>& inputs, 281 const std::vector<PyObject*>& input_ndarrays, 282 const std::vector<TF_Output>& outputs, 283 const std::vector<TF_Operation*>& targets, 284 TF_Buffer* run_metadata, TF_Status* out_status, 285 std::vector<PyObject*>* py_outputs) { 286 TF_SessionRun_wrapper_helper(session, nullptr, run_options, inputs, 287 input_ndarrays, outputs, targets, run_metadata, 288 out_status, py_outputs); 289 // Release any unused ndarray references (see memory management comment in 290 // TF_SessionRun_wrapper_helper) 291 ClearDecrefCache(); 292 } 293 294 string EqualGraphDefWrapper(const string& actual, const string& expected) { 295 GraphDef actual_def; 296 if (!actual_def.ParseFromString(actual)) { 297 return "actual is not a valid serialized GraphDef"; 298 } 299 GraphDef expected_def; 300 if (!expected_def.ParseFromString(expected)) { 301 return "expected is not a valid serialized GraphDef"; 302 } 303 string diff; 304 return EqualGraphDef(actual_def, expected_def, &diff) ? "" : diff; 305 } 306 307 string EqualAttrValueWrapper(const string& actual, const string& expected) { 308 AttrValue actual_attr_value; 309 if (!actual_attr_value.ParseFromString(actual)) { 310 return "actual is not a valid serialized AttrValue"; 311 } 312 313 AttrValue expected_attr_value; 314 if (!expected_attr_value.ParseFromString(expected)) { 315 return "expected is not a valid serialized AttrValue"; 316 } 317 318 string diff; 319 if (!AreAttrValuesEqual(actual_attr_value, expected_attr_value)) { 320 diff = strings::Printf( 321 "Actual AttrValue %s does not match Expected AttrValue %s.", 322 SummarizeAttrValue(actual_attr_value).c_str(), 323 SummarizeAttrValue(expected_attr_value).c_str()); 324 } 325 return diff; 326 } 327 328 // Return value set to 6 inlined elements so it fits in a 64-byte cache line. 329 tensorflow::gtl::InlinedVector<int64_t, 6> TF_GraphGetTensorShapeHelper( 330 TF_Graph* graph, TF_Output output, TF_Status* out_status, 331 bool* unknown_shape) { 332 // Allocate a single variable for holding the result for RVO. 333 tensorflow::gtl::InlinedVector<int64_t, 6> result; 334 *unknown_shape = false; 335 int num_dims = TF_GraphGetTensorNumDims(graph, output, out_status); 336 if (TF_GetCode(out_status) != TF_OK) { 337 return result; 338 } 339 // If shape is unknown, set boolean and return. 340 if (num_dims == -1) { 341 *unknown_shape = true; 342 return result; 343 } 344 345 // If shape is a scalar, avoid another C call and just return {}. 346 if (num_dims == 0) { 347 return result; 348 } 349 350 result.resize(num_dims); 351 TF_GraphGetTensorShape(graph, output, result.data(), num_dims, out_status); 352 return result; 353 } 354 355 void TF_SessionPRunSetup_wrapper(TF_Session* session, 356 const std::vector<TF_Output>& inputs, 357 const std::vector<TF_Output>& outputs, 358 const std::vector<TF_Operation*>& targets, 359 const char** out_handle, 360 TF_Status* out_status) { 361 // Call TF_SessionPRunSetup() (and release GIL during execution) 362 Py_BEGIN_ALLOW_THREADS; 363 TF_SessionPRunSetup(session, inputs.data(), inputs.size(), outputs.data(), 364 outputs.size(), targets.data(), targets.size(), 365 out_handle, out_status); 366 Py_END_ALLOW_THREADS; 367 } 368 369 void TF_SessionPRun_wrapper(TF_Session* session, const char* handle, 370 const std::vector<TF_Output>& inputs, 371 const std::vector<PyObject*>& input_ndarrays, 372 const std::vector<TF_Output>& outputs, 373 TF_Status* out_status, 374 std::vector<PyObject*>* py_outputs) { 375 const std::vector<TF_Operation*> targets; 376 TF_SessionRun_wrapper_helper(session, handle, 377 nullptr, // run_options 378 inputs, input_ndarrays, outputs, targets, 379 nullptr, // run_metadata 380 out_status, py_outputs); 381 // Release any unused ndarray references (see memory management comment in 382 // TF_SessionRun_wrapper_helper) 383 ClearDecrefCache(); 384 } 385 386 std::vector<TF_Output> GetOperationInputs(TF_Operation* oper) { 387 int num_inputs = TF_OperationNumInputs(oper); 388 std::vector<TF_Output> inputs(num_inputs); 389 for (int i = 0; i < num_inputs; ++i) { 390 inputs[i] = TF_OperationInput({oper, i}); 391 } 392 return inputs; 393 } 394 395 std::vector<TF_Operation*> TF_OperationGetControlInputs_wrapper( 396 TF_Operation* oper) { 397 std::vector<TF_Operation*> control_inputs(TF_OperationNumControlInputs(oper)); 398 TF_OperationGetControlInputs(oper, control_inputs.data(), 399 control_inputs.size()); 400 return control_inputs; 401 } 402 403 std::vector<const char*> TF_OperationOutputConsumers_wrapper( 404 TF_Output oper_out) { 405 int num_consumers = TF_OperationOutputNumConsumers(oper_out); 406 std::vector<TF_Input> consumers(num_consumers); 407 TF_OperationOutputConsumers(oper_out, consumers.data(), num_consumers); 408 409 std::vector<const char*> consumer_names(num_consumers); 410 for (int i = 0; i < num_consumers; ++i) { 411 consumer_names[i] = TF_OperationName(consumers[i].oper); 412 } 413 return consumer_names; 414 } 415 416 TF_Function* TF_GraphToFunction_wrapper( 417 const TF_Graph* fn_body, const char* fn_name, bool append_hash_to_fn_name, 418 const std::vector<TF_Operation*>* opers, 419 const std::vector<TF_Output>& inputs, const std::vector<TF_Output>& outputs, 420 const NameVector& output_names, const TF_FunctionOptions* opts, 421 const char* description, TF_Status* out_status) { 422 if (!output_names.empty() && output_names.size() != outputs.size()) { 423 Set_TF_Status_from_Status( 424 out_status, 425 errors::InvalidArgument( 426 "output names must be either empty or equal in size to outputs. ", 427 "output names size = ", output_names.size(), 428 " outputs size = ", outputs.size())); 429 return nullptr; 430 } 431 432 int nopers = -1; 433 const TF_Operation* const* opers_array = nullptr; 434 if (opers != nullptr) { 435 nopers = opers->size(); 436 opers_array = opers->data(); 437 } 438 439 const char** output_names_ptr = 440 output_names.empty() ? nullptr 441 : const_cast<const char**>(output_names.data()); 442 443 return TF_GraphToFunction(fn_body, fn_name, append_hash_to_fn_name, nopers, 444 opers_array, inputs.size(), inputs.data(), 445 outputs.size(), outputs.data(), output_names_ptr, 446 opts, description, out_status); 447 } 448 449 void TF_GraphSetOutputHandleShapesAndTypes_wrapper( 450 TF_Graph* graph, TF_Output output, 451 const std::vector<std::vector<int64_t>>& shapes, 452 const std::vector<int>& ranks, const std::vector<TF_DataType>& types, 453 TF_Status* status) { 454 std::vector<const int64_t*> shapes_pointers(shapes.size()); 455 for (int i = 0; i < shapes.size(); ++i) { 456 shapes_pointers[i] = ranks[i] <= 0 ? nullptr : &shapes[i][0]; 457 } 458 TF_GraphSetOutputHandleShapesAndTypes(graph, output, shapes.size(), 459 shapes_pointers.data(), ranks.data(), 460 types.data(), status); 461 } 462 463 void TF_GraphSetTensorShape_wrapper(TF_Graph* graph, TF_Output output, 464 const std::vector<int64_t>& dims, 465 bool unknown_shape, TF_Status* status) { 466 if (unknown_shape) { 467 TF_GraphSetTensorShape(graph, output, nullptr, -1, status); 468 return; 469 } 470 TF_GraphSetTensorShape(graph, output, dims.data(), dims.size(), status); 471 } 472 473 std::vector<int64_t> TF_GraphGetTensorShape_wrapper(TF_Graph* graph, 474 TF_Output output, 475 int num_dims, 476 TF_Status* status) { 477 std::vector<int64_t> dims(num_dims); 478 TF_GraphGetTensorShape(graph, output, dims.data(), num_dims, status); 479 return dims; 480 } 481 482 std::vector<string> TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper( 483 TF_ImportGraphDefResults* results) { 484 int num_missing_unused_input_mappings; 485 const char** src_names; 486 int* src_indexes; 487 TF_ImportGraphDefResultsMissingUnusedInputMappings( 488 results, &num_missing_unused_input_mappings, &src_names, &src_indexes); 489 std::vector<string> input_strs(num_missing_unused_input_mappings); 490 for (int i = 0; i < num_missing_unused_input_mappings; ++i) { 491 input_strs[i] = TensorId(src_names[i], src_indexes[i]).ToString(); 492 } 493 return input_strs; 494 } 495 496 } // namespace tensorflow 497