1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/io_ops.cc. 17 18 #include <string> 19 #include <vector> 20 21 #include "tensorflow/core/framework/op_kernel.h" 22 #include "tensorflow/core/framework/tensor.h" 23 #include "tensorflow/core/framework/types.h" 24 #include "tensorflow/core/framework/types.pb.h" 25 #include "tensorflow/core/kernels/bounds_check.h" 26 #include "tensorflow/core/kernels/save_restore_tensor.h" 27 #include "tensorflow/core/lib/core/status.h" 28 #include "tensorflow/core/lib/io/path.h" 29 #include "tensorflow/core/platform/env.h" 30 #include "tensorflow/core/platform/logging.h" 31 #include "tensorflow/core/platform/types.h" 32 #include "tensorflow/core/util/saved_tensor_slice_util.h" 33 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h" 34 #include "tensorflow/core/util/tensor_slice_reader.h" 35 36 namespace tensorflow { 37 38 namespace { 39 40 // Shared validations of the inputs to the SaveV2 and RestoreV2 ops. 41 void ValidateInputs(bool is_save_op, OpKernelContext* context, 42 const Tensor& prefix, const Tensor& tensor_names, 43 const Tensor& shape_and_slices) { 44 const int kFixedInputs = 3; // Prefix, tensor names, shape_and_slices. 45 const int num_tensors = static_cast<int>(tensor_names.NumElements()); 46 OP_REQUIRES( 47 context, prefix.NumElements() == 1, 48 errors::InvalidArgument("Input prefix should have a single element, got ", 49 prefix.NumElements(), " instead.")); 50 OP_REQUIRES(context, 51 TensorShapeUtils::IsVector(tensor_names.shape()) && 52 TensorShapeUtils::IsVector(shape_and_slices.shape()), 53 errors::InvalidArgument( 54 "Input tensor_names and shape_and_slices " 55 "should be an 1-D tensors, got ", 56 tensor_names.shape().DebugString(), " and ", 57 shape_and_slices.shape().DebugString(), " instead.")); 58 OP_REQUIRES(context, 59 tensor_names.NumElements() == shape_and_slices.NumElements(), 60 errors::InvalidArgument("tensor_names and shape_and_slices " 61 "have different number of elements: ", 62 tensor_names.NumElements(), " vs. ", 63 shape_and_slices.NumElements())); 64 OP_REQUIRES(context, 65 FastBoundsCheck(tensor_names.NumElements() + kFixedInputs, 66 std::numeric_limits<int>::max()), 67 errors::InvalidArgument("Too many inputs to the op")); 68 OP_REQUIRES( 69 context, shape_and_slices.NumElements() == num_tensors, 70 errors::InvalidArgument("Expected ", num_tensors, 71 " elements in shapes_and_slices, but got ", 72 context->input(2).NumElements())); 73 if (is_save_op) { 74 OP_REQUIRES(context, context->num_inputs() == num_tensors + kFixedInputs, 75 errors::InvalidArgument( 76 "Got ", num_tensors, " tensor names but ", 77 context->num_inputs() - kFixedInputs, " tensors.")); 78 OP_REQUIRES(context, context->num_inputs() == num_tensors + kFixedInputs, 79 errors::InvalidArgument( 80 "Expected a total of ", num_tensors + kFixedInputs, 81 " inputs as input #1 (which is a string " 82 "tensor of saved names) contains ", 83 num_tensors, " names, but received ", context->num_inputs(), 84 " inputs")); 85 } 86 } 87 88 } // namespace 89 90 // Saves a list of named tensors using the tensor bundle library. 91 class SaveV2 : public OpKernel { 92 public: 93 explicit SaveV2(OpKernelConstruction* context) : OpKernel(context) {} 94 95 void Compute(OpKernelContext* context) override { 96 const Tensor& prefix = context->input(0); 97 const Tensor& tensor_names = context->input(1); 98 const Tensor& shape_and_slices = context->input(2); 99 ValidateInputs(true /* is save op */, context, prefix, tensor_names, 100 shape_and_slices); 101 102 const int kFixedInputs = 3; // Prefix, tensor names, shape_and_slices. 103 const int num_tensors = static_cast<int>(tensor_names.NumElements()); 104 const string& prefix_string = prefix.scalar<string>()(); 105 const auto& tensor_names_flat = tensor_names.flat<string>(); 106 const auto& shape_and_slices_flat = shape_and_slices.flat<string>(); 107 108 BundleWriter writer(Env::Default(), prefix_string); 109 OP_REQUIRES_OK(context, writer.status()); 110 VLOG(1) << "BundleWriter, prefix_string: " << prefix_string; 111 112 for (int i = 0; i < num_tensors; ++i) { 113 const string& tensor_name = tensor_names_flat(i); 114 const Tensor& tensor = context->input(i + kFixedInputs); 115 116 if (!shape_and_slices_flat(i).empty()) { 117 const string& shape_spec = shape_and_slices_flat(i); 118 TensorShape shape; 119 TensorSlice slice(tensor.dims()); 120 TensorShape slice_shape; 121 122 OP_REQUIRES_OK(context, checkpoint::ParseShapeAndSlice( 123 shape_spec, &shape, &slice, &slice_shape)); 124 OP_REQUIRES(context, slice_shape.IsSameSize(tensor.shape()), 125 errors::InvalidArgument("Slice in shape_and_slice " 126 "specification does not match the " 127 "shape of the tensor to save: ", 128 shape_spec, ", tensor: ", 129 tensor.shape().DebugString())); 130 131 OP_REQUIRES_OK(context, 132 writer.AddSlice(tensor_name, shape, slice, tensor)); 133 } else { 134 OP_REQUIRES_OK(context, writer.Add(tensor_name, tensor)); 135 } 136 } 137 OP_REQUIRES_OK(context, writer.Finish()); 138 } 139 }; 140 REGISTER_KERNEL_BUILDER(Name("SaveV2").Device(DEVICE_CPU), SaveV2); 141 142 // Restores a list of named tensors from a tensor bundle (V2 checkpoint format). 143 class RestoreV2 : public OpKernel { 144 public: 145 explicit RestoreV2(OpKernelConstruction* context) : OpKernel(context) { 146 OP_REQUIRES_OK(context, context->GetAttr("dtypes", &dtypes_)); 147 } 148 149 void Compute(OpKernelContext* context) override { 150 const Tensor& prefix = context->input(0); 151 const Tensor& tensor_names = context->input(1); 152 const Tensor& shape_and_slices = context->input(2); 153 OP_REQUIRES(context, tensor_names.NumElements() == dtypes_.size(), 154 errors::InvalidArgument("Got ", tensor_names.NumElements(), 155 " tensor names, but ", dtypes_.size(), 156 " expected dtypes.")); 157 ValidateInputs(false /* not save op */, context, prefix, tensor_names, 158 shape_and_slices); 159 160 const string& prefix_string = prefix.scalar<string>()(); 161 162 // Intention: we plan to use the RestoreV2 op as a backward-compatible 163 // reader as we upgrade to the V2 format. This allows transparent upgrade. 164 // We here attempt to read a V1 checkpoint, if "prefix_string" does not 165 // refer to a V2 checkpoint. 166 Env* env = Env::Default(); 167 std::vector<string> paths; 168 if (!env->GetMatchingPaths(MetaFilename(prefix_string), &paths).ok() || 169 paths.empty()) { 170 // Cannot find V2's metadata file, so "prefix_string" does not point to a 171 // V2 checkpoint. Invokes the V1 read path instead. 172 for (size_t i = 0; i < tensor_names.NumElements(); ++i) { 173 RestoreTensor(context, &checkpoint::OpenTableTensorSliceReader, 174 /* preferred_shard */ -1, /* restore_slice */ true, 175 /* restore_index */ i); 176 if (!context->status().ok()) { 177 return; 178 } 179 } 180 return; 181 } 182 // If found, invokes the V2 reader. 183 OP_REQUIRES_OK(context, RestoreTensorsV2(context, prefix, tensor_names, 184 shape_and_slices, dtypes_)); 185 } 186 187 private: 188 // Expected dtypes of the to-restore tensors. 189 std::vector<DataType> dtypes_; 190 }; 191 REGISTER_KERNEL_BUILDER(Name("RestoreV2").Device(DEVICE_CPU), RestoreV2); 192 193 // The final step in saving sharded V2 checkpoints: merges metadata files. 194 class MergeV2Checkpoints : public OpKernel { 195 public: 196 explicit MergeV2Checkpoints(OpKernelConstruction* context) 197 : OpKernel(context) { 198 OP_REQUIRES_OK(context, 199 context->GetAttr("delete_old_dirs", &delete_old_dirs_)); 200 } 201 202 void Compute(OpKernelContext* context) override { 203 const Tensor& checkpoint_prefixes = context->input(0); 204 const Tensor& destination_prefix = context->input(1); 205 OP_REQUIRES(context, 206 TensorShapeUtils::IsVector(checkpoint_prefixes.shape()), 207 errors::InvalidArgument( 208 "Input checkpoint_prefixes should be an 1-D tensor, got ", 209 checkpoint_prefixes.shape().DebugString(), " instead.")); 210 OP_REQUIRES(context, TensorShapeUtils::IsScalar(destination_prefix.shape()), 211 errors::InvalidArgument( 212 "Input destination_prefix should be a scalar tensor, got ", 213 destination_prefix.shape().DebugString(), " instead.")); 214 215 const gtl::ArraySlice<string> input_prefixes = 216 gtl::ArraySlice<string>(checkpoint_prefixes.flat<string>()); 217 Env* env = Env::Default(); 218 const string& merged_prefix = destination_prefix.scalar<string>()(); 219 OP_REQUIRES_OK( 220 context, tensorflow::MergeBundles(env, input_prefixes, merged_prefix)); 221 222 if (delete_old_dirs_) { 223 const string& merged_dir = io::Dirname(merged_prefix).ToString(); 224 for (const string& input_prefix : input_prefixes) { 225 const string& dirname = io::Dirname(input_prefix).ToString(); 226 if (dirname == merged_dir) continue; 227 Status status = env->DeleteDir(dirname); 228 // For sharded save, only the first delete will go through and all 229 // others will hit NotFound. Use vlog to be less verbose. 230 if (!status.ok()) VLOG(1) << status; 231 } 232 } 233 } 234 235 private: 236 // On merge, whether or not to delete the input (temporary) directories. 237 bool delete_old_dirs_; 238 }; 239 REGISTER_KERNEL_BUILDER(Name("MergeV2Checkpoints").Device(DEVICE_CPU), 240 MergeV2Checkpoints); 241 242 } // namespace tensorflow 243