Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/contrib/coder/kernels/range_coder_ops_util.h"
     17 
     18 #include <vector>
     19 
     20 #include "tensorflow/core/framework/tensor_shape.h"
     21 #include "tensorflow/core/lib/core/status.h"
     22 #include "tensorflow/core/platform/macros.h"
     23 #include "tensorflow/core/platform/types.h"
     24 
     25 using tensorflow::errors::InvalidArgument;
     26 
     27 namespace tensorflow {
     28 Status MergeAxes(const TensorShape& broadcast_shape,
     29                  const TensorShape& storage_shape,
     30                  std::vector<int64>* merged_broadcast_shape_pointer,
     31                  std::vector<int64>* merged_storage_shape_pointer) {
     32   CHECK_EQ(storage_shape.dims(), broadcast_shape.dims() + 1);
     33 
     34   std::vector<int64>& merged_broadcast_shape = *merged_broadcast_shape_pointer;
     35   std::vector<int64>& merged_storage_shape = *merged_storage_shape_pointer;
     36 
     37   // The shapes are simplified so that the conversions between linear index
     38   // and coordinates takes less CPU cycles. Two adjacent dimensions are
     39   // merged if they both are broadcasting dimensions or if they both are
     40   // non-broadcasting dimensions.
     41   merged_broadcast_shape.resize(1);
     42   merged_broadcast_shape[0] = 1;
     43   merged_storage_shape.resize(1);
     44   merged_storage_shape[0] = 1;
     45 
     46   for (int i = 0, j = 0; j < broadcast_shape.dims(); ++j) {
     47     if (TF_PREDICT_FALSE(
     48             (broadcast_shape.dim_size(j) != storage_shape.dim_size(j)) &&
     49             (storage_shape.dim_size(j) != 1))) {
     50       return InvalidArgument("Cannot broadcast shape ",
     51                              storage_shape.DebugString(), " to ",
     52                              broadcast_shape.DebugString());
     53     }
     54 
     55     const bool was_broadcasting = (merged_storage_shape[i] == 1);
     56     const bool is_broadcasting = (storage_shape.dim_size(j) == 1);
     57 
     58     // Merge two adjacent axes if they both are broadcasting or both are
     59     // non-broadcasting axes. The second and the third conditions in the if
     60     // clause below are when the previously merged axis or the next j-th axis
     61     // may be interpreted as either a broadcasting or a non-broadcasting axis.
     62     const bool merge = (was_broadcasting == is_broadcasting) ||
     63                        (broadcast_shape.dim_size(j) <= 1) ||
     64                        (merged_broadcast_shape[i] <= 1);
     65 
     66     if (merge) {
     67       merged_broadcast_shape[i] *= broadcast_shape.dim_size(j);
     68       merged_storage_shape[i] *= storage_shape.dim_size(j);
     69     } else {
     70       // Move to the next axis.
     71       merged_broadcast_shape.push_back(broadcast_shape.dim_size(j));
     72       merged_storage_shape.push_back(storage_shape.dim_size(j));
     73       ++i;
     74     }
     75   }
     76 
     77   int64 storage_stride = 1;
     78   for (int i = broadcast_shape.dims(); i < storage_shape.dims(); ++i) {
     79     storage_stride *= storage_shape.dim_size(i);
     80   }
     81   merged_storage_shape.push_back(storage_stride);
     82 
     83   return Status::OK();
     84 }
     85 }  // namespace tensorflow
     86