Home | History | Annotate | Download | only in util
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/util/tensor_slice_set.h"
     17 
     18 #include <vector>
     19 #include "tensorflow/core/lib/core/errors.h"
     20 #include "tensorflow/core/lib/gtl/map_util.h"
     21 #include "tensorflow/core/platform/logging.h"
     22 #include "tensorflow/core/util/tensor_slice_util.h"
     23 
     24 namespace tensorflow {
     25 
     26 namespace checkpoint {
     27 
     28 TensorSliceSet::TensorSliceSet(const TensorShape& shape, DataType type)
     29     : shape_(shape), type_(type) {}
     30 
     31 TensorSliceSet::~TensorSliceSet() {}
     32 
     33 Status TensorSliceSet::Register(const TensorSlice& slice, const string& tag,
     34                                 const float* data) {
     35   TensorShape result_shape;
     36   TF_RETURN_IF_ERROR(slice.SliceTensorShape(shape_, &result_shape));
     37   string str = slice.DebugString();
     38 
     39   if (slices_.empty()) {
     40     slices_hull_ = slice;
     41   } else {
     42     // We check if there is any intersection between this slice and any of the
     43     // registered slices.
     44     if (slices_hull_.Overlaps(slice)) {
     45       for (const auto& x : slices_) {
     46         if (slice.Overlaps(x.second.slice)) {
     47           return errors::Internal("Overlapping slices: existing slice = ",
     48                                   x.first, ", new slice = ", str);
     49         }
     50       }
     51     }
     52     // No overlap: we can now insert the slice
     53     slices_hull_.UpdateToCover(slice);
     54   }
     55 
     56   TensorSliceSet::SliceInfo info = {slice, tag, data,
     57                                     result_shape.num_elements()};
     58   slices_.insert(std::make_pair(str, info));
     59   return Status::OK();
     60 }
     61 
     62 // TODO(yangke): merge Query() with QueryMeta()
     63 bool TensorSliceSet::Query(const TensorSlice& slice, float* data) const {
     64   Status s;
     65   string str = slice.DebugString();
     66   // First we check if there is an exactly match (this is the dominant case).
     67   const TensorSliceSet::SliceInfo* info = gtl::FindOrNull(slices_, str);
     68   if (info) {
     69     if (data) {
     70       std::copy_n(info->data, info->num_floats, data);
     71     }
     72     return true;
     73   } else {
     74     // We didn't find any exact match but there is still a possibility that
     75     // multiple existing slices can be patched together to output the slice.
     76     // We figure this out by computing the intersection of each of the existing
     77     // slices with the query slice, and check if the union of all these
     78     // intersections cover the entire slice. We rely on the fact that the
     79     // existing slices don't have any intersection among themselves.
     80     TensorShape target_shape;
     81     Status s;
     82     s = slice.SliceTensorShape(shape_, &target_shape);
     83     if (!s.ok()) {
     84       LOG(WARNING) << s;
     85       return false;
     86     }
     87     int64 total_size = target_shape.num_elements();
     88 
     89     int64 overlap_size = 0;
     90     TensorSlice intersection;
     91     TensorShape inter_shape;
     92     for (const auto& x : slices_) {
     93       if (slice.Intersect(x.second.slice, &intersection)) {
     94         s = intersection.SliceTensorShape(shape_, &inter_shape);
     95         if (!s.ok()) {
     96           LOG(WARNING) << s;
     97           return false;
     98         }
     99         overlap_size += inter_shape.num_elements();
    100       }
    101     }
    102     if (total_size == overlap_size) {
    103       // We have it!
    104       // Now we need to copy the data to "data"
    105       if (data) {
    106         for (const auto& x : slices_) {
    107           CopyDataFromTensorSliceToTensorSlice(shape_, x.second.slice, slice,
    108                                                x.second.data, data);
    109         }
    110       }
    111       return true;
    112     } else {
    113       // We don't have all the data for the asked tensor slice
    114       return false;
    115     }
    116   }
    117 }
    118 
    119 bool TensorSliceSet::QueryMeta(
    120     const TensorSlice& slice,
    121     std::vector<std::pair<TensorSlice, string>>* results) const {
    122   results->clear();
    123   Status s;
    124   string str = slice.DebugString();
    125   // First we check if there is an exactly match (this is the dominant case).
    126   const TensorSliceSet::SliceInfo* info = gtl::FindOrNull(slices_, str);
    127   if (info) {
    128     results->emplace_back(std::make_pair(info->slice, info->tag));
    129     return true;
    130   } else {
    131     // We didn't find any exact match but there is still a possibility that
    132     // multiple existing slices can be patched together to output the slice.
    133     // We figure this out by computing the intersection of each of the existing
    134     // slices with the query slice, and check if the union of all these
    135     // intersections cover the entire slice. We rely on the fact that the
    136     // existing slices don't have any intersection among themselves.
    137     TensorShape target_shape;
    138     Status s;
    139     s = slice.SliceTensorShape(shape_, &target_shape);
    140     if (!s.ok()) {
    141       LOG(WARNING) << s;
    142       return false;
    143     }
    144     int64 total_size = target_shape.num_elements();
    145 
    146     int64 overlap_size = 0;
    147     TensorSlice intersection;
    148     TensorShape inter_shape;
    149     for (const auto& x : slices_) {
    150       if (slice.Intersect(x.second.slice, &intersection)) {
    151         s = intersection.SliceTensorShape(shape_, &inter_shape);
    152         if (!s.ok()) {
    153           LOG(WARNING) << s;
    154           return false;
    155         }
    156         overlap_size += inter_shape.num_elements();
    157         results->emplace_back(std::make_pair(x.second.slice, x.second.tag));
    158       }
    159     }
    160     if (total_size == overlap_size) {
    161       // We have it!
    162       return true;
    163     } else {
    164       // We don't have all the data for the asked tensor slice
    165       results->clear();
    166       return false;
    167     }
    168   }
    169 }
    170 
    171 Status RegisterTensorSlice(
    172     const string& name, const TensorShape& shape, DataType type,
    173     const string& tag, const TensorSlice& slice,
    174     std::unordered_map<string, TensorSliceSet*>* tensor_slices) {
    175   DCHECK_NE(tensor_slices, nullptr);
    176   TensorSliceSet* tss = gtl::FindPtrOrNull(*tensor_slices, name);
    177   // Create a tensor slice set if needed
    178   if (!tss) {
    179     tss = new TensorSliceSet(shape, type);
    180     tensor_slices->insert(std::make_pair(name, tss));
    181   } else {
    182     // Check if the shapes match
    183     const TensorShape& tss_shape(tss->shape());
    184     if (!shape.IsSameSize(tss_shape)) {
    185       return errors::Internal("Incompatible tensor shapes detected for tensor ",
    186                               name, ": existing = ", tss_shape.DebugString(),
    187                               ", new = ", shape.DebugString());
    188     }
    189     if (type != tss->type()) {
    190       return errors::Internal("Incompatible tensor types detected for tensor ",
    191                               name,
    192                               ": existing = ", DataTypeString(tss->type()),
    193                               ", new = ", DataTypeString(type));
    194     }
    195   }
    196   // Register the tensor slices without the actual data.
    197   return tss->Register(slice, tag, nullptr);
    198 }
    199 
    200 }  // namespace checkpoint
    201 
    202 }  // namespace tensorflow
    203