1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/contrib/coder/kernels/range_coder_ops_util.h" 17 18 #include <vector> 19 20 #include "tensorflow/core/framework/tensor_shape.h" 21 #include "tensorflow/core/lib/core/status.h" 22 #include "tensorflow/core/platform/macros.h" 23 #include "tensorflow/core/platform/types.h" 24 25 using tensorflow::errors::InvalidArgument; 26 27 namespace tensorflow { 28 Status MergeAxes(const TensorShape& broadcast_shape, 29 const TensorShape& storage_shape, 30 std::vector<int64>* merged_broadcast_shape_pointer, 31 std::vector<int64>* merged_storage_shape_pointer) { 32 CHECK_EQ(storage_shape.dims(), broadcast_shape.dims() + 1); 33 34 std::vector<int64>& merged_broadcast_shape = *merged_broadcast_shape_pointer; 35 std::vector<int64>& merged_storage_shape = *merged_storage_shape_pointer; 36 37 // The shapes are simplified so that the conversions between linear index 38 // and coordinates takes less CPU cycles. Two adjacent dimensions are 39 // merged if they both are broadcasting dimensions or if they both are 40 // non-broadcasting dimensions. 41 merged_broadcast_shape.resize(1); 42 merged_broadcast_shape[0] = 1; 43 merged_storage_shape.resize(1); 44 merged_storage_shape[0] = 1; 45 46 for (int i = 0, j = 0; j < broadcast_shape.dims(); ++j) { 47 if (TF_PREDICT_FALSE( 48 (broadcast_shape.dim_size(j) != storage_shape.dim_size(j)) && 49 (storage_shape.dim_size(j) != 1))) { 50 return InvalidArgument("Cannot broadcast shape ", 51 storage_shape.DebugString(), " to ", 52 broadcast_shape.DebugString()); 53 } 54 55 const bool was_broadcasting = (merged_storage_shape[i] == 1); 56 const bool is_broadcasting = (storage_shape.dim_size(j) == 1); 57 58 // Merge two adjacent axes if they both are broadcasting or both are 59 // non-broadcasting axes. The second and the third conditions in the if 60 // clause below are when the previously merged axis or the next j-th axis 61 // may be interpreted as either a broadcasting or a non-broadcasting axis. 62 const bool merge = (was_broadcasting == is_broadcasting) || 63 (broadcast_shape.dim_size(j) <= 1) || 64 (merged_broadcast_shape[i] <= 1); 65 66 if (merge) { 67 merged_broadcast_shape[i] *= broadcast_shape.dim_size(j); 68 merged_storage_shape[i] *= storage_shape.dim_size(j); 69 } else { 70 // Move to the next axis. 71 merged_broadcast_shape.push_back(broadcast_shape.dim_size(j)); 72 merged_storage_shape.push_back(storage_shape.dim_size(j)); 73 ++i; 74 } 75 } 76 77 int64 storage_stride = 1; 78 for (int i = broadcast_shape.dims(); i < storage_shape.dims(); ++i) { 79 storage_stride *= storage_shape.dim_size(i); 80 } 81 merged_storage_shape.push_back(storage_stride); 82 83 return Status::OK(); 84 } 85 } // namespace tensorflow 86