Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/kernels/save_restore_tensor.h"
     17 #include <numeric>
     18 #include <unordered_map>
     19 #include <utility>
     20 #include <vector>
     21 
     22 #include "tensorflow/core/framework/op_kernel.h"
     23 #include "tensorflow/core/framework/register_types.h"
     24 #include "tensorflow/core/framework/types.h"
     25 #include "tensorflow/core/kernels/bounds_check.h"
     26 #include "tensorflow/core/lib/gtl/array_slice.h"
     27 #include "tensorflow/core/lib/strings/strcat.h"
     28 #include "tensorflow/core/lib/strings/stringprintf.h"
     29 #include "tensorflow/core/platform/logging.h"
     30 #include "tensorflow/core/platform/types.h"
     31 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
     32 #include "tensorflow/core/util/tensor_slice_reader.h"
     33 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
     34 #include "tensorflow/core/util/tensor_slice_writer.h"
     35 
     36 namespace tensorflow {
     37 
     38 void SaveTensors(
     39     OpKernelContext* context,
     40     checkpoint::TensorSliceWriter::CreateBuilderFunction builder_func,
     41     bool save_slices) {
     42   const Tensor& filename_t = context->input(0);
     43   {
     44     const int64 size = filename_t.NumElements();
     45     OP_REQUIRES(
     46         context, size == 1,
     47         errors::InvalidArgument(
     48             "Input 0 (filename) must be a string scalar; got a tensor of ",
     49             size, "elements"));
     50   }
     51 
     52   // Path, names, and slices if save_slices is true.
     53   const int kFixedInputs = save_slices ? 3 : 2;
     54   const Tensor& tensor_names_t = context->input(1);
     55   OP_REQUIRES(context,
     56               FastBoundsCheck(tensor_names_t.NumElements() + kFixedInputs,
     57                               std::numeric_limits<int>::max()),
     58               errors::InvalidArgument("Too many inputs to SaveTensors"));
     59   const int N = static_cast<int>(tensor_names_t.NumElements());
     60   const string* tensor_shapes_and_slices_ptr = nullptr;
     61   if (save_slices) {
     62     const Tensor& tensor_shapes_and_slices_t = context->input(2);
     63     OP_REQUIRES(
     64         context,
     65         tensor_shapes_and_slices_t.NumElements() == static_cast<int64>(N),
     66         errors::InvalidArgument("Expected ", N,
     67                                 " elements for the tensor "
     68                                 "shapes and slices but got ",
     69                                 tensor_shapes_and_slices_t.NumElements()));
     70     tensor_shapes_and_slices_ptr =
     71         tensor_shapes_and_slices_t.flat<string>().data();
     72   }
     73   OP_REQUIRES(context, context->num_inputs() == N + kFixedInputs,
     74               errors::InvalidArgument("Expected totally ", N + kFixedInputs,
     75                                       " inputs as input #1 (which is a string "
     76                                       "tensor of saved names) contains ",
     77                                       N, " names, but received ",
     78                                       context->num_inputs(), " inputs"));
     79 
     80   VLOG(1) << "About to save tensors to file " << filename_t.flat<string>()(0)
     81           << "...";
     82   checkpoint::TensorSliceWriter writer(filename_t.flat<string>()(0),
     83                                        std::move(builder_func));
     84 
     85   Status s;
     86   auto tensor_names_flat = tensor_names_t.flat<string>();
     87 
     88   // Process tensors in sorted name order.  This allows us to avoid seeking
     89   // during restoration in the common case where we are restoring a full
     90   // checkpoint.
     91   std::vector<size_t> sorted_name_idx(tensor_names_flat.size());
     92   std::iota(sorted_name_idx.begin(), sorted_name_idx.end(), 0);
     93   std::sort(sorted_name_idx.begin(), sorted_name_idx.end(),
     94             [&tensor_names_flat](size_t a, size_t b) {
     95               return tensor_names_flat(a) < tensor_names_flat(b);
     96             });
     97 
     98   for (size_t i : sorted_name_idx) {
     99     const string& name = tensor_names_flat(i);
    100     const Tensor& input = context->input(i + kFixedInputs);
    101     TensorShape shape(input.shape());
    102     TensorSlice slice(input.dims());
    103     if (save_slices && !tensor_shapes_and_slices_ptr[i].empty()) {
    104       const string& shape_spec = tensor_shapes_and_slices_ptr[i];
    105       TensorShape slice_shape;
    106       OP_REQUIRES_OK(context, checkpoint::ParseShapeAndSlice(
    107                                   shape_spec, &shape, &slice, &slice_shape));
    108       OP_REQUIRES(context, slice_shape.IsSameSize(input.shape()),
    109                   errors::InvalidArgument(
    110                       "Slice in shape_and_slice "
    111                       "specification does not match the "
    112                       "shape of the tensor to  save: ",
    113                       shape_spec, ", tensor: ", input.shape().DebugString()));
    114     }
    115 
    116 #define WRITER_ADD(T)                                           \
    117   case DataTypeToEnum<T>::value:                                \
    118     s = writer.Add(name, shape, slice, input.flat<T>().data()); \
    119     break;
    120 
    121     switch (input.dtype()) {
    122       TF_CALL_SAVE_RESTORE_TYPES(WRITER_ADD)
    123       default:
    124         context->SetStatus(errors::Unimplemented("Saving data type ",
    125                                                  DataTypeString(input.dtype()),
    126                                                  " not yet supported"));
    127         return;
    128     }
    129 #undef WRITER_ADD
    130     if (!s.ok()) {
    131       context->SetStatus(s);
    132       return;
    133     }
    134   }
    135 
    136   s = writer.Finish();
    137   if (!s.ok()) {
    138     context->SetStatus(s);
    139   }
    140 }
    141 
    142 void RestoreTensor(OpKernelContext* context,
    143                    checkpoint::TensorSliceReader::OpenTableFunction open_func,
    144                    int preferred_shard, bool restore_slice, int restore_index) {
    145   const Tensor& file_pattern_t = context->input(0);
    146   {
    147     const int64 size = file_pattern_t.NumElements();
    148     OP_REQUIRES(
    149         context, size == 1,
    150         errors::InvalidArgument(
    151             "Input 0 (file_pattern) must be a string scalar; got a tensor of ",
    152             size, "elements"));
    153   }
    154   const string& file_pattern = file_pattern_t.flat<string>()(0);
    155 
    156   const Tensor& tensor_name_t = context->input(1);
    157   const string& tensor_name = tensor_name_t.flat<string>()(restore_index);
    158 
    159   // If we cannot find a cached reader we will allocate our own.
    160   std::unique_ptr<checkpoint::TensorSliceReader> allocated_reader;
    161 
    162   const checkpoint::TensorSliceReader* reader =
    163       context->slice_reader_cache()->GetReader(file_pattern, open_func,
    164                                                preferred_shard);
    165   if (!reader) {
    166     allocated_reader.reset(new checkpoint::TensorSliceReader(
    167         file_pattern, open_func, preferred_shard));
    168     reader = allocated_reader.get();
    169   }
    170   OP_REQUIRES_OK(context, CHECK_NOTNULL(reader)->status());
    171 
    172   // Get the shape and type from the save file.
    173   DataType type;
    174   TensorShape saved_shape;
    175   OP_REQUIRES(
    176       context, reader->HasTensor(tensor_name, &saved_shape, &type),
    177       errors::NotFound("Tensor name \"", tensor_name,
    178                        "\" not found in checkpoint files ", file_pattern));
    179   OP_REQUIRES(
    180       context, type == context->expected_output_dtype(restore_index),
    181       errors::InvalidArgument("Expected to restore a tensor of type ",
    182                               DataTypeString(context->expected_output_dtype(0)),
    183                               ", got a tensor of type ", DataTypeString(type),
    184                               " instead: tensor_name = ", tensor_name));
    185 
    186   // Shape of the output and slice to load.
    187   TensorShape output_shape(saved_shape);
    188   TensorSlice slice_to_load(saved_shape.dims());
    189   if (restore_slice) {
    190     const string& shape_spec = context->input(2).flat<string>()(restore_index);
    191     if (!shape_spec.empty()) {
    192       TensorShape parsed_shape;
    193       OP_REQUIRES_OK(context, checkpoint::ParseShapeAndSlice(
    194                                   shape_spec, &parsed_shape, &slice_to_load,
    195                                   &output_shape));
    196       OP_REQUIRES(
    197           context, parsed_shape.IsSameSize(saved_shape),
    198           errors::InvalidArgument(
    199               "Shape in shape_and_slice spec does not match the shape in the "
    200               "save file: ",
    201               parsed_shape.DebugString(),
    202               ", save file shape: ", saved_shape.DebugString()));
    203     }
    204   }
    205 
    206   Tensor* t = nullptr;
    207   OP_REQUIRES_OK(context,
    208                  context->allocate_output(restore_index, output_shape, &t));
    209 
    210   if (output_shape.num_elements() == 0) return;
    211 
    212 #define READER_COPY(T)                                                \
    213   case DataTypeToEnum<T>::value:                                      \
    214     OP_REQUIRES(context,                                              \
    215                 reader->CopySliceData(tensor_name, slice_to_load,     \
    216                                       t->flat<T>().data()),           \
    217                 errors::InvalidArgument("Error copying slice data")); \
    218     break;
    219 
    220   switch (type) {
    221     TF_CALL_SAVE_RESTORE_TYPES(READER_COPY)
    222     default:
    223       context->SetStatus(errors::Unimplemented(
    224           "Restoring data type ", DataTypeString(type), " not yet supported"));
    225   }
    226 #undef READER_COPY
    227 }
    228 
    229 Status RestoreTensorsV2(OpKernelContext* context, const Tensor& prefix,
    230                         const Tensor& tensor_names,
    231                         const Tensor& shape_and_slices,
    232                         gtl::ArraySlice<DataType> dtypes) {
    233   const string& prefix_string = prefix.scalar<string>()();
    234 
    235   const auto& tensor_names_flat = tensor_names.flat<string>();
    236   const auto& shape_and_slices_flat = shape_and_slices.flat<string>();
    237 
    238   // Sort lookup keys to improve locality when reading multiple tensors.
    239   std::vector<size_t> sorted_name_idx(tensor_names_flat.size());
    240   std::iota(sorted_name_idx.begin(), sorted_name_idx.end(), 0);
    241   std::sort(sorted_name_idx.begin(), sorted_name_idx.end(),
    242             [&tensor_names_flat](size_t a, size_t b) {
    243               return tensor_names_flat(a) < tensor_names_flat(b);
    244             });
    245 
    246   BundleReader reader(Env::Default(), prefix_string);
    247   TF_RETURN_IF_ERROR(reader.status());
    248 
    249   // TODO(zongheng): potential optimization: one Seek() in first lookup.
    250   // TODO(zongheng): consider measuring speed and issuing concurrent lookups
    251   // within a fixed memory budget.
    252   TensorShape restored_full_shape;
    253   Tensor* restored_tensor = nullptr;
    254   for (auto i : sorted_name_idx) {
    255     const string& tensor_name = tensor_names_flat(i);
    256     const string& shape_and_slice = shape_and_slices_flat(i);
    257 
    258     TF_RETURN_IF_ERROR(
    259         reader.LookupTensorShape(tensor_name, &restored_full_shape));
    260 
    261     if (shape_and_slice.empty()) {
    262       // Lookup the full tensor.
    263       TF_RETURN_IF_ERROR(
    264           context->allocate_output(i, restored_full_shape, &restored_tensor));
    265       TF_RETURN_IF_ERROR(reader.Lookup(tensor_name, restored_tensor));
    266     } else {
    267       // Lookup the slice.
    268       TensorShape parsed_full_shape;
    269       TensorSlice parsed_slice;
    270       TensorShape parsed_slice_shape;
    271 
    272       TF_RETURN_IF_ERROR(
    273           checkpoint::ParseShapeAndSlice(shape_and_slice, &parsed_full_shape,
    274                                          &parsed_slice, &parsed_slice_shape));
    275       if (!restored_full_shape.IsSameSize(parsed_full_shape)) {
    276         return errors::InvalidArgument(
    277             "tensor_name = ", tensor_name, "; shape in shape_and_slice spec ",
    278             parsed_full_shape.DebugString(),
    279             " does not match the shape stored in checkpoint: ",
    280             restored_full_shape.DebugString());
    281       }
    282 
    283       TF_RETURN_IF_ERROR(
    284           context->allocate_output(i, parsed_slice_shape, &restored_tensor));
    285       TF_RETURN_IF_ERROR(
    286           reader.LookupSlice(tensor_name, parsed_slice, restored_tensor));
    287     }
    288     if (dtypes[i] != restored_tensor->dtype()) {
    289       return errors::InvalidArgument(
    290           "tensor_name = ", tensor_name, "; expected dtype ",
    291           DataTypeString(dtypes[i]), " does not equal restored dtype ",
    292           DataTypeString(restored_tensor->dtype()));
    293     }
    294   }
    295   return Status::OK();
    296 }
    297 
    298 }  // namespace tensorflow
    299