Home | History | Annotate | Download | only in util
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_UTIL_BCAST_H_
     17 #define TENSORFLOW_UTIL_BCAST_H_
     18 
     19 #include <algorithm>
     20 
     21 #include "tensorflow/core/framework/tensor_shape.h"
     22 #include "tensorflow/core/lib/gtl/inlined_vector.h"
     23 #include "tensorflow/core/platform/macros.h"
     24 #include "tensorflow/core/platform/types.h"
     25 
     26 namespace tensorflow {
     27 
     28 // BCast is a helper for broadcasting binary tensor operation.
     29 // TensorFlow's broadcasting rule follows that of numpy (See
     30 // http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).
     31 //
     32 // The rule has the following properties:
     33 //
     34 //   1. suffix matching: the rule starts with the right-most
     35 //      dimension, and works towards the left-most dimension. Since
     36 //      TensorFlow is row-major, the right-most dimension (the last
     37 //      element in the shape of a tensor) is the inner-most, a.k.a.
     38 //      the fastest changing, dimension.
     39 //
     40 //   2. Two dimensions are compatible for broadcasting if both are the
     41 //      same or either is 1.
     42 //
     43 // BCast takes the shape of two tensors and computes a few vectors of
     44 // int32 that are useful for the caller to reshape the tensors, apply
     45 // the right broadcasts to them, compute the broadcasted operation,
     46 // and possibly the gradients. In a nutshell, the caller is expected
     47 // to compute the broadcasted operation as following:
     48 //
     49 //   BCast b(x.shape(), y.shape());
     50 //   output = x.reshape(b.x_reshape()).broadcast(b.x_bcast())
     51 //            _op_
     52 //            y.reshape(b.y_reshape()).broadcast(b.y_bcast())
     53 //
     54 // For the gradient computation,
     55 //   grad_x = sum(grad * backprop_x(x, y), grad_x_reduce_idx)
     56 //            .reshape(x.shape())
     57 //   grad_y = sum(grad * backprop_y(x, y), grad_y_reduce_idx)
     58 //            .reshape(y.shape())
     59 // backprop_x and backprop_y are functionals of the binary function "op",
     60 // e.g.,
     61 //   for +, backprop_x(x, y) = backprop_y(x, y) = 1;
     62 //   for *, backprop_x(x, y) =  y, backprop_y(x, y) = x;
     63 //   for /, backprop_x(x, y) = 1/y, backprop_y(x, y) = -x/y^2;
     64 //
     65 // The multiplication in the grad * backprop_x itself is also
     66 // broadcasting following the same rule.
     67 //
     68 // TODO(zhifengc): Adds support for n-ary (n >= 2).
     69 class BCast {
     70  public:
     71   // A vector of int64 representing the shape of tensor. The 0-th
     72   // element is the outer-most dimension and the last element is the
     73   // inner-most dimension. Note that we do not use TensorShape since
     74   // it's more convenient to manipulate Vec directly for this module.
     75   typedef gtl::InlinedVector<int64, 4> Vec;
     76 
     77   // Constructs all helper shapes, following the aforementioned rules.
     78   //
     79   // If "fewer_dims_optimization" is set to true (the default), the
     80   // implementation tries to reduce intermediate dimensions needed to be more
     81   // efficient.  This is transparent to the caller.
     82   //
     83   // If false, all intermediate shapes (except for grad_{x,y}_reduce_idx()) have
     84   // the same number of dimensions as the larger of the two inputs.
     85   BCast(const Vec& x, const Vec& y, const bool fewer_dims_optimization = true);
     86   ~BCast() {}
     87 
     88   // Returns true iff two operands are compatible according to the
     89   // broadcasting rule.
     90   bool IsValid() const { return valid_; }
     91 
     92   // If and only if IsValid(), the following fields can be used in
     93   // implementing a broadcasted binary tensor operation according to
     94   // the broadcasting rule.
     95   const Vec& x_reshape() const { return x_reshape_; }
     96   const Vec& x_bcast() const { return x_bcast_; }
     97   const Vec& y_reshape() const { return y_reshape_; }
     98   const Vec& y_bcast() const { return y_bcast_; }
     99   const Vec& result_shape() const { return result_; }
    100   const Vec& output_shape() const { return output_; }
    101   const Vec& grad_x_reduce_idx() const { return grad_x_reduce_idx_; }
    102   const Vec& grad_y_reduce_idx() const { return grad_y_reduce_idx_; }
    103 
    104   // Static helpers.
    105   static Vec FromShape(const TensorShape& shape);
    106   static TensorShape ToShape(const BCast::Vec& vec);
    107 
    108   template <int NDIMS>
    109   static Eigen::array<Eigen::DenseIndex, NDIMS> ToIndexArray(
    110       const BCast::Vec& vec) {
    111     CHECK_EQ(vec.size(), NDIMS);
    112     Eigen::array<Eigen::DenseIndex, NDIMS> ret;
    113     for (int i = 0; i < NDIMS; ++i) ret[i] = vec[i];
    114     return ret;
    115   }
    116 
    117  private:
    118   bool valid_ = true;
    119   Vec x_reshape_;
    120   Vec x_bcast_;
    121   Vec y_reshape_;
    122   Vec y_bcast_;
    123   Vec result_;
    124   Vec output_;
    125   Vec grad_x_reduce_idx_;
    126   Vec grad_y_reduce_idx_;
    127 
    128   static void Reverse(Vec* shape);
    129 
    130   TF_DISALLOW_COPY_AND_ASSIGN(BCast);
    131 };
    132 
    133 }  // end namespace tensorflow
    134 
    135 #endif  // TENSORFLOW_UTIL_BCAST_H_
    136