Home | History | Annotate | Download | only in util
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_UTIL_BCAST_H_
     17 #define TENSORFLOW_CORE_UTIL_BCAST_H_
     18 
     19 #include <algorithm>
     20 
     21 #include "tensorflow/core/framework/tensor_shape.h"
     22 #include "tensorflow/core/lib/gtl/inlined_vector.h"
     23 #include "tensorflow/core/platform/macros.h"
     24 #include "tensorflow/core/platform/types.h"
     25 
     26 namespace tensorflow {
     27 
     28 // BCast is a helper for broadcasting binary tensor operation.
     29 // TensorFlow's broadcasting rule follows that of numpy (See
     30 // http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).
     31 //
     32 // The rule has the following properties:
     33 //
     34 //   1. suffix matching: the rule starts with the right-most
     35 //      dimension, and works towards the left-most dimension. Since
     36 //      TensorFlow is row-major, the right-most dimension (the last
     37 //      element in the shape of a tensor) is the inner-most, a.k.a.
     38 //      the fastest changing, dimension.
     39 //
     40 //   2. Two dimensions are compatible for broadcasting if both are the
     41 //      same or either is 1.
     42 //
     43 // BCast takes the shape of two tensors and computes a few vectors of
     44 // int32 that are useful for the caller to reshape the tensors, apply
     45 // the right broadcasts to them, compute the broadcasted operation,
     46 // and possibly the gradients. In a nutshell, the caller is expected
     47 // to compute the broadcasted operation as following:
     48 //
     49 //   BCast b(x.shape(), y.shape());
     50 //   output = x.reshape(b.x_reshape()).broadcast(b.x_bcast())
     51 //            _op_
     52 //            y.reshape(b.y_reshape()).broadcast(b.y_bcast())
     53 //
     54 // For the gradient computation,
     55 //   grad_x = sum(grad * backprop_x(x, y), grad_x_reduce_idx)
     56 //            .reshape(x.shape())
     57 //   grad_y = sum(grad * backprop_y(x, y), grad_y_reduce_idx)
     58 //            .reshape(y.shape())
     59 // backprop_x and backprop_y are functionals of the binary function "op",
     60 // e.g.,
     61 //   for +, backprop_x(x, y) = backprop_y(x, y) = 1;
     62 //   for *, backprop_x(x, y) =  y, backprop_y(x, y) = x;
     63 //   for /, backprop_x(x, y) = 1/y, backprop_y(x, y) = -x/y^2;
     64 //
     65 // The multiplication in the grad * backprop_x itself is also
     66 // broadcasting following the same rule.
     67 //
     68 // TODO(zhifengc): Adds support for n-ary (n >= 2).
     69 class BCast {
     70  public:
     71   // A vector of int64 representing the shape of tensor. The 0-th
     72   // element is the outer-most dimension and the last element is the
     73   // inner-most dimension. Note that we do not use TensorShape since
     74   // it's more convenient to manipulate Vec directly for this module.
     75   typedef gtl::InlinedVector<int64, 4> Vec;
     76 
     77   // Constructs all helper shapes, following the aforementioned rules.
     78   //
     79   // If "fewer_dims_optimization" is set to true (the default), the
     80   // implementation tries to reduce intermediate dimensions needed to be more
     81   // efficient.  This is transparent to the caller.
     82   //
     83   // If false, all intermediate shapes (except for grad_{x,y}_reduce_idx()) have
     84   // the same number of dimensions as the larger of the two inputs.
     85   BCast(const Vec& x, const Vec& y, const bool fewer_dims_optimization = true);
     86   ~BCast() {}
     87 
     88   // Returns true iff two operands are compatible according to the
     89   // broadcasting rule.
     90   bool IsValid() const { return valid_; }
     91 
     92   // If and only if IsValid(), the following fields can be used in
     93   // implementing a broadcasted binary tensor operation according to
     94   // the broadcasting rule.
     95   const Vec& x_reshape() const { return x_reshape_; }
     96   const Vec& x_bcast() const { return x_bcast_; }
     97   const Vec& y_reshape() const { return y_reshape_; }
     98   const Vec& y_bcast() const { return y_bcast_; }
     99   const Vec& result_shape() const { return result_; }
    100   const Vec& output_shape() const { return output_; }
    101   const Vec& grad_x_reduce_idx() const { return grad_x_reduce_idx_; }
    102   const Vec& grad_y_reduce_idx() const { return grad_y_reduce_idx_; }
    103 
    104   // Static helpers.
    105   static Vec FromShape(const TensorShape& shape);
    106   static TensorShape ToShape(const BCast::Vec& vec);
    107 
    108   template <typename IndexType, int NDIMS>
    109   static Eigen::array<IndexType, NDIMS> ToIndexArrayType(
    110       const BCast::Vec& vec) {
    111     CHECK_EQ(vec.size(), NDIMS);
    112     Eigen::array<IndexType, NDIMS> ret;
    113     for (int i = 0; i < NDIMS; ++i) ret[i] = vec[i];
    114     return ret;
    115   }
    116 
    117   template <int NDIMS>
    118   static Eigen::array<Eigen::DenseIndex, NDIMS> ToIndexArray(
    119       const BCast::Vec& vec) {
    120     return ToIndexArrayType<Eigen::DenseIndex, NDIMS>(vec);
    121   }
    122 
    123  private:
    124   bool valid_ = true;
    125   Vec x_reshape_;
    126   Vec x_bcast_;
    127   Vec y_reshape_;
    128   Vec y_bcast_;
    129   Vec result_;
    130   Vec output_;
    131   Vec grad_x_reduce_idx_;
    132   Vec grad_y_reduce_idx_;
    133 
    134   static void Reverse(Vec* shape);
    135 
    136   TF_DISALLOW_COPY_AND_ASSIGN(BCast);
    137 };
    138 
    139 }  // end namespace tensorflow
    140 
    141 #endif  // TENSORFLOW_CORE_UTIL_BCAST_H_
    142