Home | History | Annotate | Download | only in framework
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/framework/tensor_util.h"
     17 
     18 #include <vector>
     19 #include "tensorflow/core/framework/tensor.h"
     20 #include "tensorflow/core/framework/tensor_testutil.h"
     21 #include "tensorflow/core/framework/types.h"
     22 #include "tensorflow/core/framework/variant.h"
     23 #include "tensorflow/core/framework/variant_encode_decode.h"
     24 #include "tensorflow/core/framework/variant_tensor_data.h"
     25 #include "tensorflow/core/lib/core/status_test_util.h"
     26 #include "tensorflow/core/platform/test.h"
     27 
     28 namespace tensorflow {
     29 namespace {
     30 
     31 TEST(TensorUtil, DeepCopy0d) {
     32   Tensor x(DT_FLOAT, TensorShape({}));
     33   x.scalar<float>()() = 10.0;
     34 
     35   // Make y a deep copy of x and then change it.
     36   Tensor y = tensor::DeepCopy(x);
     37   y.scalar<float>()() = 20.0;
     38 
     39   // x doesn't change
     40   EXPECT_EQ(10.0, x.scalar<float>()());
     41 
     42   // Change x.
     43   x.scalar<float>()() = 30.0;
     44 
     45   // Y doesn't change.
     46   EXPECT_EQ(20.0, y.scalar<float>()());
     47 
     48   Tensor z = tensor::DeepCopy(y);
     49 
     50   // Change y.
     51   y.scalar<float>()() = 40.0;
     52 
     53   // The final states should all be different.
     54   EXPECT_EQ(20.0, z.scalar<float>()());
     55   EXPECT_EQ(30.0, x.scalar<float>()());
     56   EXPECT_EQ(40.0, y.scalar<float>()());
     57 
     58   // Should have the same shape and type.
     59   EXPECT_EQ(TensorShape({}), x.shape());
     60   EXPECT_EQ(TensorShape({}), y.shape());
     61   EXPECT_EQ(TensorShape({}), z.shape());
     62 
     63   EXPECT_EQ(DT_FLOAT, x.dtype());
     64   EXPECT_EQ(DT_FLOAT, y.dtype());
     65   EXPECT_EQ(DT_FLOAT, z.dtype());
     66 }
     67 
     68 TEST(TensorUtil, DeepCopyZeroElements) {
     69   Tensor x;
     70   Tensor y = tensor::DeepCopy(x);
     71   EXPECT_EQ(TensorShape({0}), y.shape());
     72   EXPECT_EQ(DT_FLOAT, y.dtype());
     73   EXPECT_EQ(0, y.NumElements());
     74 }
     75 
     76 TEST(TensorUtil, DeepCopy) {
     77   Tensor x(DT_FLOAT, TensorShape({1}));
     78   x.flat<float>()(0) = 10.0;
     79 
     80   // Make y a deep copy of x and then change it.
     81   Tensor y = tensor::DeepCopy(x);
     82   y.flat<float>()(0) = 20.0;
     83 
     84   // x doesn't change
     85   EXPECT_EQ(10.0, x.flat<float>()(0));
     86 
     87   // Change x.
     88   x.flat<float>()(0) = 30.0;
     89 
     90   // Y doesn't change.
     91   EXPECT_EQ(20.0, y.flat<float>()(0));
     92 
     93   Tensor z = tensor::DeepCopy(y);
     94 
     95   // Change y.
     96   y.flat<float>()(0) = 40.0;
     97 
     98   // The final states should all be different.
     99   EXPECT_EQ(20.0, z.flat<float>()(0));
    100   EXPECT_EQ(30.0, x.flat<float>()(0));
    101   EXPECT_EQ(40.0, y.flat<float>()(0));
    102 
    103   // Should have the same shape and type.
    104   EXPECT_EQ(TensorShape({1}), x.shape());
    105   EXPECT_EQ(TensorShape({1}), y.shape());
    106   EXPECT_EQ(TensorShape({1}), z.shape());
    107 
    108   EXPECT_EQ(DT_FLOAT, x.dtype());
    109   EXPECT_EQ(DT_FLOAT, y.dtype());
    110   EXPECT_EQ(DT_FLOAT, z.dtype());
    111 
    112   // Test string deep copy
    113   Tensor str1(DT_STRING, TensorShape({2}));
    114   str1.flat<string>()(0) = "foo1";
    115   str1.flat<string>()(1) = "foo2";
    116   Tensor str2 = tensor::DeepCopy(str1);
    117   str2.flat<string>()(0) = "bar1";
    118   str2.flat<string>()(1) = "bar2";
    119   EXPECT_NE(str2.flat<string>()(0), str1.flat<string>()(0));
    120 }
    121 
    122 TEST(TensorUtil, DeepCopySlice) {
    123   Tensor x(DT_INT32, TensorShape({10}));
    124   x.flat<int32>().setConstant(1);
    125 
    126   // Slice 'x' -- y still refers to the same buffer.
    127   Tensor y = x.Slice(2, 6);
    128 
    129   // Do a deep copy of y, which is a slice.
    130   Tensor z = tensor::DeepCopy(y);
    131 
    132   // Set x to be different.
    133   x.flat<int32>().setConstant(2);
    134 
    135   EXPECT_EQ(TensorShape({10}), x.shape());
    136   EXPECT_EQ(TensorShape({4}), y.shape());
    137   EXPECT_EQ(TensorShape({4}), z.shape());
    138   EXPECT_EQ(DT_INT32, x.dtype());
    139   EXPECT_EQ(DT_INT32, y.dtype());
    140   EXPECT_EQ(DT_INT32, z.dtype());
    141 
    142   // x and y should now all be '2', but z should be '1'.
    143   for (int i = 0; i < 10; ++i) {
    144     EXPECT_EQ(2, x.flat<int32>()(i));
    145   }
    146   for (int i = 0; i < 4; ++i) {
    147     EXPECT_EQ(2, y.unaligned_flat<int32>()(i));
    148     EXPECT_EQ(1, z.flat<int32>()(i));
    149   }
    150 }
    151 
    152 TEST(TensorUtil, DeepCopySliceString) {
    153   Tensor x(DT_STRING, TensorShape({10}));
    154   x.flat<string>().setConstant("hello");
    155 
    156   // Slice 'x' -- y still refers to the same buffer.
    157   Tensor y = x.Slice(3, 7);
    158 
    159   // Do a deep copy of y, which is a slice.
    160   Tensor z = tensor::DeepCopy(y);
    161 
    162   // Set x to be different.
    163   x.flat<string>().setConstant("goodbye");
    164 
    165   EXPECT_EQ(TensorShape({10}), x.shape());
    166   EXPECT_EQ(TensorShape({4}), y.shape());
    167   EXPECT_EQ(TensorShape({4}), z.shape());
    168   EXPECT_EQ(DT_STRING, x.dtype());
    169   EXPECT_EQ(DT_STRING, y.dtype());
    170   EXPECT_EQ(DT_STRING, z.dtype());
    171 
    172   // x and y should now all be 'goodbye', but z should be 'hello'.
    173   for (int i = 0; i < 10; ++i) {
    174     EXPECT_EQ("goodbye", x.flat<string>()(i));
    175   }
    176   for (int i = 0; i < 4; ++i) {
    177     EXPECT_EQ("goodbye", y.unaligned_flat<string>()(i));
    178     EXPECT_EQ("hello", z.flat<string>()(i));
    179   }
    180 }
    181 
    182 TEST(TensorUtil, DeepCopySliceVariant) {
    183   Tensor x(DT_VARIANT, TensorShape({10}));
    184   x.flat<Variant>().setConstant(Tensor(42.0f));
    185 
    186   // Slice 'x' -- y still refers to the same buffer.
    187   Tensor y = x.Slice(3, 7);
    188 
    189   // Do a deep copy of y, which is a slice.
    190   Tensor z = tensor::DeepCopy(y);
    191 
    192   // Set x to be different.
    193   x.flat<Variant>().setConstant(Tensor("foo"));
    194 
    195   EXPECT_EQ(TensorShape({10}), x.shape());
    196   EXPECT_EQ(TensorShape({4}), y.shape());
    197   EXPECT_EQ(TensorShape({4}), z.shape());
    198   EXPECT_EQ(DT_VARIANT, x.dtype());
    199   EXPECT_EQ(DT_VARIANT, y.dtype());
    200   EXPECT_EQ(DT_VARIANT, z.dtype());
    201 
    202   // Each element of x and y should now be a DT_STRING Tensor containing "foo",
    203   // but each element of z should be a DT_FLOAT tensor containing 42.0.
    204   for (int i = 0; i < 10; ++i) {
    205     EXPECT_EQ("foo", x.flat<Variant>()(i).get<Tensor>()->scalar<string>()());
    206   }
    207   for (int i = 0; i < 4; ++i) {
    208     EXPECT_EQ("foo",
    209               y.unaligned_flat<Variant>()(i).get<Tensor>()->scalar<string>()());
    210     EXPECT_EQ(42.0, z.flat<Variant>()(i).get<Tensor>()->scalar<float>()());
    211   }
    212 }
    213 
    214 TEST(TensorUtil, Concat) {
    215   std::vector<int64> sizes = {1, 4, 5};
    216   std::vector<Tensor> to_concat;
    217   int64 total_size = 0;
    218   int offset = 0;
    219   for (size_t entry = 0; entry < sizes.size(); ++entry) {
    220     const int64 size = sizes[entry];
    221     Tensor tensor(DT_INT32, TensorShape({size, 2}));
    222     for (int i = offset; i < offset + size; ++i) {
    223       for (int j = 0; j < 2; ++j) {
    224         tensor.matrix<int32>()(i - offset, j) = 2 * i + j;
    225       }
    226     }
    227     to_concat.push_back(tensor);
    228     total_size += size;
    229     offset += size;
    230   }
    231 
    232   Tensor concated;
    233   TF_ASSERT_OK(tensor::Concat(to_concat, &concated));
    234   ASSERT_EQ(TensorShape({total_size, 2}), concated.shape());
    235   for (int i = 0; i < total_size; ++i) {
    236     for (int j = 0; j < 2; ++j) {
    237       EXPECT_EQ(2 * i + j, concated.matrix<int32>()(i, j));
    238     }
    239   }
    240 }
    241 
    242 TEST(TensorUtil, Split) {
    243   Tensor to_split(DT_INT64, TensorShape({10, 2}));
    244   for (int i = 0; i < 10; ++i) {
    245     for (int j = 0; j < 2; ++j) {
    246       to_split.matrix<int64>()(i, j) = 2 * i + j;
    247     }
    248   }
    249 
    250   std::vector<int64> sizes = {1, 4, 5};
    251   std::vector<Tensor> splits;
    252   TF_ASSERT_OK(tensor::Split(to_split, sizes, &splits));
    253   ASSERT_EQ(sizes.size(), splits.size());
    254 
    255   int offset = 0;
    256   for (size_t entry = 0; entry < splits.size(); ++entry) {
    257     const int64 size = sizes[entry];
    258     const Tensor& split = splits[entry];
    259 
    260     ASSERT_EQ(TensorShape({size, 2}), split.shape());
    261     for (int i = offset; i < offset + size; ++i) {
    262       for (int j = 0; j < 2; ++j) {
    263         EXPECT_EQ(2 * i + j, split.matrix<int64>()(i - offset, j));
    264       }
    265     }
    266 
    267     offset += size;
    268   }
    269 }
    270 
    271 TEST(TensorUtil, ConcatSplitStrings) {
    272   Tensor x(DT_STRING, TensorShape({4, 3}));
    273   for (int i = 0; i < 4 * 3; ++i) {
    274     x.flat<string>()(i) = strings::StrCat("foo_", i);
    275   }
    276 
    277   std::vector<Tensor> split;
    278   TF_ASSERT_OK(tensor::Split(x, {2, 1, 1}, &split));
    279   Tensor x_round_tripped;
    280   TF_ASSERT_OK(tensor::Concat(split, &x_round_tripped));
    281   ASSERT_EQ(x.shape(), x_round_tripped.shape());
    282   for (int i = 0; i < 4 * 3; ++i) {
    283     EXPECT_EQ(x.flat<string>()(i), x_round_tripped.flat<string>()(i));
    284   }
    285 
    286   // Ensure that no memory is being shared between 'x' and 'x_round_tripped'.
    287   for (int i = 0; i < 4 * 3; ++i) {
    288     x_round_tripped.flat<string>()(i) = strings::StrCat("bar_", i);
    289   }
    290   for (int i = 0; i < 4 * 3; ++i) {
    291     EXPECT_NE(x.flat<string>()(i), x_round_tripped.flat<string>()(i));
    292   }
    293 }
    294 
    295 TEST(TensorProtoUtil, CreatesStringTensorProto) {
    296   std::vector<string> values{"a", "b", "c"};
    297   std::vector<size_t> shape{1, 3};
    298 
    299   auto proto = tensor::CreateTensorProto(values, shape);
    300 
    301   EXPECT_EQ(proto.DebugString(),
    302             "dtype: DT_STRING\n"
    303             "tensor_shape {\n"
    304             "  dim {\n"
    305             "    size: 1\n"
    306             "  }\n"
    307             "  dim {\n"
    308             "    size: 3\n"
    309             "  }\n"
    310             "}\n"
    311             "string_val: \"a\"\n"
    312             "string_val: \"b\"\n"
    313             "string_val: \"c\"\n");
    314 }
    315 
    316 TEST(TensorProtoUtil, CreatesInt32TensorProto) {
    317   std::vector<int32> values{1, 2};
    318   std::vector<size_t> shape{2};
    319 
    320   auto proto = tensor::CreateTensorProto(values, shape);
    321 
    322   EXPECT_EQ(proto.DebugString(),
    323             "dtype: DT_INT32\n"
    324             "tensor_shape {\n"
    325             "  dim {\n"
    326             "    size: 2\n"
    327             "  }\n"
    328             "}\n"
    329             "int_val: 1\n"
    330             "int_val: 2\n");
    331 }
    332 
    333 TEST(TensorProtoUtil, CreatesInt64TensorProto) {
    334   std::vector<int64> values{1, 2};
    335   std::vector<size_t> shape{2};
    336 
    337   auto proto = tensor::CreateTensorProto(values, shape);
    338 
    339   EXPECT_EQ(proto.DebugString(),
    340             "dtype: DT_INT64\n"
    341             "tensor_shape {\n"
    342             "  dim {\n"
    343             "    size: 2\n"
    344             "  }\n"
    345             "}\n"
    346             "int64_val: 1\n"
    347             "int64_val: 2\n");
    348 }
    349 
    350 TEST(TensorProtoUtil, CreatesUInt32TensorProto) {
    351   std::vector<uint32> values{1, 2};
    352   std::vector<size_t> shape{2};
    353 
    354   auto proto = tensor::CreateTensorProto(values, shape);
    355 
    356   EXPECT_EQ(proto.DebugString(),
    357             "dtype: DT_UINT32\n"
    358             "tensor_shape {\n"
    359             "  dim {\n"
    360             "    size: 2\n"
    361             "  }\n"
    362             "}\n"
    363             "uint32_val: 1\n"
    364             "uint32_val: 2\n");
    365 }
    366 
    367 TEST(TensorProtoUtil, CreatesUInt64TensorProto) {
    368   std::vector<uint64> values{1, 2};
    369   std::vector<size_t> shape{2};
    370 
    371   auto proto = tensor::CreateTensorProto(values, shape);
    372 
    373   EXPECT_EQ(proto.DebugString(),
    374             "dtype: DT_UINT64\n"
    375             "tensor_shape {\n"
    376             "  dim {\n"
    377             "    size: 2\n"
    378             "  }\n"
    379             "}\n"
    380             "uint64_val: 1\n"
    381             "uint64_val: 2\n");
    382 }
    383 
    384 TEST(TensorProtoUtil, CreatesFloatTensorProto) {
    385   std::vector<float> values{1.1, 2.2};
    386   std::vector<size_t> shape{2};
    387 
    388   auto proto = tensor::CreateTensorProto(values, shape);
    389 
    390   EXPECT_EQ(proto.DebugString(),
    391             "dtype: DT_FLOAT\n"
    392             "tensor_shape {\n"
    393             "  dim {\n"
    394             "    size: 2\n"
    395             "  }\n"
    396             "}\n"
    397             "float_val: 1.1\n"
    398             "float_val: 2.2\n");
    399 }
    400 
    401 TEST(TensorProtoUtil, CreatesDoubleTensorProto) {
    402   std::vector<double> values{1.1, 2.2};
    403   std::vector<size_t> shape{2};
    404 
    405   auto proto = tensor::CreateTensorProto(values, shape);
    406 
    407   EXPECT_EQ(proto.DebugString(),
    408             "dtype: DT_DOUBLE\n"
    409             "tensor_shape {\n"
    410             "  dim {\n"
    411             "    size: 2\n"
    412             "  }\n"
    413             "}\n"
    414             "double_val: 1.1\n"
    415             "double_val: 2.2\n");
    416 }
    417 
    418 TEST(TensorProtoUtil, CreatesBoolTensorProto) {
    419   std::vector<bool> values{true, false};
    420   std::vector<size_t> shape{2};
    421 
    422   auto proto = tensor::CreateTensorProto(values, shape);
    423 
    424   EXPECT_EQ(proto.DebugString(),
    425             "dtype: DT_BOOL\n"
    426             "tensor_shape {\n"
    427             "  dim {\n"
    428             "    size: 2\n"
    429             "  }\n"
    430             "}\n"
    431             "bool_val: true\n"
    432             "bool_val: false\n");
    433 }
    434 
    435 TEST(TensorProtoUtil, CompressTensorProtoInPlaceTooSmall) {
    436   const int kLength = 63;
    437   TensorProto tensor_proto =
    438       tensor::CreateTensorProto(std::vector<float>(kLength), {kLength});
    439   EXPECT_FALSE(tensor::CompressTensorProtoInPlace(&tensor_proto));
    440   tensor_proto =
    441       tensor::CreateTensorProto(std::vector<int>(kLength), {kLength});
    442   EXPECT_FALSE(tensor::CompressTensorProtoInPlace(&tensor_proto));
    443   tensor_proto =
    444       tensor::CreateTensorProto(std::vector<uint8>(kLength), {kLength});
    445   EXPECT_FALSE(tensor::CompressTensorProtoInPlace(&tensor_proto));
    446   tensor_proto =
    447       tensor::CreateTensorProto(std::vector<bool>(kLength), {kLength});
    448   EXPECT_FALSE(tensor::CompressTensorProtoInPlace(&tensor_proto));
    449   tensor_proto =
    450       tensor::CreateTensorProto(std::vector<Eigen::half>(kLength), {kLength});
    451   EXPECT_FALSE(tensor::CompressTensorProtoInPlace(&tensor_proto));
    452   tensor_proto = tensor::CreateTensorProto(
    453       std::vector<std::complex<float>>(kLength), {kLength});
    454   EXPECT_FALSE(tensor::CompressTensorProtoInPlace(&tensor_proto));
    455 }
    456 
    457 TEST(TensorProtoUtil, CompressTensorProtoInPlaceAllEqual) {
    458   const int kLength = 64;
    459   TensorProto tensor_proto =
    460       tensor::CreateTensorProto(std::vector<float>(kLength), {kLength});
    461   EXPECT_TRUE(tensor::CompressTensorProtoInPlace(&tensor_proto));
    462   EXPECT_EQ(tensor::internal::TensorProtoHelper<float>::NumValues(tensor_proto),
    463             1);
    464 
    465   tensor_proto =
    466       tensor::CreateTensorProto(std::vector<int>(kLength), {kLength});
    467   EXPECT_TRUE(tensor::CompressTensorProtoInPlace(&tensor_proto));
    468   EXPECT_EQ(tensor::internal::TensorProtoHelper<int>::NumValues(tensor_proto),
    469             1);
    470 
    471   tensor_proto =
    472       tensor::CreateTensorProto(std::vector<uint8>(kLength), {kLength});
    473   EXPECT_TRUE(tensor::CompressTensorProtoInPlace(&tensor_proto));
    474   EXPECT_EQ(tensor::internal::TensorProtoHelper<uint8>::NumValues(tensor_proto),
    475             1);
    476   tensor_proto =
    477       tensor::CreateTensorProto(std::vector<bool>(kLength), {kLength});
    478   EXPECT_TRUE(tensor::CompressTensorProtoInPlace(&tensor_proto));
    479   EXPECT_EQ(tensor::internal::TensorProtoHelper<bool>::NumValues(tensor_proto),
    480             1);
    481 
    482   tensor_proto =
    483       tensor::CreateTensorProto(std::vector<Eigen::half>(kLength), {kLength});
    484   EXPECT_TRUE(tensor::CompressTensorProtoInPlace(&tensor_proto));
    485   EXPECT_EQ(
    486       tensor::internal::TensorProtoHelper<Eigen::half>::NumValues(tensor_proto),
    487       1);
    488 
    489   tensor_proto = tensor::CreateTensorProto(
    490       std::vector<std::complex<float>>(kLength), {kLength});
    491   EXPECT_TRUE(tensor::CompressTensorProtoInPlace(&tensor_proto));
    492   EXPECT_EQ(tensor::internal::TensorProtoHelper<std::complex<float>>::NumValues(
    493                 tensor_proto),
    494             1);
    495 }
    496 
    497 template <typename T>
    498 std::vector<T> VectorWithConstantTail(int size, int tail_length) {
    499   CHECK_LE(tail_length, size);
    500   std::vector<T> v(size, T(0));
    501   for (int i = 0; i < size - tail_length; ++i) {
    502     v[i] = T(i + 1);
    503   }
    504   return v;
    505 }
    506 
    507 template <typename T>
    508 TensorProto CreateAsProtoTensorContent(int size, int tail_length) {
    509   auto values = VectorWithConstantTail<T>(size, tail_length);
    510   Tensor tensor(DataTypeToEnum<T>::value, TensorShape({size}));
    511   std::copy(values.begin(), values.end(), tensor.flat<T>().data());
    512   TensorProto tensor_proto;
    513   tensor.AsProtoTensorContent(&tensor_proto);
    514   return tensor_proto;
    515 }
    516 
    517 template <typename T>
    518 TensorProto CreateAsProtoField(int size, int tail_length) {
    519   auto values = VectorWithConstantTail<T>(size, tail_length);
    520   Tensor tensor(DataTypeToEnum<T>::value, TensorShape({size}));
    521   std::copy(values.begin(), values.end(), tensor.flat<T>().data());
    522   TensorProto tensor_proto;
    523   tensor.AsProtoField(&tensor_proto);
    524   return tensor_proto;
    525 }
    526 
    527 template <typename T>
    528 void CompareTensorValues(const TensorProto& x, const TensorProto& y) {
    529   Tensor x_t;
    530   EXPECT_TRUE(x_t.FromProto(x));
    531   Tensor y_t;
    532   EXPECT_TRUE(y_t.FromProto(y));
    533   test::ExpectTensorEqual<T>(x_t, y_t);
    534 }
    535 
    536 template <typename T>
    537 void ConstantTailTest(int64 length, int64 tail_length, bool as_field) {
    538   using TensorProtoHelper = tensor::internal::TensorProtoHelper<T>;
    539   using FieldType = typename TensorProtoHelper::FieldType;
    540   const float kMinCompressionRatio = 2.0;
    541   const int64 kMinSize = 64;
    542   TensorProto tensor_proto =
    543       as_field ? CreateAsProtoField<T>(length, tail_length)
    544                : CreateAsProtoTensorContent<T>(length, tail_length);
    545   TensorProto original_tensor_proto = tensor_proto;
    546   int64 original_size =
    547       length * (as_field ? (is_complex<T>::value ? 2 : 1) * sizeof(FieldType)
    548                          : sizeof(T));
    549   int64 size_as_tensor_content = length * sizeof(T);
    550   int64 size_as_field = std::min(length, (length - tail_length + 1)) *
    551                         (is_complex<T>::value ? 2 : 1) * sizeof(FieldType);
    552   bool will_compress = std::min(size_as_tensor_content, size_as_field) <=
    553                        static_cast<int64>(original_size / kMinCompressionRatio);
    554 
    555   EXPECT_EQ(tensor::CompressTensorProtoInPlace(kMinSize, kMinCompressionRatio,
    556                                                &tensor_proto),
    557             will_compress);
    558   if (will_compress) {
    559     if (size_as_tensor_content < size_as_field) {
    560       EXPECT_EQ(TensorProtoHelper::NumValues(tensor_proto), 0);
    561       EXPECT_FALSE(tensor_proto.tensor_content().empty());
    562     } else {
    563       EXPECT_LE(TensorProtoHelper::NumValues(tensor_proto),
    564                 (length - tail_length + 1));
    565       EXPECT_TRUE(tensor_proto.tensor_content().empty());
    566     }
    567   }
    568   CompareTensorValues<T>(tensor_proto, original_tensor_proto);
    569 }
    570 
    571 TEST(TensorProtoUtil, CompressTensorProtoConstantTail) {
    572   const int kLength = 64;
    573   for (bool as_field : {true, false}) {
    574     for (int tail_length : {0, 1, 2, 32, 33, 63, 64}) {
    575       ConstantTailTest<float>(kLength, tail_length, as_field);
    576       ConstantTailTest<double>(kLength, tail_length, as_field);
    577       ConstantTailTest<complex64>(kLength, tail_length, as_field);
    578       ConstantTailTest<complex128>(kLength, tail_length, as_field);
    579       ConstantTailTest<int32>(kLength, tail_length, as_field);
    580       ConstantTailTest<uint32>(kLength, tail_length, as_field);
    581       ConstantTailTest<int64>(kLength, tail_length, as_field);
    582       ConstantTailTest<uint64>(kLength, tail_length, as_field);
    583       ConstantTailTest<int8>(kLength, tail_length, as_field);
    584       ConstantTailTest<uint8>(kLength, tail_length, as_field);
    585       ConstantTailTest<int16>(kLength, tail_length, as_field);
    586       ConstantTailTest<uint16>(kLength, tail_length, as_field);
    587       ConstantTailTest<Eigen::half>(kLength, tail_length, as_field);
    588       ConstantTailTest<bfloat16>(kLength, tail_length, as_field);
    589     }
    590   }
    591 }
    592 
    593 }  // namespace
    594 }  // namespace tensorflow
    595