Home | History | Annotate | Download | only in util
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/util/tensor_slice_set.h"
     17 
     18 #include <vector>
     19 #include "tensorflow/core/lib/core/status.h"
     20 #include "tensorflow/core/platform/logging.h"
     21 #include "tensorflow/core/platform/test.h"
     22 #include "tensorflow/core/platform/test_benchmark.h"
     23 
     24 namespace tensorflow {
     25 
     26 namespace checkpoint {
     27 
     28 namespace {
     29 
     30 // A simple test: we have a 2-d tensor of shape 4 X 5 that looks like this:
     31 //
     32 //   0   1   2   3   4
     33 //   5   6   7   8   9
     34 //  10  11  12  13  14
     35 //  15  16  17  18  19
     36 //
     37 // We assume this is a row-major matrix.
     38 //
     39 // We store the tensor in a couple of slices and verify that we can recover all
     40 // of them.
     41 TEST(TensorSliceSetTest, QueryTwoD) {
     42   TensorShape shape({4, 5});
     43 
     44   TensorSliceSet tss(shape, DT_FLOAT);
     45   // We store a few slices.
     46 
     47   // Slice #1 is the top two rows:
     48   //   0   1   2   3   4
     49   //   5   6   7   8   9
     50   //   .   .   .   .   .
     51   //   .   .   .   .   .
     52   const float src_1[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
     53   TensorSlice slice_1 = TensorSlice::ParseOrDie("0,2:-");
     54   TF_CHECK_OK(tss.Register(slice_1, "", src_1));
     55 
     56   // Slice #2 is the bottom left corner
     57   //   .   .   .   .   .
     58   //   .   .   .   .   .
     59   //  10  11  12   .   .
     60   //  15  16  17   .   .
     61   const float src_2[] = {10, 11, 12, 15, 16, 17};
     62   TensorSlice slice_2 = TensorSlice::ParseOrDie("2,2:0,3");
     63   TF_CHECK_OK(tss.Register(slice_2, "", src_2));
     64 
     65   // Slice #3 is the bottom right corner
     66   //   .   .   .   .   .
     67   //   .   .   .   .   .
     68   //   .   .   .   .   .
     69   //   .   .   .  18  19
     70   const float src_3[] = {18, 19};
     71   TensorSlice slice_3 = TensorSlice::ParseOrDie("3,1:3,2");
     72   TF_CHECK_OK(tss.Register(slice_3, "", src_3));
     73 
     74   // Notice that we leave a hole in the tensor
     75   //   .   .   .   .   .
     76   //   .   .   .   .   .
     77   //   .   .   . (13) (14)
     78   //   .   .   .   .   .
     79 
     80   // Now we query some of the slices
     81 
     82   // Slice #1 is an exact match
     83   //   0   1   2   3   4
     84   //   5   6   7   8   9
     85   //   .   .   .   .   .
     86   //   .   .   .   .   .
     87   {
     88     TensorSlice s = TensorSlice::ParseOrDie("0,2:-");
     89     float expected[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
     90     float results[10];
     91     EXPECT_TRUE(tss.Query(s, results));
     92     for (int i = 0; i < 10; ++i) {
     93       EXPECT_EQ(expected[i], results[i]);
     94     }
     95   }
     96 
     97   // Slice #2 is a subset match
     98   //   .   .   .   .   .
     99   //   5   6   7   8   9
    100   //   .   .   .   .   .
    101   //   .   .   .   .   .
    102   {
    103     TensorSlice s = TensorSlice::ParseOrDie("1,1:-");
    104     float expected[] = {5, 6, 7, 8, 9};
    105     float results[5];
    106     EXPECT_TRUE(tss.Query(s, results));
    107     for (int i = 0; i < 5; ++i) {
    108       EXPECT_EQ(expected[i], results[i]);
    109     }
    110   }
    111 
    112   // Slice #3 is a more complicated match: it needs the combination of a couple
    113   // of slices
    114   //   .   .   .   .   .
    115   //   5   6   7   .   .
    116   //  10  11  12   .   .
    117   //   .   .   .   .   .
    118   {
    119     TensorSlice s = TensorSlice::ParseOrDie("1,2:0,3");
    120     float expected[] = {5, 6, 7, 10, 11, 12};
    121     float results[6];
    122     EXPECT_TRUE(tss.Query(s, results));
    123     for (int i = 0; i < 6; ++i) {
    124       EXPECT_EQ(expected[i], results[i]);
    125     }
    126   }
    127 
    128   // Slice #4 includes the hole and so there is no match
    129   //   .   .   .   .   .
    130   //   .   .   7   8   9
    131   //   .   .  12  13  14
    132   //   .   .   .   .   .
    133   {
    134     TensorSlice s = TensorSlice::ParseOrDie("1,2:2,3");
    135     float results[6];
    136     EXPECT_FALSE(tss.Query(s, results));
    137   }
    138 }
    139 
    140 // Testing the meta version of the tensor slice set.
    141 TEST(TensorSliceSetTest, QueryMetaTwoD) {
    142   TensorShape shape({4, 5});
    143 
    144   TensorSliceSet tss(shape, DT_INT32);
    145   // We store a few slices.
    146 
    147   // Slice #1 is the top two rows:
    148   //   0   1   2   3   4
    149   //   5   6   7   8   9
    150   //   .   .   .   .   .
    151   //   .   .   .   .   .
    152   TensorSlice slice_1 = TensorSlice::ParseOrDie("0,2:-");
    153   TF_CHECK_OK(tss.Register(slice_1, "slice_1", nullptr));
    154 
    155   // Slice #2 is the bottom left corner
    156   //   .   .   .   .   .
    157   //   .   .   .   .   .
    158   //  10  11  12   .   .
    159   //  15  16  17   .   .
    160   TensorSlice slice_2 = TensorSlice::ParseOrDie("2,2:0,3");
    161   TF_CHECK_OK(tss.Register(slice_2, "slice_2", nullptr));
    162 
    163   // Slice #3 is the bottom right corner
    164   //   .   .   .   .   .
    165   //   .   .   .   .   .
    166   //   .   .   .   .   .
    167   //   .   .   .  18  19
    168   TensorSlice slice_3 = TensorSlice::ParseOrDie("3,1:3,2");
    169   TF_CHECK_OK(tss.Register(slice_3, "slice_3", nullptr));
    170 
    171   // Notice that we leave a hole in the tensor
    172   //   .   .   .   .   .
    173   //   .   .   .   .   .
    174   //   .   .   . (13) (14)
    175   //   .   .   .   .   .
    176 
    177   // Now we query some of the slices
    178 
    179   // Slice #1 is an exact match
    180   //   0   1   2   3   4
    181   //   5   6   7   8   9
    182   //   .   .   .   .   .
    183   //   .   .   .   .   .
    184   // We just need slice_1 for this
    185   {
    186     TensorSlice s = TensorSlice::ParseOrDie("0,2:-");
    187     std::vector<std::pair<TensorSlice, string>> results;
    188     EXPECT_TRUE(tss.QueryMeta(s, &results));
    189     EXPECT_EQ(1, results.size());
    190     EXPECT_EQ("0,2:-", results[0].first.DebugString());
    191     EXPECT_EQ("slice_1", results[0].second);
    192   }
    193 
    194   // Slice #2 is a subset match
    195   //   .   .   .   .   .
    196   //   5   6   7   8   9
    197   //   .   .   .   .   .
    198   //   .   .   .   .   .
    199   // We just need slice_1 for this
    200   {
    201     TensorSlice s = TensorSlice::ParseOrDie("1,1:-");
    202     std::vector<std::pair<TensorSlice, string>> results;
    203     EXPECT_TRUE(tss.QueryMeta(s, &results));
    204     EXPECT_EQ(1, results.size());
    205     EXPECT_EQ("0,2:-", results[0].first.DebugString());
    206     EXPECT_EQ("slice_1", results[0].second);
    207   }
    208 
    209   // Slice #3 is a more complicated match: it needs the combination of a couple
    210   // of slices
    211   //   .   .   .   .   .
    212   //   5   6   7   .   .
    213   //  10  11  12   .   .
    214   //   .   .   .   .   .
    215   // We need both slice_1 and slice_2 for this.
    216   {
    217     TensorSlice s = TensorSlice::ParseOrDie("1,2:0,3");
    218     std::vector<std::pair<TensorSlice, string>> results;
    219     EXPECT_TRUE(tss.QueryMeta(s, &results));
    220     EXPECT_EQ(2, results.size());
    221     EXPECT_EQ("2,2:0,3", results[0].first.DebugString());
    222     EXPECT_EQ("slice_2", results[0].second);
    223     EXPECT_EQ("0,2:-", results[1].first.DebugString());
    224     EXPECT_EQ("slice_1", results[1].second);
    225   }
    226 
    227   // Slice #4 includes the hole and so there is no match
    228   //   .   .   .   .   .
    229   //   .   .   7   8   9
    230   //   .   .  12  13  14
    231   //   .   .   .   .   .
    232   {
    233     TensorSlice s = TensorSlice::ParseOrDie("1,2:2,3");
    234     std::vector<std::pair<TensorSlice, string>> results;
    235     EXPECT_FALSE(tss.QueryMeta(s, &results));
    236     EXPECT_EQ(0, results.size());
    237   }
    238 }
    239 
    240 static void BM_RegisterOneByOne(int parts) {
    241   TensorShape shape({parts, 41});
    242   TensorSliceSet slice_set(shape, DT_INT32);
    243   for (int i = 0; i < parts; ++i) {
    244     TensorSlice part({{i, 1}, {0, -1}});
    245     TF_CHECK_OK(slice_set.Register(part, part.DebugString(), nullptr));
    246   }
    247 }
    248 
    249 BENCHMARK(BM_RegisterOneByOne);
    250 
    251 }  // namespace
    252 
    253 }  // namespace checkpoint
    254 
    255 }  // namespace tensorflow
    256