1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <complex> 17 #include <string> 18 19 #include "tensorflow/core/framework/fake_input.h" 20 #include "tensorflow/core/framework/node_def_builder.h" 21 #include "tensorflow/core/framework/tensor.h" 22 #include "tensorflow/core/framework/tensor_shape.h" 23 #include "tensorflow/core/framework/types.h" 24 #include "tensorflow/core/framework/types.pb.h" 25 #include "tensorflow/core/kernels/ops_testutil.h" 26 #include "tensorflow/core/lib/core/status.h" 27 #include "tensorflow/core/lib/io/path.h" 28 #include "tensorflow/core/platform/env.h" 29 #include "tensorflow/core/platform/test.h" 30 #include "tensorflow/core/platform/types.h" 31 #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h" 32 33 namespace tensorflow { 34 namespace { 35 36 class SaveV2OpTest : public OpsTestBase { 37 protected: 38 void MakeOp() { 39 TF_ASSERT_OK(NodeDefBuilder("myop", "SaveV2") 40 .Input(FakeInput()) // prefix 41 .Input(FakeInput()) // tensor_names 42 .Input(FakeInput()) // shape_and_slices 43 .Input(FakeInput({DT_BOOL, DT_INT32, DT_FLOAT, DT_DOUBLE, 44 DT_QINT8, DT_QINT32, DT_UINT8, DT_INT8, 45 DT_INT16, DT_INT64, DT_COMPLEX64, 46 DT_COMPLEX128, DT_HALF})) // tensors 47 .Finalize(node_def())); 48 TF_ASSERT_OK(InitOp()); 49 } 50 }; 51 52 TEST_F(SaveV2OpTest, Simple) { 53 const string prefix = io::JoinPath(testing::TmpDir(), "tensor_simple"); 54 const string tensornames[] = { 55 "tensor_bool", "tensor_int", "tensor_float", "tensor_double", 56 "tensor_qint8", "tensor_qint32", "tensor_uint8", "tensor_int8", 57 "tensor_int16", "tensor_int64", "tensor_complex64", "tensor_complex128", 58 "tensor_half"}; 59 60 MakeOp(); 61 // Add a file name 62 AddInput<string>(TensorShape({}), 63 [&prefix](int x) -> string { return prefix; }); 64 65 // Add the tensor names 66 AddInput<string>(TensorShape({13}), 67 [&tensornames](int x) -> string { return tensornames[x]; }); 68 69 // Add the slice specs 70 AddInput<string>(TensorShape({13}), [&tensornames](int x) -> string { 71 return "" /* saves in full */; 72 }); 73 74 // Add a 1-d bool tensor 75 AddInput<bool>(TensorShape({2}), [](int x) -> bool { return x != 0; }); 76 77 // Add a 1-d integer tensor 78 AddInput<int32>(TensorShape({10}), [](int x) -> int32 { return x + 1; }); 79 80 // Add a 2-d float tensor 81 AddInput<float>(TensorShape({2, 4}), 82 [](int x) -> float { return static_cast<float>(x) / 10; }); 83 84 // Add a 2-d double tensor 85 AddInput<double>(TensorShape({2, 4}), 86 [](int x) -> double { return static_cast<double>(x) / 20; }); 87 88 // Add a 2-d qint8 tensor 89 AddInput<qint8>(TensorShape({3, 2}), 90 [](int x) -> qint8 { return *reinterpret_cast<qint8*>(&x); }); 91 92 // Add a 2-d qint32 tensor 93 AddInput<qint32>(TensorShape({2, 3}), [](int x) -> qint32 { 94 return *reinterpret_cast<qint32*>(&x) * qint8(2); 95 }); 96 97 // Add a 1-d uint8 tensor 98 AddInput<uint8>(TensorShape({11}), [](int x) -> uint8 { return x + 1; }); 99 100 // Add a 1-d int8 tensor 101 AddInput<int8>(TensorShape({7}), [](int x) -> int8 { return x - 7; }); 102 103 // Add a 1-d int16 tensor 104 AddInput<int16>(TensorShape({7}), [](int x) -> int16 { return x - 8; }); 105 106 // Add a 1-d int64 tensor 107 AddInput<int64>(TensorShape({9}), [](int x) -> int64 { return x - 9; }); 108 109 // Add a 2-d complex64 tensor 110 AddInput<complex64>(TensorShape({2, 3}), [](int x) -> complex64 { 111 return complex64(100 + x, 200 + x); 112 }); 113 114 // Add a 2-d complex128 tensor 115 AddInput<complex128>(TensorShape({2, 3}), [](int x) -> complex128 { 116 return complex128(100 + x, 200 + x); 117 }); 118 119 // Add a 2-d half tensor 120 AddInput<Eigen::half>(TensorShape({2, 4}), [](int x) -> Eigen::half { 121 return static_cast<Eigen::half>(x) / Eigen::half(2); 122 }); 123 TF_ASSERT_OK(RunOpKernel()); 124 125 // Check that the checkpoint file is properly written 126 BundleReader reader(Env::Default(), prefix); 127 TF_EXPECT_OK(reader.status()); 128 129 // We expect to find all saved tensors 130 { 131 // The 1-d bool tensor 132 TensorShape shape; 133 TF_EXPECT_OK(reader.LookupTensorShape("tensor_bool", &shape)); 134 TensorShape expected({2}); 135 EXPECT_TRUE(shape.IsSameSize(expected)); 136 137 // We expect the tensor value to be correct. 138 Tensor val; 139 TF_EXPECT_OK(reader.Lookup("tensor_bool", &val)); 140 EXPECT_EQ(DT_BOOL, val.dtype()); 141 for (int i = 0; i < 2; ++i) { 142 EXPECT_EQ((i != 0), val.template flat<bool>()(i)); 143 } 144 } 145 146 { 147 // The 1-d integer tensor 148 TensorShape shape; 149 TF_EXPECT_OK(reader.LookupTensorShape("tensor_int", &shape)); 150 TensorShape expected({10}); 151 EXPECT_TRUE(shape.IsSameSize(expected)); 152 153 // We expect the tensor value to be correct. 154 Tensor val; 155 TF_EXPECT_OK(reader.Lookup("tensor_int", &val)); 156 EXPECT_EQ(DT_INT32, val.dtype()); 157 for (int i = 0; i < 10; ++i) { 158 EXPECT_EQ(i + 1, val.template flat<int>()(i)); 159 } 160 } 161 162 { 163 // The 2-d float tensor 164 TensorShape shape; 165 TF_EXPECT_OK(reader.LookupTensorShape("tensor_float", &shape)); 166 TensorShape expected({2, 4}); 167 EXPECT_TRUE(shape.IsSameSize(expected)); 168 169 // We expect the tensor value to be correct. 170 Tensor val; 171 TF_EXPECT_OK(reader.Lookup("tensor_float", &val)); 172 EXPECT_EQ(DT_FLOAT, val.dtype()); 173 for (int i = 0; i < 8; ++i) { 174 EXPECT_EQ(static_cast<float>(i) / 10, val.template flat<float>()(i)); 175 } 176 } 177 178 { 179 // The 2-d double tensor 180 TensorShape shape; 181 TF_EXPECT_OK(reader.LookupTensorShape("tensor_double", &shape)); 182 TensorShape expected({2, 4}); 183 EXPECT_TRUE(shape.IsSameSize(expected)); 184 185 // We expect the tensor value to be correct. 186 Tensor val; 187 TF_EXPECT_OK(reader.Lookup("tensor_double", &val)); 188 EXPECT_EQ(DT_DOUBLE, val.dtype()); 189 for (int i = 0; i < 8; ++i) { 190 EXPECT_EQ(static_cast<double>(i) / 20, val.template flat<double>()(i)); 191 } 192 } 193 194 { 195 // The 2-d qint8 tensor 196 TensorShape shape; 197 TF_EXPECT_OK(reader.LookupTensorShape("tensor_qint8", &shape)); 198 TensorShape expected({3, 2}); 199 EXPECT_TRUE(shape.IsSameSize(expected)); 200 201 // We expect the tensor value to be correct. 202 Tensor val; 203 TF_EXPECT_OK(reader.Lookup("tensor_qint8", &val)); 204 EXPECT_EQ(DT_QINT8, val.dtype()); 205 for (int i = 0; i < 6; ++i) { 206 EXPECT_EQ(*reinterpret_cast<qint8*>(&i), val.template flat<qint8>()(i)); 207 } 208 } 209 210 { 211 // The 2-d qint32 tensor 212 TensorShape shape; 213 TF_EXPECT_OK(reader.LookupTensorShape("tensor_qint32", &shape)); 214 TensorShape expected({2, 3}); 215 EXPECT_TRUE(shape.IsSameSize(expected)); 216 217 // We expect the tensor value to be correct. 218 Tensor val; 219 TF_EXPECT_OK(reader.Lookup("tensor_qint32", &val)); 220 EXPECT_EQ(DT_QINT32, val.dtype()); 221 for (int i = 0; i < 6; ++i) { 222 EXPECT_EQ(*reinterpret_cast<qint32*>(&i) * qint8(2), 223 val.template flat<qint32>()(i)); 224 } 225 } 226 227 { 228 // The 1-d uint8 tensor 229 TensorShape shape; 230 TF_EXPECT_OK(reader.LookupTensorShape("tensor_uint8", &shape)); 231 TensorShape expected({11}); 232 EXPECT_TRUE(shape.IsSameSize(expected)); 233 234 // We expect the tensor value to be correct. 235 Tensor val; 236 TF_EXPECT_OK(reader.Lookup("tensor_uint8", &val)); 237 EXPECT_EQ(DT_UINT8, val.dtype()); 238 for (int i = 0; i < 11; ++i) { 239 EXPECT_EQ(i + 1, val.template flat<uint8>()(i)); 240 } 241 } 242 243 { 244 // The 1-d int8 tensor 245 TensorShape shape; 246 TF_EXPECT_OK(reader.LookupTensorShape("tensor_int8", &shape)); 247 TensorShape expected({7}); 248 EXPECT_TRUE(shape.IsSameSize(expected)); 249 250 // We expect the tensor value to be correct. 251 Tensor val; 252 TF_EXPECT_OK(reader.Lookup("tensor_int8", &val)); 253 EXPECT_EQ(DT_INT8, val.dtype()); 254 for (int i = 0; i < 7; ++i) { 255 EXPECT_EQ(i - 7, val.template flat<int8>()(i)); 256 } 257 } 258 259 { 260 // The 1-d int16 tensor 261 TensorShape shape; 262 TF_EXPECT_OK(reader.LookupTensorShape("tensor_int16", &shape)); 263 TensorShape expected({7}); 264 EXPECT_TRUE(shape.IsSameSize(expected)); 265 266 // We expect the tensor value to be correct. 267 Tensor val; 268 TF_EXPECT_OK(reader.Lookup("tensor_int16", &val)); 269 EXPECT_EQ(DT_INT16, val.dtype()); 270 for (int i = 0; i < 7; ++i) { 271 EXPECT_EQ(i - 8, val.template flat<int16>()(i)); 272 } 273 } 274 275 { 276 // The 1-d int64 tensor 277 TensorShape shape; 278 TF_EXPECT_OK(reader.LookupTensorShape("tensor_int64", &shape)); 279 TensorShape expected({9}); 280 EXPECT_TRUE(shape.IsSameSize(expected)); 281 282 // We expect the tensor value to be correct. 283 Tensor val; 284 TF_EXPECT_OK(reader.Lookup("tensor_int64", &val)); 285 EXPECT_EQ(DT_INT64, val.dtype()); 286 for (int i = 0; i < 9; ++i) { 287 EXPECT_EQ(i - 9, val.template flat<int64>()(i)); 288 } 289 } 290 291 { 292 // The 2-d complex64 tensor 293 TensorShape shape; 294 TF_EXPECT_OK(reader.LookupTensorShape("tensor_complex64", &shape)); 295 TensorShape expected({2, 3}); 296 EXPECT_TRUE(shape.IsSameSize(expected)); 297 298 // We expect the tensor value to be correct. 299 Tensor val; 300 TF_EXPECT_OK(reader.Lookup("tensor_complex64", &val)); 301 EXPECT_EQ(DT_COMPLEX64, val.dtype()); 302 for (int i = 0; i < 6; ++i) { 303 EXPECT_EQ(100 + i, val.template flat<complex64>()(i).real()); 304 EXPECT_EQ(200 + i, val.template flat<complex64>()(i).imag()); 305 } 306 } 307 308 { 309 // The 2-d complex128 tensor 310 TensorShape shape; 311 TF_EXPECT_OK(reader.LookupTensorShape("tensor_complex128", &shape)); 312 TensorShape expected({2, 3}); 313 EXPECT_TRUE(shape.IsSameSize(expected)); 314 315 // We expect the tensor value to be correct. 316 Tensor val; 317 TF_EXPECT_OK(reader.Lookup("tensor_complex128", &val)); 318 EXPECT_EQ(DT_COMPLEX128, val.dtype()); 319 for (int i = 0; i < 6; ++i) { 320 EXPECT_EQ(100 + i, val.template flat<complex128>()(i).real()); 321 EXPECT_EQ(200 + i, val.template flat<complex128>()(i).imag()); 322 } 323 } 324 { 325 // The 2-d half tensor 326 TensorShape shape; 327 TF_EXPECT_OK(reader.LookupTensorShape("tensor_half", &shape)); 328 TensorShape expected({2, 4}); 329 EXPECT_TRUE(shape.IsSameSize(expected)); 330 331 // We expect the tensor value to be correct. 332 Tensor val; 333 TF_EXPECT_OK(reader.Lookup("tensor_half", &val)); 334 EXPECT_EQ(DT_HALF, val.dtype()); 335 for (int i = 0; i < 8; ++i) { 336 EXPECT_EQ(static_cast<Eigen::half>(i) / Eigen::half(2), 337 val.template flat<Eigen::half>()(i)); 338 } 339 } 340 } 341 342 } // namespace 343 } // namespace tensorflow 344