1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/util/tensor_slice_set.h" 17 18 #include <vector> 19 #include "tensorflow/core/lib/core/status.h" 20 #include "tensorflow/core/platform/logging.h" 21 #include "tensorflow/core/platform/test.h" 22 #include "tensorflow/core/platform/test_benchmark.h" 23 24 namespace tensorflow { 25 26 namespace checkpoint { 27 28 namespace { 29 30 // A simple test: we have a 2-d tensor of shape 4 X 5 that looks like this: 31 // 32 // 0 1 2 3 4 33 // 5 6 7 8 9 34 // 10 11 12 13 14 35 // 15 16 17 18 19 36 // 37 // We assume this is a row-major matrix. 38 // 39 // We store the tensor in a couple of slices and verify that we can recover all 40 // of them. 41 TEST(TensorSliceSetTest, QueryTwoD) { 42 TensorShape shape({4, 5}); 43 44 TensorSliceSet tss(shape, DT_FLOAT); 45 // We store a few slices. 46 47 // Slice #1 is the top two rows: 48 // 0 1 2 3 4 49 // 5 6 7 8 9 50 // . . . . . 51 // . . . . . 52 const float src_1[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; 53 TensorSlice slice_1 = TensorSlice::ParseOrDie("0,2:-"); 54 TF_CHECK_OK(tss.Register(slice_1, "", src_1)); 55 56 // Slice #2 is the bottom left corner 57 // . . . . . 58 // . . . . . 59 // 10 11 12 . . 60 // 15 16 17 . . 61 const float src_2[] = {10, 11, 12, 15, 16, 17}; 62 TensorSlice slice_2 = TensorSlice::ParseOrDie("2,2:0,3"); 63 TF_CHECK_OK(tss.Register(slice_2, "", src_2)); 64 65 // Slice #3 is the bottom right corner 66 // . . . . . 67 // . . . . . 68 // . . . . . 69 // . . . 18 19 70 const float src_3[] = {18, 19}; 71 TensorSlice slice_3 = TensorSlice::ParseOrDie("3,1:3,2"); 72 TF_CHECK_OK(tss.Register(slice_3, "", src_3)); 73 74 // Notice that we leave a hole in the tensor 75 // . . . . . 76 // . . . . . 77 // . . . (13) (14) 78 // . . . . . 79 80 // Now we query some of the slices 81 82 // Slice #1 is an exact match 83 // 0 1 2 3 4 84 // 5 6 7 8 9 85 // . . . . . 86 // . . . . . 87 { 88 TensorSlice s = TensorSlice::ParseOrDie("0,2:-"); 89 float expected[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; 90 float results[10]; 91 EXPECT_TRUE(tss.Query(s, results)); 92 for (int i = 0; i < 10; ++i) { 93 EXPECT_EQ(expected[i], results[i]); 94 } 95 } 96 97 // Slice #2 is a subset match 98 // . . . . . 99 // 5 6 7 8 9 100 // . . . . . 101 // . . . . . 102 { 103 TensorSlice s = TensorSlice::ParseOrDie("1,1:-"); 104 float expected[] = {5, 6, 7, 8, 9}; 105 float results[5]; 106 EXPECT_TRUE(tss.Query(s, results)); 107 for (int i = 0; i < 5; ++i) { 108 EXPECT_EQ(expected[i], results[i]); 109 } 110 } 111 112 // Slice #3 is a more complicated match: it needs the combination of a couple 113 // of slices 114 // . . . . . 115 // 5 6 7 . . 116 // 10 11 12 . . 117 // . . . . . 118 { 119 TensorSlice s = TensorSlice::ParseOrDie("1,2:0,3"); 120 float expected[] = {5, 6, 7, 10, 11, 12}; 121 float results[6]; 122 EXPECT_TRUE(tss.Query(s, results)); 123 for (int i = 0; i < 6; ++i) { 124 EXPECT_EQ(expected[i], results[i]); 125 } 126 } 127 128 // Slice #4 includes the hole and so there is no match 129 // . . . . . 130 // . . 7 8 9 131 // . . 12 13 14 132 // . . . . . 133 { 134 TensorSlice s = TensorSlice::ParseOrDie("1,2:2,3"); 135 float results[6]; 136 EXPECT_FALSE(tss.Query(s, results)); 137 } 138 } 139 140 // Testing the meta version of the tensor slice set. 141 TEST(TensorSliceSetTest, QueryMetaTwoD) { 142 TensorShape shape({4, 5}); 143 144 TensorSliceSet tss(shape, DT_INT32); 145 // We store a few slices. 146 147 // Slice #1 is the top two rows: 148 // 0 1 2 3 4 149 // 5 6 7 8 9 150 // . . . . . 151 // . . . . . 152 TensorSlice slice_1 = TensorSlice::ParseOrDie("0,2:-"); 153 TF_CHECK_OK(tss.Register(slice_1, "slice_1", nullptr)); 154 155 // Slice #2 is the bottom left corner 156 // . . . . . 157 // . . . . . 158 // 10 11 12 . . 159 // 15 16 17 . . 160 TensorSlice slice_2 = TensorSlice::ParseOrDie("2,2:0,3"); 161 TF_CHECK_OK(tss.Register(slice_2, "slice_2", nullptr)); 162 163 // Slice #3 is the bottom right corner 164 // . . . . . 165 // . . . . . 166 // . . . . . 167 // . . . 18 19 168 TensorSlice slice_3 = TensorSlice::ParseOrDie("3,1:3,2"); 169 TF_CHECK_OK(tss.Register(slice_3, "slice_3", nullptr)); 170 171 // Notice that we leave a hole in the tensor 172 // . . . . . 173 // . . . . . 174 // . . . (13) (14) 175 // . . . . . 176 177 // Now we query some of the slices 178 179 // Slice #1 is an exact match 180 // 0 1 2 3 4 181 // 5 6 7 8 9 182 // . . . . . 183 // . . . . . 184 // We just need slice_1 for this 185 { 186 TensorSlice s = TensorSlice::ParseOrDie("0,2:-"); 187 std::vector<std::pair<TensorSlice, string>> results; 188 EXPECT_TRUE(tss.QueryMeta(s, &results)); 189 EXPECT_EQ(1, results.size()); 190 EXPECT_EQ("0,2:-", results[0].first.DebugString()); 191 EXPECT_EQ("slice_1", results[0].second); 192 } 193 194 // Slice #2 is a subset match 195 // . . . . . 196 // 5 6 7 8 9 197 // . . . . . 198 // . . . . . 199 // We just need slice_1 for this 200 { 201 TensorSlice s = TensorSlice::ParseOrDie("1,1:-"); 202 std::vector<std::pair<TensorSlice, string>> results; 203 EXPECT_TRUE(tss.QueryMeta(s, &results)); 204 EXPECT_EQ(1, results.size()); 205 EXPECT_EQ("0,2:-", results[0].first.DebugString()); 206 EXPECT_EQ("slice_1", results[0].second); 207 } 208 209 // Slice #3 is a more complicated match: it needs the combination of a couple 210 // of slices 211 // . . . . . 212 // 5 6 7 . . 213 // 10 11 12 . . 214 // . . . . . 215 // We need both slice_1 and slice_2 for this. 216 { 217 TensorSlice s = TensorSlice::ParseOrDie("1,2:0,3"); 218 std::vector<std::pair<TensorSlice, string>> results; 219 EXPECT_TRUE(tss.QueryMeta(s, &results)); 220 EXPECT_EQ(2, results.size()); 221 EXPECT_EQ("2,2:0,3", results[0].first.DebugString()); 222 EXPECT_EQ("slice_2", results[0].second); 223 EXPECT_EQ("0,2:-", results[1].first.DebugString()); 224 EXPECT_EQ("slice_1", results[1].second); 225 } 226 227 // Slice #4 includes the hole and so there is no match 228 // . . . . . 229 // . . 7 8 9 230 // . . 12 13 14 231 // . . . . . 232 { 233 TensorSlice s = TensorSlice::ParseOrDie("1,2:2,3"); 234 std::vector<std::pair<TensorSlice, string>> results; 235 EXPECT_FALSE(tss.QueryMeta(s, &results)); 236 EXPECT_EQ(0, results.size()); 237 } 238 } 239 240 static void BM_RegisterOneByOne(int parts) { 241 TensorShape shape({parts, 41}); 242 TensorSliceSet slice_set(shape, DT_INT32); 243 for (int i = 0; i < parts; ++i) { 244 TensorSlice part({{i, 1}, {0, -1}}); 245 TF_CHECK_OK(slice_set.Register(part, part.DebugString(), nullptr)); 246 } 247 } 248 249 BENCHMARK(BM_RegisterOneByOne); 250 251 } // namespace 252 253 } // namespace checkpoint 254 255 } // namespace tensorflow 256