1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/util/tensor_slice_set.h" 17 18 #include <vector> 19 #include "tensorflow/core/lib/core/errors.h" 20 #include "tensorflow/core/lib/gtl/map_util.h" 21 #include "tensorflow/core/platform/logging.h" 22 #include "tensorflow/core/util/tensor_slice_util.h" 23 24 namespace tensorflow { 25 26 namespace checkpoint { 27 28 TensorSliceSet::TensorSliceSet(const TensorShape& shape, DataType type) 29 : shape_(shape), type_(type) {} 30 31 TensorSliceSet::~TensorSliceSet() {} 32 33 Status TensorSliceSet::Register(const TensorSlice& slice, const string& tag, 34 const float* data) { 35 TensorShape result_shape; 36 TF_RETURN_IF_ERROR(slice.SliceTensorShape(shape_, &result_shape)); 37 string str = slice.DebugString(); 38 39 if (slices_.empty()) { 40 slices_hull_ = slice; 41 } else { 42 // We check if there is any intersection between this slice and any of the 43 // registered slices. 44 if (slices_hull_.Overlaps(slice)) { 45 for (const auto& x : slices_) { 46 if (slice.Overlaps(x.second.slice)) { 47 return errors::Internal("Overlapping slices: existing slice = ", 48 x.first, ", new slice = ", str); 49 } 50 } 51 } 52 // No overlap: we can now insert the slice 53 slices_hull_.UpdateToCover(slice); 54 } 55 56 TensorSliceSet::SliceInfo info = {slice, tag, data, 57 result_shape.num_elements()}; 58 slices_.insert(std::make_pair(str, info)); 59 return Status::OK(); 60 } 61 62 // TODO(yangke): merge Query() with QueryMeta() 63 bool TensorSliceSet::Query(const TensorSlice& slice, float* data) const { 64 Status s; 65 string str = slice.DebugString(); 66 // First we check if there is an exactly match (this is the dominant case). 67 const TensorSliceSet::SliceInfo* info = gtl::FindOrNull(slices_, str); 68 if (info) { 69 if (data) { 70 std::copy_n(info->data, info->num_floats, data); 71 } 72 return true; 73 } else { 74 // We didn't find any exact match but there is still a possibility that 75 // multiple existing slices can be patched together to output the slice. 76 // We figure this out by computing the intersection of each of the existing 77 // slices with the query slice, and check if the union of all these 78 // intersections cover the entire slice. We rely on the fact that the 79 // existing slices don't have any intersection among themselves. 80 TensorShape target_shape; 81 Status s; 82 s = slice.SliceTensorShape(shape_, &target_shape); 83 if (!s.ok()) { 84 LOG(WARNING) << s; 85 return false; 86 } 87 int64 total_size = target_shape.num_elements(); 88 89 int64 overlap_size = 0; 90 TensorSlice intersection; 91 TensorShape inter_shape; 92 for (const auto& x : slices_) { 93 if (slice.Intersect(x.second.slice, &intersection)) { 94 s = intersection.SliceTensorShape(shape_, &inter_shape); 95 if (!s.ok()) { 96 LOG(WARNING) << s; 97 return false; 98 } 99 overlap_size += inter_shape.num_elements(); 100 } 101 } 102 if (total_size == overlap_size) { 103 // We have it! 104 // Now we need to copy the data to "data" 105 if (data) { 106 for (const auto& x : slices_) { 107 CopyDataFromTensorSliceToTensorSlice(shape_, x.second.slice, slice, 108 x.second.data, data); 109 } 110 } 111 return true; 112 } else { 113 // We don't have all the data for the asked tensor slice 114 return false; 115 } 116 } 117 } 118 119 bool TensorSliceSet::QueryMeta( 120 const TensorSlice& slice, 121 std::vector<std::pair<TensorSlice, string>>* results) const { 122 results->clear(); 123 Status s; 124 string str = slice.DebugString(); 125 // First we check if there is an exactly match (this is the dominant case). 126 const TensorSliceSet::SliceInfo* info = gtl::FindOrNull(slices_, str); 127 if (info) { 128 results->emplace_back(std::make_pair(info->slice, info->tag)); 129 return true; 130 } else { 131 // We didn't find any exact match but there is still a possibility that 132 // multiple existing slices can be patched together to output the slice. 133 // We figure this out by computing the intersection of each of the existing 134 // slices with the query slice, and check if the union of all these 135 // intersections cover the entire slice. We rely on the fact that the 136 // existing slices don't have any intersection among themselves. 137 TensorShape target_shape; 138 Status s; 139 s = slice.SliceTensorShape(shape_, &target_shape); 140 if (!s.ok()) { 141 LOG(WARNING) << s; 142 return false; 143 } 144 int64 total_size = target_shape.num_elements(); 145 146 int64 overlap_size = 0; 147 TensorSlice intersection; 148 TensorShape inter_shape; 149 for (const auto& x : slices_) { 150 if (slice.Intersect(x.second.slice, &intersection)) { 151 s = intersection.SliceTensorShape(shape_, &inter_shape); 152 if (!s.ok()) { 153 LOG(WARNING) << s; 154 return false; 155 } 156 overlap_size += inter_shape.num_elements(); 157 results->emplace_back(std::make_pair(x.second.slice, x.second.tag)); 158 } 159 } 160 if (total_size == overlap_size) { 161 // We have it! 162 return true; 163 } else { 164 // We don't have all the data for the asked tensor slice 165 results->clear(); 166 return false; 167 } 168 } 169 } 170 171 Status RegisterTensorSlice( 172 const string& name, const TensorShape& shape, DataType type, 173 const string& tag, const TensorSlice& slice, 174 std::unordered_map<string, TensorSliceSet*>* tensor_slices) { 175 DCHECK_NE(tensor_slices, nullptr); 176 TensorSliceSet* tss = gtl::FindPtrOrNull(*tensor_slices, name); 177 // Create a tensor slice set if needed 178 if (!tss) { 179 tss = new TensorSliceSet(shape, type); 180 tensor_slices->insert(std::make_pair(name, tss)); 181 } else { 182 // Check if the shapes match 183 const TensorShape& tss_shape(tss->shape()); 184 if (!shape.IsSameSize(tss_shape)) { 185 return errors::Internal("Incompatible tensor shapes detected for tensor ", 186 name, ": existing = ", tss_shape.DebugString(), 187 ", new = ", shape.DebugString()); 188 } 189 if (type != tss->type()) { 190 return errors::Internal("Incompatible tensor types detected for tensor ", 191 name, 192 ": existing = ", DataTypeString(tss->type()), 193 ", new = ", DataTypeString(type)); 194 } 195 } 196 // Register the tensor slices without the actual data. 197 return tss->Register(slice, tag, nullptr); 198 } 199 200 } // namespace checkpoint 201 202 } // namespace tensorflow 203