Home | History | Annotate | Download | only in util
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // A class to manage slices of a tensor. You can "register" set of slices for a
     17 // tensor and then "query" if we have data for a given slice.
     18 
     19 // TODO(yangke): consider moving it to a more private place so that we don't
     20 // need to expose the API.
     21 
     22 #ifndef TENSORFLOW_UTIL_TENSOR_SLICE_SET_H_
     23 #define TENSORFLOW_UTIL_TENSOR_SLICE_SET_H_
     24 
     25 #include <string>  // for string
     26 #include <unordered_map>
     27 #include <vector>
     28 
     29 #include "tensorflow/core/framework/tensor_shape.h"
     30 #include "tensorflow/core/framework/tensor_slice.h"
     31 #include "tensorflow/core/framework/types.h"
     32 #include "tensorflow/core/lib/core/status.h"       // for Status
     33 #include "tensorflow/core/lib/core/stringpiece.h"  // for StringPiece
     34 #include "tensorflow/core/platform/types.h"        // for int64
     35 
     36 namespace tensorflow {
     37 
     38 namespace checkpoint {
     39 
     40 class TensorSliceSet {
     41  public:
     42   TensorSliceSet(const TensorShape& shape, DataType type);
     43   virtual ~TensorSliceSet();
     44 
     45   const TensorShape& shape() const { return shape_; }
     46   const DataType type() const { return type_; }
     47 
     48   // Register a new slice for the tensor. The "tag" is an arbitrary string
     49   // associated with the slice (in one application it denotes the name of the
     50   // file that contains the slice); the "data" points to the data of the tensor
     51   // slice (it can be a nullptr).
     52   // We don't take the ownership of "data" and the caller needs to make sure
     53   // the data is always available during the life time of the tensor slice set
     54   // if it is not nullptr.
     55   Status Register(const TensorSlice& slice, const string& tag,
     56                   const float* data);
     57 
     58   // Query about a new slice: checks if we have data for "slice" and if we have
     59   // the data and "data" is not nullptr, fill "data" with the slice data. The
     60   // caller needs to make sure "data" point to a large enough buffer.
     61   // TODO(yangke): avoid unnecessary copying by using a core::RefCounted
     62   // pointer.
     63   bool Query(const TensorSlice& slice, float* data) const;
     64 
     65   // Alternative way of querying about a new slice: instead of copying the
     66   // data, it returns a list of meta data about the stored slices that will
     67   // supply data for the slice.
     68   bool QueryMeta(
     69       const TensorSlice& slice,
     70       std::vector<std::pair<tensorflow::TensorSlice, string>>* results) const;
     71 
     72   struct SliceInfo {
     73     TensorSlice slice;
     74     const string tag;
     75     const float* data;
     76     int64 num_floats;
     77   };
     78 
     79   // Returns the map from slice string to SliceInfo.
     80   const std::unordered_map<string, SliceInfo>& Slices() const {
     81     return slices_;
     82   }
     83 
     84  private:
     85   const TensorShape shape_;
     86   const DataType type_;
     87   // We maintain a mapping from the slice string to the slice information.
     88   std::unordered_map<string, SliceInfo> slices_;
     89 
     90   // Minimal slice which contains all presented slices. Used for speeding up
     91   // overlap check when slices are being added consequently.
     92   TensorSlice slices_hull_;
     93 };
     94 
     95 // Registers "slice" in the TensorSliceSet stored in "tensor_slices", under key
     96 // "name".  Other arguments are used for validations.  Does not modify the map
     97 // or its values on non-OK.
     98 // REQUIRES: tensor_slices != nullptr
     99 Status RegisterTensorSlice(
    100     const string& name, const TensorShape& shape, DataType type,
    101     const string& tag, const TensorSlice& slice,
    102     std::unordered_map<string, TensorSliceSet*>* tensor_slices);
    103 
    104 }  // namespace checkpoint
    105 
    106 }  // namespace tensorflow
    107 
    108 #endif  // TENSORFLOW_UTIL_TENSOR_SLICE_SET_H_
    109