Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_KERNELS_SAVE_RESTORE_TENSOR_H_
     17 #define TENSORFLOW_KERNELS_SAVE_RESTORE_TENSOR_H_
     18 
     19 #include "tensorflow/core/util/tensor_slice_reader.h"
     20 #include "tensorflow/core/util/tensor_slice_writer.h"
     21 
     22 namespace tensorflow {
     23 
     24 class OpKernelContext;
     25 
     26 // Legacy / V1 checkpoint format.
     27 
     28 // Save input tensors in *context to a writer built from builder_func().
     29 // context must have the following inputs:
     30 //  0: a single element string tensor that contains the file name.
     31 //  1: names for the remaining tensors
     32 // If save_slices is true:
     33 //  2: shape and slice specifications.
     34 //  rest: tensors to save
     35 void SaveTensors(
     36     OpKernelContext* context,
     37     checkpoint::TensorSliceWriter::CreateBuilderFunction builder_func,
     38     bool save_slices);
     39 
     40 // Reads a single tensor from the reader built from open_func() and produces
     41 // it as context->output(restore_index).  "preferred_shard" is the same the
     42 // TensorSliceReader preferred_shard parameter.
     43 //
     44 // context must have the following inputs:
     45 //  0: a single element string tensor that contains the file name.
     46 //  1: string tensor that names the outputs to be restored.
     47 // If restore_slice is true:
     48 //  2: shape and slice specification of the tensors to restore.
     49 //
     50 // restore_index indicates the variable name and slice to lookup
     51 // in context(1) and (2).
     52 void RestoreTensor(OpKernelContext* context,
     53                    checkpoint::TensorSliceReader::OpenTableFunction open_func,
     54                    int preferred_shard, bool restore_slice, int restore_index);
     55 
     56 // V2 checkpoint format.
     57 
     58 // Invokes the V2 checkpoint read path to read tensors.
     59 //
     60 // "context" is only used for allocating outputs.  In particular, the inputs are
     61 // explicitly provided and not accessed via the "input(i)" methods.
     62 // REQUIRES:
     63 //   * "prefix" has 1 element, DT_STRING.
     64 //   * "tensor_names" and "shape_and_slices" shaped {N}, both DT_STRING.
     65 //   * "dtypes" has N elements, the datatypes of the to-restore tensors.
     66 Status RestoreTensorsV2(OpKernelContext* context, const Tensor& prefix,
     67                         const Tensor& tensor_names,
     68                         const Tensor& shape_and_slices,
     69                         gtl::ArraySlice<DataType> dtypes);
     70 
     71 }  // namespace tensorflow
     72 
     73 #endif  // TENSORFLOW_KERNELS_SAVE_RESTORE_TENSOR_H_
     74