1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <functional> 17 #include <memory> 18 19 #include "tensorflow/cc/ops/const_op.h" 20 #include "tensorflow/cc/ops/io_ops.h" 21 #include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" 22 #include "tensorflow/core/framework/allocator.h" 23 #include "tensorflow/core/framework/fake_input.h" 24 #include "tensorflow/core/framework/node_def_builder.h" 25 #include "tensorflow/core/framework/op_kernel.h" 26 #include "tensorflow/core/framework/tensor.h" 27 #include "tensorflow/core/framework/types.h" 28 #include "tensorflow/core/framework/types.pb.h" 29 #include "tensorflow/core/graph/graph_def_builder.h" 30 #include "tensorflow/core/kernels/ops_testutil.h" 31 #include "tensorflow/core/kernels/ops_util.h" 32 #include "tensorflow/core/lib/io/path.h" 33 #include "tensorflow/core/lib/strings/strcat.h" 34 #include "tensorflow/core/platform/test.h" 35 #include "tensorflow/core/platform/test_benchmark.h" 36 #include "tensorflow/core/platform/types.h" 37 #include "tensorflow/core/protobuf/config.pb.h" 38 #include "tensorflow/core/util/tensor_slice_reader.h" 39 40 namespace tensorflow { 41 namespace { 42 43 class SaveOpTest : public OpsTestBase { 44 protected: 45 void MakeOp() { 46 TF_ASSERT_OK( 47 NodeDefBuilder("myop", "Save") 48 .Input(FakeInput()) 49 .Input(FakeInput()) 50 .Input(FakeInput({DT_BOOL, DT_INT32, DT_FLOAT, DT_DOUBLE, DT_QINT8, 51 DT_QINT32, DT_UINT8, DT_INT8, DT_INT16, DT_INT64, 52 DT_STRING, DT_COMPLEX64, DT_COMPLEX128, DT_HALF})) 53 .Finalize(node_def())); 54 TF_ASSERT_OK(InitOp()); 55 } 56 }; 57 58 TEST_F(SaveOpTest, Simple) { 59 const string filename = io::JoinPath(testing::TmpDir(), "tensor_simple"); 60 const string tensornames[] = { 61 "tensor_bool", "tensor_int", "tensor_float", "tensor_double", 62 "tensor_qint8", "tensor_qint32", "tensor_uint8", "tensor_int8", 63 "tensor_int16", "tensor_int64", "tensor_string", "tensor_complex64", 64 "tensor_complex128", "tensor_half"}; 65 66 MakeOp(); 67 // Add a file name 68 AddInput<string>(TensorShape({}), 69 [&filename](int x) -> string { return filename; }); 70 71 // Add the tensor names 72 AddInput<string>(TensorShape({14}), 73 [&tensornames](int x) -> string { return tensornames[x]; }); 74 75 // Add a 1-d bool tensor 76 AddInput<bool>(TensorShape({2}), [](int x) -> bool { return x != 0; }); 77 78 // Add a 1-d integer tensor 79 AddInput<int32>(TensorShape({10}), [](int x) -> int32 { return x + 1; }); 80 81 // Add a 2-d float tensor 82 AddInput<float>(TensorShape({2, 4}), 83 [](int x) -> float { return static_cast<float>(x) / 10; }); 84 85 // Add a 2-d double tensor 86 AddInput<double>(TensorShape({2, 4}), 87 [](int x) -> double { return static_cast<double>(x) / 20; }); 88 89 // Add a 2-d qint8 tensor 90 AddInput<qint8>(TensorShape({3, 2}), 91 [](int x) -> qint8 { return *reinterpret_cast<qint8*>(&x); }); 92 93 // Add a 2-d qint32 tensor 94 AddInput<qint32>(TensorShape({2, 3}), [](int x) -> qint32 { 95 return *reinterpret_cast<qint32*>(&x) * qint8(2); 96 }); 97 98 // Add a 1-d uint8 tensor 99 AddInput<uint8>(TensorShape({11}), [](int x) -> uint8 { return x + 1; }); 100 101 // Add a 1-d int8 tensor 102 AddInput<int8>(TensorShape({7}), [](int x) -> int8 { return x - 7; }); 103 104 // Add a 1-d int16 tensor 105 AddInput<int16>(TensorShape({7}), [](int x) -> int16 { return x - 8; }); 106 107 // Add a 1-d int64 tensor 108 AddInput<int64>(TensorShape({9}), [](int x) -> int64 { return x - 9; }); 109 110 // Add a 1-d string tensor 111 AddInput<string>(TensorShape({2}), 112 [](int x) -> string { return x ? "yes" : "no"; }); 113 114 // Add a 2-d complex64 tensor 115 AddInput<complex64>(TensorShape({2, 3}), [](int x) -> complex64 { 116 return complex64(100 + x, 200 + x); 117 }); 118 119 // Add a 2-d complex128 tensor 120 AddInput<complex128>(TensorShape({2, 3}), [](int x) -> complex128 { 121 return complex128(100 + x, 200 + x); 122 }); 123 124 // Add a 2-d half tensor 125 AddInput<Eigen::half>(TensorShape({2, 4}), [](int x) -> Eigen::half { 126 return static_cast<Eigen::half>(x) / Eigen::half(2); 127 }); 128 TF_ASSERT_OK(RunOpKernel()); 129 130 // Check that the checkpoint file is properly written 131 checkpoint::TensorSliceReader reader(filename, 132 checkpoint::OpenTableTensorSliceReader); 133 TF_EXPECT_OK(reader.status()); 134 135 // We expect to find all saved tensors 136 { 137 // The 1-d bool tensor 138 TensorShape shape; 139 DataType type; 140 EXPECT_TRUE(reader.HasTensor("tensor_bool", &shape, &type)); 141 TensorShape expected({2}); 142 EXPECT_TRUE(shape.IsSameSize(expected)); 143 EXPECT_EQ(DT_BOOL, type); 144 145 // We expect the tensor value to be correct. 146 TensorSlice s = TensorSlice::ParseOrDie("-"); 147 bool data[2]; 148 std::fill_n(data, 2, false); 149 EXPECT_TRUE(reader.CopySliceData("tensor_bool", s, data)); 150 for (int i = 0; i < 2; ++i) { 151 EXPECT_EQ((i != 0), data[i]); 152 } 153 } 154 155 { 156 // The 1-d integer tensor 157 TensorShape shape; 158 DataType type; 159 EXPECT_TRUE(reader.HasTensor("tensor_int", &shape, &type)); 160 TensorShape expected({10}); 161 EXPECT_TRUE(shape.IsSameSize(expected)); 162 EXPECT_EQ(DT_INT32, type); 163 164 // We expect the tensor value to be correct. 165 TensorSlice s = TensorSlice::ParseOrDie("-"); 166 int data[10]; 167 std::fill_n(data, 10, 0); 168 EXPECT_TRUE(reader.CopySliceData("tensor_int", s, data)); 169 for (int i = 0; i < 10; ++i) { 170 EXPECT_EQ(i + 1, data[i]); 171 } 172 } 173 174 { 175 // The 2-d float tensor 176 TensorShape shape; 177 DataType type; 178 EXPECT_TRUE(reader.HasTensor("tensor_float", &shape, &type)); 179 TensorShape expected({2, 4}); 180 EXPECT_TRUE(shape.IsSameSize(expected)); 181 EXPECT_EQ(DT_FLOAT, type); 182 183 // We expect the tensor value to be correct. 184 TensorSlice s = TensorSlice::ParseOrDie("-:-"); 185 float data[8]; 186 std::fill_n(data, 8, 0); 187 EXPECT_TRUE(reader.CopySliceData("tensor_float", s, data)); 188 for (int i = 0; i < 8; ++i) { 189 EXPECT_EQ(static_cast<float>(i) / 10, data[i]); 190 } 191 } 192 193 { 194 // The 2-d double tensor 195 TensorShape shape; 196 DataType type; 197 EXPECT_TRUE(reader.HasTensor("tensor_double", &shape, &type)); 198 TensorShape expected({2, 4}); 199 EXPECT_TRUE(shape.IsSameSize(expected)); 200 EXPECT_EQ(DT_DOUBLE, type); 201 202 // We expect the tensor value to be correct. 203 TensorSlice s = TensorSlice::ParseOrDie("-:-"); 204 double data[8]; 205 std::fill_n(data, 8, 0); 206 EXPECT_TRUE(reader.CopySliceData("tensor_double", s, data)); 207 for (int i = 0; i < 8; ++i) { 208 EXPECT_EQ(static_cast<double>(i) / 20, data[i]); 209 } 210 } 211 212 { 213 // The 2-d qint8 tensor 214 TensorShape shape; 215 DataType type; 216 EXPECT_TRUE(reader.HasTensor("tensor_qint8", &shape, &type)); 217 TensorShape expected({3, 2}); 218 EXPECT_TRUE(shape.IsSameSize(expected)); 219 EXPECT_EQ(DT_QINT8, type); 220 221 // We expect the tensor value to be correct. 222 TensorSlice s = TensorSlice::ParseOrDie("-:-"); 223 qint8 data[6]; 224 EXPECT_TRUE(reader.CopySliceData("tensor_qint8", s, data)); 225 for (int i = 0; i < 6; ++i) { 226 EXPECT_EQ(*reinterpret_cast<qint8*>(&i), data[i]); 227 } 228 } 229 230 { 231 // The 2-d qint32 tensor 232 TensorShape shape; 233 DataType type; 234 EXPECT_TRUE(reader.HasTensor("tensor_qint32", &shape, &type)); 235 TensorShape expected({2, 3}); 236 EXPECT_TRUE(shape.IsSameSize(expected)); 237 EXPECT_EQ(DT_QINT32, type); 238 239 // We expect the tensor value to be correct. 240 TensorSlice s = TensorSlice::ParseOrDie("-:-"); 241 qint32 data[6]; 242 EXPECT_TRUE(reader.CopySliceData("tensor_qint32", s, data)); 243 for (int i = 0; i < 6; ++i) { 244 EXPECT_EQ(*reinterpret_cast<qint32*>(&i) * qint8(2), data[i]); 245 } 246 } 247 248 { 249 // The 1-d uint8 tensor 250 TensorShape shape; 251 DataType type; 252 EXPECT_TRUE(reader.HasTensor("tensor_uint8", &shape, &type)); 253 TensorShape expected({11}); 254 EXPECT_TRUE(shape.IsSameSize(expected)); 255 EXPECT_EQ(DT_UINT8, type); 256 257 // We expect the tensor value to be correct. 258 TensorSlice s = TensorSlice::ParseOrDie("-"); 259 uint8 data[11]; 260 EXPECT_TRUE(reader.CopySliceData("tensor_uint8", s, data)); 261 for (int i = 0; i < 11; ++i) { 262 EXPECT_EQ(i + 1, data[i]); 263 } 264 } 265 266 { 267 // The 1-d int8 tensor 268 TensorShape shape; 269 DataType type; 270 EXPECT_TRUE(reader.HasTensor("tensor_int8", &shape, &type)); 271 TensorShape expected({7}); 272 EXPECT_TRUE(shape.IsSameSize(expected)); 273 EXPECT_EQ(DT_INT8, type); 274 275 // We expect the tensor value to be correct. 276 TensorSlice s = TensorSlice::ParseOrDie("-"); 277 int8 data[7]; 278 EXPECT_TRUE(reader.CopySliceData("tensor_int8", s, data)); 279 for (int i = 0; i < 7; ++i) { 280 EXPECT_EQ(i - 7, data[i]); 281 } 282 } 283 284 { 285 // The 1-d int16 tensor 286 TensorShape shape; 287 DataType type; 288 EXPECT_TRUE(reader.HasTensor("tensor_int16", &shape, &type)); 289 TensorShape expected({7}); 290 EXPECT_TRUE(shape.IsSameSize(expected)); 291 EXPECT_EQ(DT_INT16, type); 292 293 // We expect the tensor value to be correct. 294 TensorSlice s = TensorSlice::ParseOrDie("-"); 295 int16 data[7]; 296 EXPECT_TRUE(reader.CopySliceData("tensor_int16", s, data)); 297 for (int i = 0; i < 7; ++i) { 298 EXPECT_EQ(i - 8, data[i]); 299 } 300 } 301 302 { 303 // The 1-d int64 tensor 304 TensorShape shape; 305 DataType type; 306 EXPECT_TRUE(reader.HasTensor("tensor_int64", &shape, &type)); 307 TensorShape expected({9}); 308 EXPECT_TRUE(shape.IsSameSize(expected)); 309 EXPECT_EQ(DT_INT64, type); 310 311 // We expect the tensor value to be correct. 312 TensorSlice s = TensorSlice::ParseOrDie("-"); 313 int64 data[9]; 314 EXPECT_TRUE(reader.CopySliceData("tensor_int64", s, data)); 315 for (int i = 0; i < 9; ++i) { 316 EXPECT_EQ(i - 9, data[i]); 317 } 318 } 319 320 { 321 // The 1-d string tensor 322 TensorShape shape; 323 DataType type; 324 EXPECT_TRUE(reader.HasTensor("tensor_string", &shape, &type)); 325 TensorShape expected({2}); 326 EXPECT_TRUE(shape.IsSameSize(expected)); 327 EXPECT_EQ(DT_STRING, type); 328 329 // We expect the tensor value to be correct. 330 TensorSlice s = TensorSlice::ParseOrDie("-"); 331 string data[2]; 332 EXPECT_TRUE(reader.CopySliceData("tensor_string", s, data)); 333 EXPECT_EQ("no", data[0]); 334 EXPECT_EQ("yes", data[1]); 335 } 336 337 { 338 // The 2-d complex64 tensor 339 TensorShape shape; 340 DataType type; 341 EXPECT_TRUE(reader.HasTensor("tensor_complex64", &shape, &type)); 342 TensorShape expected({2, 3}); 343 EXPECT_TRUE(shape.IsSameSize(expected)); 344 EXPECT_EQ(DT_COMPLEX64, type); 345 346 // We expect the tensor value to be correct. 347 TensorSlice s = TensorSlice::ParseOrDie("-:-"); 348 complex64 data[6]; 349 EXPECT_TRUE(reader.CopySliceData("tensor_complex64", s, data)); 350 for (int i = 0; i < 6; ++i) { 351 EXPECT_EQ(100 + i, data[i].real()); 352 EXPECT_EQ(200 + i, data[i].imag()); 353 } 354 } 355 356 { 357 // The 2-d complex128 tensor 358 TensorShape shape; 359 DataType type; 360 EXPECT_TRUE(reader.HasTensor("tensor_complex128", &shape, &type)); 361 TensorShape expected({2, 3}); 362 EXPECT_TRUE(shape.IsSameSize(expected)); 363 EXPECT_EQ(DT_COMPLEX128, type); 364 365 // We expect the tensor value to be correct. 366 TensorSlice s = TensorSlice::ParseOrDie("-:-"); 367 complex128 data[6]; 368 EXPECT_TRUE(reader.CopySliceData("tensor_complex128", s, data)); 369 for (int i = 0; i < 6; ++i) { 370 EXPECT_EQ(100 + i, data[i].real()); 371 EXPECT_EQ(200 + i, data[i].imag()); 372 } 373 } 374 { 375 // The 2-d half tensor 376 TensorShape shape; 377 DataType type; 378 EXPECT_TRUE(reader.HasTensor("tensor_half", &shape, &type)); 379 TensorShape expected({2, 4}); 380 EXPECT_TRUE(shape.IsSameSize(expected)); 381 EXPECT_EQ(DT_HALF, type); 382 383 // We expect the tensor value to be correct. 384 TensorSlice s = TensorSlice::ParseOrDie("-:-"); 385 Eigen::half data[8]; 386 std::fill_n(data, 8, Eigen::half(0)); 387 EXPECT_TRUE(reader.CopySliceData("tensor_half", s, data)); 388 for (int i = 0; i < 8; ++i) { 389 EXPECT_EQ(static_cast<Eigen::half>(i) / Eigen::half(2), data[i]); 390 } 391 } 392 } 393 394 class SaveSlicesOpTest : public OpsTestBase { 395 protected: 396 void MakeOp() { 397 TF_ASSERT_OK(NodeDefBuilder("myop", "SaveSlices") 398 .Input(FakeInput()) 399 .Input(FakeInput()) 400 .Input(FakeInput()) 401 .Input(FakeInput( 402 {DT_INT32, DT_FLOAT, DT_DOUBLE, DT_QINT8, DT_QINT32})) 403 .Finalize(node_def())); 404 TF_ASSERT_OK(InitOp()); 405 } 406 }; 407 408 // Here we save only slices. We restore them in a larger tensor and we check 409 // that the right slice is restored. It is quite tricky to check that the 410 // right slices are actually restored so instead we just check that 411 // CopySliceData() return true/false depending on the slice we ask for. 412 TEST_F(SaveSlicesOpTest, Slices) { 413 const string filename = io::JoinPath(testing::TmpDir(), "tensor_slices"); 414 const string tensornames[] = {"tensor_int", "tensor_float", "tensor_double", 415 "tensor_qint8", "tensor_qint32"}; 416 // Specifies that the data we save are slices of larger tensors. 417 // See core/framework/tensor_slice.h for the slice syntax. 418 const string tensorshapes[] = { 419 "10 -", // Full contents of a 10 element vector. 420 "2 4 -:0,2", // A 2x2 slice of a 2x4 tensor. 421 "2 4 0,1:2,2", // A 1x2 slice of a 2x4 tensor. 422 "3 2 -:-", // Full contents of a 3x2 tensor. 423 "2 3 1,1:2,1" // Another 1x1 slice of a2x3 tensor. 424 }; 425 426 MakeOp(); 427 // Add a file name 428 AddInput<string>(TensorShape({}), 429 [&filename](int x) -> string { return filename; }); 430 431 // Add the tensor names 432 AddInput<string>(TensorShape({5}), 433 [&tensornames](int x) -> string { return tensornames[x]; }); 434 435 // Add the tensor shapes and slices 436 AddInput<string>(TensorShape({5}), [&tensorshapes](int x) -> string { 437 return tensorshapes[x]; 438 }); 439 440 // Add a 1-d integer tensor 441 AddInput<int32>(TensorShape({10}), [](int x) -> int32 { return x + 1; }); 442 443 // Add a 2-d float tensor 444 AddInput<float>(TensorShape({2, 2}), 445 [](int x) -> float { return static_cast<float>(x) / 10; }); 446 447 // Add a 2-d double tensor 448 AddInput<double>(TensorShape({1, 2}), 449 [](int x) -> double { return static_cast<double>(x) / 20; }); 450 451 // Add a 2-d qint8 tensor 452 AddInput<qint8>(TensorShape({3, 2}), 453 [](int x) -> qint8 { return *reinterpret_cast<qint8*>(&x); }); 454 455 // Add a 2-d qint32 tensor 456 AddInput<qint32>(TensorShape({1, 1}), [](int x) -> qint32 { 457 return *reinterpret_cast<qint32*>(&x) * qint8(2); 458 }); 459 460 TF_ASSERT_OK(RunOpKernel()); 461 462 // Check that the checkpoint file is properly written 463 checkpoint::TensorSliceReader reader(filename, 464 checkpoint::OpenTableTensorSliceReader); 465 TF_EXPECT_OK(reader.status()); 466 467 // We expect to find all saved tensors 468 { 469 // The 1-d integer tensor 470 TensorShape shape; 471 DataType type; 472 EXPECT_TRUE(reader.HasTensor("tensor_int", &shape, &type)); 473 TensorShape expected({10}); 474 EXPECT_TRUE(shape.IsSameSize(expected)); 475 EXPECT_EQ(DT_INT32, type); 476 477 // We saved the full tensor so we should be able to read it all. 478 TensorSlice s = TensorSlice::ParseOrDie("-"); 479 int data[10]; 480 EXPECT_TRUE(reader.CopySliceData("tensor_int", s, data)); 481 } 482 483 { 484 // The 2-d float tensor 485 TensorShape shape; 486 DataType type; 487 EXPECT_TRUE(reader.HasTensor("tensor_float", &shape, &type)); 488 TensorShape expected({2, 4}); 489 EXPECT_TRUE(shape.IsSameSize(expected)); 490 EXPECT_EQ(DT_FLOAT, type); 491 492 // We saved the slice "-:0,2" so we should not be able to read the full 493 // tensor. 494 TensorSlice full_slice = TensorSlice::ParseOrDie("-:-"); 495 TensorSlice saved_slice = TensorSlice::ParseOrDie("-:0,2"); 496 float data[8]; 497 EXPECT_FALSE(reader.CopySliceData("tensor_float", full_slice, data)); 498 EXPECT_TRUE(reader.CopySliceData("tensor_float", saved_slice, data)); 499 } 500 501 { 502 // The 2-d double tensor 503 TensorShape shape; 504 DataType type; 505 EXPECT_TRUE(reader.HasTensor("tensor_double", &shape, &type)); 506 TensorShape expected({2, 4}); 507 EXPECT_TRUE(shape.IsSameSize(expected)); 508 EXPECT_EQ(DT_DOUBLE, type); 509 510 // We saved the slice "0,1:2,2" so we should not be able to read the full 511 // tensor. 512 TensorSlice full_slice = TensorSlice::ParseOrDie("-:-"); 513 TensorSlice saved_slice = TensorSlice::ParseOrDie("0,1:2,2"); 514 double data[8]; 515 EXPECT_FALSE(reader.CopySliceData("tensor_double", full_slice, data)); 516 EXPECT_TRUE(reader.CopySliceData("tensor_double", saved_slice, data)); 517 } 518 519 { 520 // The 2-d qint8 tensor 521 TensorShape shape; 522 DataType type; 523 EXPECT_TRUE(reader.HasTensor("tensor_qint8", &shape, &type)); 524 TensorShape expected({3, 2}); 525 EXPECT_TRUE(shape.IsSameSize(expected)); 526 EXPECT_EQ(DT_QINT8, type); 527 528 // We saved the full slice. 529 TensorSlice s = TensorSlice::ParseOrDie("-:-"); 530 qint8 data[6]; 531 EXPECT_TRUE(reader.CopySliceData("tensor_qint8", s, data)); 532 } 533 534 { 535 // The 2-d qint32 tensor 536 TensorShape shape; 537 DataType type; 538 EXPECT_TRUE(reader.HasTensor("tensor_qint32", &shape, &type)); 539 TensorShape expected({2, 3}); 540 EXPECT_TRUE(shape.IsSameSize(expected)); 541 EXPECT_EQ(DT_QINT32, type); 542 543 // We expect the tensor value to be correct. 544 TensorSlice s = TensorSlice::ParseOrDie("1,1:2,1"); 545 TensorSlice full_slice = TensorSlice::ParseOrDie("-:-"); 546 TensorSlice saved_slice = TensorSlice::ParseOrDie("1,1:2,1"); 547 qint32 data[6]; 548 EXPECT_FALSE(reader.CopySliceData("tensor_qint32", full_slice, data)); 549 EXPECT_TRUE(reader.CopySliceData("tensor_qint32", saved_slice, data)); 550 } 551 } 552 553 class SaveOpSlices2Test : public OpsTestBase { 554 protected: 555 void MakeOp() { 556 TF_ASSERT_OK(NodeDefBuilder("myop", "SaveSlices") 557 .Input(FakeInput()) 558 .Input(FakeInput()) 559 .Input(FakeInput()) 560 .Input(FakeInput({DT_INT32, DT_INT32, DT_FLOAT})) 561 .Finalize(node_def())); 562 TF_ASSERT_OK(InitOp()); 563 } 564 }; 565 566 TEST_F(SaveOpSlices2Test, TwoSlices) { 567 const string filename = io::JoinPath(testing::TmpDir(), "three_slices"); 568 // We will save 2 slices of the tensor named "four_by_sixteen" which is 4x16, 569 // and one slice of the "small" tensor. 570 const string tensornames[] = {"four_by_sixteen", "four_by_sixteen", "small"}; 571 const string tensorshapes[] = { 572 // Slice specifications for the 2 slices of "four_by_sixteen" 573 "4 16 0,2:-", // 1st slice covers indices 0 and 1 in the first dim. 574 "4 16 2,2:-", // 2nd slice covers indices 2 and 3 in the first dim. 575 "" // We save the full "small" tensors. 576 }; 577 578 MakeOp(); 579 // Add a file name 580 AddInput<string>(TensorShape({}), 581 [&filename](int x) -> string { return filename; }); 582 583 // Add the tensor names 584 AddInput<string>(TensorShape({3}), 585 [&tensornames](int x) -> string { return tensornames[x]; }); 586 587 // Add the tensor shapes and slices 588 AddInput<string>(TensorShape({3}), [&tensorshapes](int x) -> string { 589 return tensorshapes[x]; 590 }); 591 592 // Add an integer tensor for slice 0,2:- of a 4x16 tensor: It is 2x16. 593 AddInput<int32>(TensorShape({2, 16}), [](int x) -> int32 { return x + 1; }); 594 595 // Add an integer tensor for slice 2,2:- of a 4x16 tensor: It is 2x16. 596 AddInput<int32>(TensorShape({2, 16}), 597 [](int x) -> int32 { return 10 * (x + 1); }); 598 599 // Add a float tensor for "small" 600 AddInput<float>(TensorShape({2, 4}), 601 [](int x) -> float { return static_cast<float>(x) / 10; }); 602 603 TF_ASSERT_OK(RunOpKernel()); 604 605 // Check that the checkpoint file is properly written 606 checkpoint::TensorSliceReader reader(filename, 607 checkpoint::OpenTableTensorSliceReader); 608 TF_EXPECT_OK(reader.status()); 609 610 { 611 // Reload the two slices of "four_by_sixteen" into that tensor. 612 Tensor reloaded(DT_INT32, {4, 16}); 613 614 // We expect to find all slices 615 TensorShape shape; 616 DataType type; 617 EXPECT_TRUE(reader.HasTensor("four_by_sixteen", &shape, &type)); 618 EXPECT_TRUE(shape.IsSameSize(reloaded.shape())); 619 EXPECT_EQ(type, reloaded.dtype()); 620 621 // Reload the whole tensor. 622 EXPECT_TRUE(reader.CopySliceData("four_by_sixteen", 623 TensorSlice(reloaded.dims()), 624 reloaded.flat<int>().data())); 625 626 { 627 auto slice = reloaded.Slice(0, 2).flat<int>(); 628 for (int i = 0; i < slice.size(); ++i) { 629 EXPECT_EQ(i + 1, slice(i)); 630 } 631 } 632 { 633 auto slice = reloaded.Slice(2, 4).flat<int>(); 634 for (int i = 0; i < slice.size(); ++i) { 635 EXPECT_EQ(10 * (i + 1), slice(i)); 636 } 637 } 638 } 639 640 { 641 // Reload the small float tensor. 642 Tensor reloaded(DT_FLOAT, {2, 4}); 643 644 TensorShape shape; 645 DataType type; 646 EXPECT_TRUE(reader.HasTensor("small", &shape, &type)); 647 EXPECT_TRUE(shape.IsSameSize(reloaded.shape())); 648 EXPECT_EQ(DT_FLOAT, reloaded.dtype()); 649 650 EXPECT_TRUE(reader.CopySliceData("small", TensorSlice(reloaded.dims()), 651 reloaded.flat<float>().data())); 652 653 for (int64 i = 0; i < reloaded.NumElements(); ++i) { 654 EXPECT_EQ(static_cast<float>(i) / 10, reloaded.flat<float>().data()[i]); 655 } 656 } 657 } 658 659 // Benchmark-related code below. 660 661 static void BM_LargeTensorWrite(int iters, int num_elements) { 662 testing::StopTiming(); 663 664 // 4 * num_elements bytes total , since sizeof(float) == 4. 665 Tensor tensor(DT_FLOAT, TensorShape({num_elements})); 666 tensor.flat<float>().setZero(); 667 668 // Builds the graph. 669 const string temp_filename = 670 io::JoinPath(testing::TmpDir(), "benchmark_checkpoint"); 671 auto root = Scope::NewRootScope().ExitOnError(); 672 const string tensor_name = "my_tensor"; 673 ops::Save(root, temp_filename, {tensor_name}, {{tensor}}); 674 675 // Disables optimizations. 676 SessionOptions session_options; 677 session_options.config.mutable_graph_options() 678 ->mutable_optimizer_options() 679 ->set_opt_level(tensorflow::OptimizerOptions_Level_L0); 680 681 TF_CHECK_OK(root.status()); 682 Graph* g = new Graph(OpRegistry::Global()); 683 TF_CHECK_OK(root.ToGraph(g)); 684 VLOG(1) << "Save op's output path: " << temp_filename; 685 VLOG(1) << "# nodes in Graph: " << g->num_nodes(); 686 687 testing::StartTiming(); 688 test::Benchmark("cpu", g, &session_options).Run(iters); 689 } 690 BENCHMARK(BM_LargeTensorWrite)->Arg((1 << 30) / 4 /* 1GB float tensor */); 691 692 } // namespace 693 } // namespace tensorflow 694