Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7 http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/io_ops.cc.
     17 
     18 #include <string>
     19 #include <vector>
     20 
     21 #include "tensorflow/core/framework/op_kernel.h"
     22 #include "tensorflow/core/framework/tensor.h"
     23 #include "tensorflow/core/framework/types.h"
     24 #include "tensorflow/core/framework/types.pb.h"
     25 #include "tensorflow/core/kernels/bounds_check.h"
     26 #include "tensorflow/core/kernels/save_restore_tensor.h"
     27 #include "tensorflow/core/lib/core/status.h"
     28 #include "tensorflow/core/lib/io/path.h"
     29 #include "tensorflow/core/platform/env.h"
     30 #include "tensorflow/core/platform/logging.h"
     31 #include "tensorflow/core/platform/types.h"
     32 #include "tensorflow/core/util/saved_tensor_slice_util.h"
     33 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
     34 #include "tensorflow/core/util/tensor_slice_reader.h"
     35 
     36 namespace tensorflow {
     37 
     38 namespace {
     39 
     40 // Shared validations of the inputs to the SaveV2 and RestoreV2 ops.
     41 void ValidateInputs(bool is_save_op, OpKernelContext* context,
     42                     const Tensor& prefix, const Tensor& tensor_names,
     43                     const Tensor& shape_and_slices) {
     44   const int kFixedInputs = 3;  // Prefix, tensor names, shape_and_slices.
     45   const int num_tensors = static_cast<int>(tensor_names.NumElements());
     46   OP_REQUIRES(
     47       context, prefix.NumElements() == 1,
     48       errors::InvalidArgument("Input prefix should have a single element, got ",
     49                               prefix.NumElements(), " instead."));
     50   OP_REQUIRES(context,
     51               TensorShapeUtils::IsVector(tensor_names.shape()) &&
     52                   TensorShapeUtils::IsVector(shape_and_slices.shape()),
     53               errors::InvalidArgument(
     54                   "Input tensor_names and shape_and_slices "
     55                   "should be an 1-D tensors, got ",
     56                   tensor_names.shape().DebugString(), " and ",
     57                   shape_and_slices.shape().DebugString(), " instead."));
     58   OP_REQUIRES(context,
     59               tensor_names.NumElements() == shape_and_slices.NumElements(),
     60               errors::InvalidArgument("tensor_names and shape_and_slices "
     61                                       "have different number of elements: ",
     62                                       tensor_names.NumElements(), " vs. ",
     63                                       shape_and_slices.NumElements()));
     64   OP_REQUIRES(context,
     65               FastBoundsCheck(tensor_names.NumElements() + kFixedInputs,
     66                               std::numeric_limits<int>::max()),
     67               errors::InvalidArgument("Too many inputs to the op"));
     68   OP_REQUIRES(
     69       context, shape_and_slices.NumElements() == num_tensors,
     70       errors::InvalidArgument("Expected ", num_tensors,
     71                               " elements in shapes_and_slices, but got ",
     72                               context->input(2).NumElements()));
     73   if (is_save_op) {
     74     OP_REQUIRES(context, context->num_inputs() == num_tensors + kFixedInputs,
     75                 errors::InvalidArgument(
     76                     "Got ", num_tensors, " tensor names but ",
     77                     context->num_inputs() - kFixedInputs, " tensors."));
     78     OP_REQUIRES(context, context->num_inputs() == num_tensors + kFixedInputs,
     79                 errors::InvalidArgument(
     80                     "Expected a total of ", num_tensors + kFixedInputs,
     81                     " inputs as input #1 (which is a string "
     82                     "tensor of saved names) contains ",
     83                     num_tensors, " names, but received ", context->num_inputs(),
     84                     " inputs"));
     85   }
     86 }
     87 
     88 }  // namespace
     89 
     90 // Saves a list of named tensors using the tensor bundle library.
     91 class SaveV2 : public OpKernel {
     92  public:
     93   explicit SaveV2(OpKernelConstruction* context) : OpKernel(context) {}
     94 
     95   void Compute(OpKernelContext* context) override {
     96     const Tensor& prefix = context->input(0);
     97     const Tensor& tensor_names = context->input(1);
     98     const Tensor& shape_and_slices = context->input(2);
     99     ValidateInputs(true /* is save op */, context, prefix, tensor_names,
    100                    shape_and_slices);
    101 
    102     const int kFixedInputs = 3;  // Prefix, tensor names, shape_and_slices.
    103     const int num_tensors = static_cast<int>(tensor_names.NumElements());
    104     const string& prefix_string = prefix.scalar<string>()();
    105     const auto& tensor_names_flat = tensor_names.flat<string>();
    106     const auto& shape_and_slices_flat = shape_and_slices.flat<string>();
    107 
    108     BundleWriter writer(Env::Default(), prefix_string);
    109     OP_REQUIRES_OK(context, writer.status());
    110     VLOG(1) << "BundleWriter, prefix_string: " << prefix_string;
    111 
    112     for (int i = 0; i < num_tensors; ++i) {
    113       const string& tensor_name = tensor_names_flat(i);
    114       const Tensor& tensor = context->input(i + kFixedInputs);
    115 
    116       if (!shape_and_slices_flat(i).empty()) {
    117         const string& shape_spec = shape_and_slices_flat(i);
    118         TensorShape shape;
    119         TensorSlice slice(tensor.dims());
    120         TensorShape slice_shape;
    121 
    122         OP_REQUIRES_OK(context, checkpoint::ParseShapeAndSlice(
    123                                     shape_spec, &shape, &slice, &slice_shape));
    124         OP_REQUIRES(context, slice_shape.IsSameSize(tensor.shape()),
    125                     errors::InvalidArgument("Slice in shape_and_slice "
    126                                             "specification does not match the "
    127                                             "shape of the tensor to  save: ",
    128                                             shape_spec, ", tensor: ",
    129                                             tensor.shape().DebugString()));
    130 
    131         OP_REQUIRES_OK(context,
    132                        writer.AddSlice(tensor_name, shape, slice, tensor));
    133       } else {
    134         OP_REQUIRES_OK(context, writer.Add(tensor_name, tensor));
    135       }
    136     }
    137     OP_REQUIRES_OK(context, writer.Finish());
    138   }
    139 };
    140 REGISTER_KERNEL_BUILDER(Name("SaveV2").Device(DEVICE_CPU), SaveV2);
    141 
    142 // Restores a list of named tensors from a tensor bundle (V2 checkpoint format).
    143 class RestoreV2 : public OpKernel {
    144  public:
    145   explicit RestoreV2(OpKernelConstruction* context) : OpKernel(context) {
    146     OP_REQUIRES_OK(context, context->GetAttr("dtypes", &dtypes_));
    147   }
    148 
    149   void Compute(OpKernelContext* context) override {
    150     const Tensor& prefix = context->input(0);
    151     const Tensor& tensor_names = context->input(1);
    152     const Tensor& shape_and_slices = context->input(2);
    153     OP_REQUIRES(context, tensor_names.NumElements() == dtypes_.size(),
    154                 errors::InvalidArgument("Got ", tensor_names.NumElements(),
    155                                         " tensor names, but ", dtypes_.size(),
    156                                         " expected dtypes."));
    157     ValidateInputs(false /* not save op */, context, prefix, tensor_names,
    158                    shape_and_slices);
    159 
    160     const string& prefix_string = prefix.scalar<string>()();
    161 
    162     // Intention: we plan to use the RestoreV2 op as a backward-compatible
    163     // reader as we upgrade to the V2 format.  This allows transparent upgrade.
    164     // We here attempt to read a V1 checkpoint, if "prefix_string" does not
    165     // refer to a V2 checkpoint.
    166     Env* env = Env::Default();
    167     std::vector<string> paths;
    168     if (!env->GetMatchingPaths(MetaFilename(prefix_string), &paths).ok() ||
    169         paths.empty()) {
    170       // Cannot find V2's metadata file, so "prefix_string" does not point to a
    171       // V2 checkpoint.  Invokes the V1 read path instead.
    172       for (size_t i = 0; i < tensor_names.NumElements(); ++i) {
    173         RestoreTensor(context, &checkpoint::OpenTableTensorSliceReader,
    174                       /* preferred_shard */ -1, /* restore_slice */ true,
    175                       /* restore_index */ i);
    176         if (!context->status().ok()) {
    177           return;
    178         }
    179       }
    180       return;
    181     }
    182     // If found, invokes the V2 reader.
    183     OP_REQUIRES_OK(context, RestoreTensorsV2(context, prefix, tensor_names,
    184                                              shape_and_slices, dtypes_));
    185   }
    186 
    187  private:
    188   // Expected dtypes of the to-restore tensors.
    189   std::vector<DataType> dtypes_;
    190 };
    191 REGISTER_KERNEL_BUILDER(Name("RestoreV2").Device(DEVICE_CPU), RestoreV2);
    192 
    193 // The final step in saving sharded V2 checkpoints: merges metadata files.
    194 class MergeV2Checkpoints : public OpKernel {
    195  public:
    196   explicit MergeV2Checkpoints(OpKernelConstruction* context)
    197       : OpKernel(context) {
    198     OP_REQUIRES_OK(context,
    199                    context->GetAttr("delete_old_dirs", &delete_old_dirs_));
    200   }
    201 
    202   void Compute(OpKernelContext* context) override {
    203     const Tensor& checkpoint_prefixes = context->input(0);
    204     const Tensor& destination_prefix = context->input(1);
    205     OP_REQUIRES(context,
    206                 TensorShapeUtils::IsVector(checkpoint_prefixes.shape()),
    207                 errors::InvalidArgument(
    208                     "Input checkpoint_prefixes should be an 1-D tensor, got ",
    209                     checkpoint_prefixes.shape().DebugString(), " instead."));
    210     OP_REQUIRES(context, TensorShapeUtils::IsScalar(destination_prefix.shape()),
    211                 errors::InvalidArgument(
    212                     "Input destination_prefix should be a scalar tensor, got ",
    213                     destination_prefix.shape().DebugString(), " instead."));
    214 
    215     const gtl::ArraySlice<string> input_prefixes =
    216         gtl::ArraySlice<string>(checkpoint_prefixes.flat<string>());
    217     Env* env = Env::Default();
    218     const string& merged_prefix = destination_prefix.scalar<string>()();
    219     OP_REQUIRES_OK(
    220         context, tensorflow::MergeBundles(env, input_prefixes, merged_prefix));
    221 
    222     if (delete_old_dirs_) {
    223       const string& merged_dir = io::Dirname(merged_prefix).ToString();
    224       for (const string& input_prefix : input_prefixes) {
    225         const string& dirname = io::Dirname(input_prefix).ToString();
    226         if (dirname == merged_dir) continue;
    227         Status status = env->DeleteDir(dirname);
    228         // For sharded save, only the first delete will go through and all
    229         // others will hit NotFound.  Use vlog to be less verbose.
    230         if (!status.ok()) VLOG(1) << status;
    231       }
    232     }
    233   }
    234 
    235  private:
    236   // On merge, whether or not to delete the input (temporary) directories.
    237   bool delete_old_dirs_;
    238 };
    239 REGISTER_KERNEL_BUILDER(Name("MergeV2Checkpoints").Device(DEVICE_CPU),
    240                         MergeV2Checkpoints);
    241 
    242 }  // namespace tensorflow
    243