1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/kernels/save_restore_tensor.h" 17 #include <numeric> 18 #include <unordered_map> 19 #include <utility> 20 #include <vector> 21 22 #include "tensorflow/core/framework/op_kernel.h" 23 #include "tensorflow/core/framework/register_types.h" 24 #include "tensorflow/core/framework/types.h" 25 #include "tensorflow/core/kernels/bounds_check.h" 26 #include "tensorflow/core/lib/gtl/array_slice.h" 27 #include "tensorflow/core/lib/strings/strcat.h" 28 #include "tensorflow/core/lib/strings/stringprintf.h" 29 #include "tensorflow/core/platform/logging.h" 30 #include "tensorflow/core/platform/types.h" 31 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h" 32 #include "tensorflow/core/util/tensor_slice_reader.h" 33 #include "tensorflow/core/util/tensor_slice_reader_cache.h" 34 #include "tensorflow/core/util/tensor_slice_writer.h" 35 36 namespace tensorflow { 37 38 void SaveTensors( 39 OpKernelContext* context, 40 checkpoint::TensorSliceWriter::CreateBuilderFunction builder_func, 41 bool save_slices) { 42 const Tensor& filename_t = context->input(0); 43 { 44 const int64 size = filename_t.NumElements(); 45 OP_REQUIRES( 46 context, size == 1, 47 errors::InvalidArgument( 48 "Input 0 (filename) must be a string scalar; got a tensor of ", 49 size, "elements")); 50 } 51 52 // Path, names, and slices if save_slices is true. 53 const int kFixedInputs = save_slices ? 3 : 2; 54 const Tensor& tensor_names_t = context->input(1); 55 OP_REQUIRES(context, 56 FastBoundsCheck(tensor_names_t.NumElements() + kFixedInputs, 57 std::numeric_limits<int>::max()), 58 errors::InvalidArgument("Too many inputs to SaveTensors")); 59 const int N = static_cast<int>(tensor_names_t.NumElements()); 60 const string* tensor_shapes_and_slices_ptr = nullptr; 61 if (save_slices) { 62 const Tensor& tensor_shapes_and_slices_t = context->input(2); 63 OP_REQUIRES( 64 context, 65 tensor_shapes_and_slices_t.NumElements() == static_cast<int64>(N), 66 errors::InvalidArgument("Expected ", N, 67 " elements for the tensor " 68 "shapes and slices but got ", 69 tensor_shapes_and_slices_t.NumElements())); 70 tensor_shapes_and_slices_ptr = 71 tensor_shapes_and_slices_t.flat<string>().data(); 72 } 73 OP_REQUIRES(context, context->num_inputs() == N + kFixedInputs, 74 errors::InvalidArgument("Expected totally ", N + kFixedInputs, 75 " inputs as input #1 (which is a string " 76 "tensor of saved names) contains ", 77 N, " names, but received ", 78 context->num_inputs(), " inputs")); 79 80 VLOG(1) << "About to save tensors to file " << filename_t.flat<string>()(0) 81 << "..."; 82 checkpoint::TensorSliceWriter writer(filename_t.flat<string>()(0), 83 std::move(builder_func)); 84 85 Status s; 86 auto tensor_names_flat = tensor_names_t.flat<string>(); 87 88 // Process tensors in sorted name order. This allows us to avoid seeking 89 // during restoration in the common case where we are restoring a full 90 // checkpoint. 91 std::vector<size_t> sorted_name_idx(tensor_names_flat.size()); 92 std::iota(sorted_name_idx.begin(), sorted_name_idx.end(), 0); 93 std::sort(sorted_name_idx.begin(), sorted_name_idx.end(), 94 [&tensor_names_flat](size_t a, size_t b) { 95 return tensor_names_flat(a) < tensor_names_flat(b); 96 }); 97 98 for (size_t i : sorted_name_idx) { 99 const string& name = tensor_names_flat(i); 100 const Tensor& input = context->input(i + kFixedInputs); 101 TensorShape shape(input.shape()); 102 TensorSlice slice(input.dims()); 103 if (save_slices && !tensor_shapes_and_slices_ptr[i].empty()) { 104 const string& shape_spec = tensor_shapes_and_slices_ptr[i]; 105 TensorShape slice_shape; 106 OP_REQUIRES_OK(context, checkpoint::ParseShapeAndSlice( 107 shape_spec, &shape, &slice, &slice_shape)); 108 OP_REQUIRES(context, slice_shape.IsSameSize(input.shape()), 109 errors::InvalidArgument( 110 "Slice in shape_and_slice " 111 "specification does not match the " 112 "shape of the tensor to save: ", 113 shape_spec, ", tensor: ", input.shape().DebugString())); 114 } 115 116 #define WRITER_ADD(T) \ 117 case DataTypeToEnum<T>::value: \ 118 s = writer.Add(name, shape, slice, input.flat<T>().data()); \ 119 break; 120 121 switch (input.dtype()) { 122 TF_CALL_SAVE_RESTORE_TYPES(WRITER_ADD) 123 default: 124 context->SetStatus(errors::Unimplemented("Saving data type ", 125 DataTypeString(input.dtype()), 126 " not yet supported")); 127 return; 128 } 129 #undef WRITER_ADD 130 if (!s.ok()) { 131 context->SetStatus(s); 132 return; 133 } 134 } 135 136 s = writer.Finish(); 137 if (!s.ok()) { 138 context->SetStatus(s); 139 } 140 } 141 142 void RestoreTensor(OpKernelContext* context, 143 checkpoint::TensorSliceReader::OpenTableFunction open_func, 144 int preferred_shard, bool restore_slice, int restore_index) { 145 const Tensor& file_pattern_t = context->input(0); 146 { 147 const int64 size = file_pattern_t.NumElements(); 148 OP_REQUIRES( 149 context, size == 1, 150 errors::InvalidArgument( 151 "Input 0 (file_pattern) must be a string scalar; got a tensor of ", 152 size, "elements")); 153 } 154 const string& file_pattern = file_pattern_t.flat<string>()(0); 155 156 const Tensor& tensor_name_t = context->input(1); 157 const string& tensor_name = tensor_name_t.flat<string>()(restore_index); 158 159 // If we cannot find a cached reader we will allocate our own. 160 std::unique_ptr<checkpoint::TensorSliceReader> allocated_reader; 161 162 const checkpoint::TensorSliceReader* reader = 163 context->slice_reader_cache()->GetReader(file_pattern, open_func, 164 preferred_shard); 165 if (!reader) { 166 allocated_reader.reset(new checkpoint::TensorSliceReader( 167 file_pattern, open_func, preferred_shard)); 168 reader = allocated_reader.get(); 169 } 170 OP_REQUIRES_OK(context, CHECK_NOTNULL(reader)->status()); 171 172 // Get the shape and type from the save file. 173 DataType type; 174 TensorShape saved_shape; 175 OP_REQUIRES( 176 context, reader->HasTensor(tensor_name, &saved_shape, &type), 177 errors::NotFound("Tensor name \"", tensor_name, 178 "\" not found in checkpoint files ", file_pattern)); 179 OP_REQUIRES( 180 context, type == context->expected_output_dtype(restore_index), 181 errors::InvalidArgument("Expected to restore a tensor of type ", 182 DataTypeString(context->expected_output_dtype(0)), 183 ", got a tensor of type ", DataTypeString(type), 184 " instead: tensor_name = ", tensor_name)); 185 186 // Shape of the output and slice to load. 187 TensorShape output_shape(saved_shape); 188 TensorSlice slice_to_load(saved_shape.dims()); 189 if (restore_slice) { 190 const string& shape_spec = context->input(2).flat<string>()(restore_index); 191 if (!shape_spec.empty()) { 192 TensorShape parsed_shape; 193 OP_REQUIRES_OK(context, checkpoint::ParseShapeAndSlice( 194 shape_spec, &parsed_shape, &slice_to_load, 195 &output_shape)); 196 OP_REQUIRES( 197 context, parsed_shape.IsSameSize(saved_shape), 198 errors::InvalidArgument( 199 "Shape in shape_and_slice spec does not match the shape in the " 200 "save file: ", 201 parsed_shape.DebugString(), 202 ", save file shape: ", saved_shape.DebugString())); 203 } 204 } 205 206 Tensor* t = nullptr; 207 OP_REQUIRES_OK(context, 208 context->allocate_output(restore_index, output_shape, &t)); 209 210 if (output_shape.num_elements() == 0) return; 211 212 #define READER_COPY(T) \ 213 case DataTypeToEnum<T>::value: \ 214 OP_REQUIRES(context, \ 215 reader->CopySliceData(tensor_name, slice_to_load, \ 216 t->flat<T>().data()), \ 217 errors::InvalidArgument("Error copying slice data")); \ 218 break; 219 220 switch (type) { 221 TF_CALL_SAVE_RESTORE_TYPES(READER_COPY) 222 default: 223 context->SetStatus(errors::Unimplemented( 224 "Restoring data type ", DataTypeString(type), " not yet supported")); 225 } 226 #undef READER_COPY 227 } 228 229 Status RestoreTensorsV2(OpKernelContext* context, const Tensor& prefix, 230 const Tensor& tensor_names, 231 const Tensor& shape_and_slices, 232 gtl::ArraySlice<DataType> dtypes) { 233 const string& prefix_string = prefix.scalar<string>()(); 234 235 const auto& tensor_names_flat = tensor_names.flat<string>(); 236 const auto& shape_and_slices_flat = shape_and_slices.flat<string>(); 237 238 // Sort lookup keys to improve locality when reading multiple tensors. 239 std::vector<size_t> sorted_name_idx(tensor_names_flat.size()); 240 std::iota(sorted_name_idx.begin(), sorted_name_idx.end(), 0); 241 std::sort(sorted_name_idx.begin(), sorted_name_idx.end(), 242 [&tensor_names_flat](size_t a, size_t b) { 243 return tensor_names_flat(a) < tensor_names_flat(b); 244 }); 245 246 BundleReader reader(Env::Default(), prefix_string); 247 TF_RETURN_IF_ERROR(reader.status()); 248 249 // TODO(zongheng): potential optimization: one Seek() in first lookup. 250 // TODO(zongheng): consider measuring speed and issuing concurrent lookups 251 // within a fixed memory budget. 252 TensorShape restored_full_shape; 253 Tensor* restored_tensor = nullptr; 254 for (auto i : sorted_name_idx) { 255 const string& tensor_name = tensor_names_flat(i); 256 const string& shape_and_slice = shape_and_slices_flat(i); 257 258 TF_RETURN_IF_ERROR( 259 reader.LookupTensorShape(tensor_name, &restored_full_shape)); 260 261 if (shape_and_slice.empty()) { 262 // Lookup the full tensor. 263 TF_RETURN_IF_ERROR( 264 context->allocate_output(i, restored_full_shape, &restored_tensor)); 265 TF_RETURN_IF_ERROR(reader.Lookup(tensor_name, restored_tensor)); 266 } else { 267 // Lookup the slice. 268 TensorShape parsed_full_shape; 269 TensorSlice parsed_slice; 270 TensorShape parsed_slice_shape; 271 272 TF_RETURN_IF_ERROR( 273 checkpoint::ParseShapeAndSlice(shape_and_slice, &parsed_full_shape, 274 &parsed_slice, &parsed_slice_shape)); 275 if (!restored_full_shape.IsSameSize(parsed_full_shape)) { 276 return errors::InvalidArgument( 277 "tensor_name = ", tensor_name, "; shape in shape_and_slice spec ", 278 parsed_full_shape.DebugString(), 279 " does not match the shape stored in checkpoint: ", 280 restored_full_shape.DebugString()); 281 } 282 283 TF_RETURN_IF_ERROR( 284 context->allocate_output(i, parsed_slice_shape, &restored_tensor)); 285 TF_RETURN_IF_ERROR( 286 reader.LookupSlice(tensor_name, parsed_slice, restored_tensor)); 287 } 288 if (dtypes[i] != restored_tensor->dtype()) { 289 return errors::InvalidArgument( 290 "tensor_name = ", tensor_name, "; expected dtype ", 291 DataTypeString(dtypes[i]), " does not equal restored dtype ", 292 DataTypeString(restored_tensor->dtype())); 293 } 294 } 295 return Status::OK(); 296 } 297 298 } // namespace tensorflow 299