Home | History | Annotate | Download | only in xla
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_REFERENCE_UTIL_H_
     17 #define TENSORFLOW_COMPILER_XLA_REFERENCE_UTIL_H_
     18 
     19 #include <array>
     20 #include <functional>
     21 #include <memory>
     22 #include <utility>
     23 #include <vector>
     24 
     25 #include "tensorflow/compiler/xla/array2d.h"
     26 #include "tensorflow/compiler/xla/array3d.h"
     27 #include "tensorflow/compiler/xla/array4d.h"
     28 #include "tensorflow/compiler/xla/client/padding.h"
     29 #include "tensorflow/compiler/xla/ptr_util.h"
     30 #include "tensorflow/compiler/xla/util.h"
     31 #include "tensorflow/compiler/xla/xla_data.pb.h"
     32 #include "tensorflow/core/lib/gtl/array_slice.h"
     33 #include "tensorflow/core/platform/macros.h"
     34 #include "tensorflow/core/platform/types.h"
     35 
     36 namespace xla {
     37 
     38 // Utility class for reference implementations of linear algebra routines.
     39 class ReferenceUtil {
     40  public:
     41   // Returns the result of a transpose operation on the input matrix.
     42   static std::unique_ptr<Array2D<float>> TransposeArray2D(
     43       const Array2D<float>& operand);
     44 
     45   // Returns the result of a matrix multiply `lhs x rhs`.
     46   static std::unique_ptr<Array2D<float>> MatmulArray2D(
     47       const Array2D<float>& lhs, const Array2D<float>& rhs);
     48   static std::unique_ptr<Array2D<double>> MatmulArray2D(
     49       const Array2D<double>& lhs, const Array2D<double>& rhs);
     50 
     51   // Converts the input operand to use f64 values instead of f32 values.
     52   static std::unique_ptr<Array2D<double>> Array2DF32ToF64(
     53       const Array2D<float>& input);
     54 
     55   // Returns the result of a convolution `lhs <conv> rhs`, with the default
     56   // convolution dimension numbers returned from
     57   // ComputationBuilder::CreateDefaultConvDimensionNumbers().
     58   static std::unique_ptr<Array4D<float>> ConvArray4D(
     59       const Array4D<float>& lhs, const Array4D<float>& rhs,
     60       std::pair<int64, int64> kernel_stride, Padding padding);
     61 
     62   // Returns the result of a convolution `lhs <conv> rhs`, with the given
     63   // convolution dimension numbers.
     64   static std::unique_ptr<Array4D<float>> ConvArray4DGeneralDimensions(
     65       const Array4D<float>& lhs, const Array4D<float>& rhs,
     66       std::pair<int64, int64> kernel_stride, Padding padding,
     67       ConvolutionDimensionNumbers dimension_numbers);
     68 
     69   // Returns the result of a convolution `lhs <conv> rhs`, with the given
     70   // dilation factors.
     71   static std::unique_ptr<Array4D<float>> ConvArray4DGeneralDimensionsDilated(
     72       const Array4D<float>& lhs, const Array4D<float>& rhs,
     73       std::pair<int64, int64> kernel_stride, Padding padding,
     74       std::pair<int64, int64> lhs_dilation,
     75       std::pair<int64, int64> rhs_dilation, ConvolutionDimensionNumbers dnums);
     76 
     77   // Returns the result of a convolution `lhs <conv> rhs`, with the default
     78   // convolution dimension numbers returned from
     79   // ComputationBuilder::CreateDefaultConvDimensionNumbers().
     80   static std::unique_ptr<Array3D<float>> ConvArray3D(const Array3D<float>& lhs,
     81                                                      const Array3D<float>& rhs,
     82                                                      int64 kernel_stride,
     83                                                      Padding padding);
     84 
     85   // Returns the result of a convolution `lhs <conv> rhs`.
     86   static std::unique_ptr<Array3D<float>> ConvArray3DGeneralDimensionsDilated(
     87       const Array3D<float>& lhs, const Array3D<float>& rhs, int64 kernel_stride,
     88       Padding padding, int64 lhs_dilation, int64 rhs_dilation,
     89       const ConvolutionDimensionNumbers& dnums);
     90 
     91   // Returns the result of a separable  convolution with the given parameters.
     92   // kernel_stride and padding applies to the depthwise convolution during
     93   // the separable convolution. pointwise_weights.depth() must be equal to
     94   // input.depth() * depthwise_weights.planes().
     95   static std::unique_ptr<Array4D<float>> SeparableConvArray4D(
     96       const Array4D<float>& input, const Array4D<float>& depthwise_weights,
     97       const Array4D<float>& pointwise_weights,
     98       std::pair<int64, int64> kernel_stride, Padding padding);
     99 
    100   // Returns the result of reducing a matrix to a column vector. init is the
    101   // initial value for the reduce operation, and reduce_function is the function
    102   // to apply for each reduction step.
    103   static std::unique_ptr<std::vector<float>> ReduceToColArray2D(
    104       const Array2D<float>& matrix, float init,
    105       const std::function<float(float, float)>& reduce_function);
    106 
    107   // Returns the result of reducing a matrix to a row vector. init is the
    108   // initial value for the reduce operation, and reduce_function is the function
    109   // to apply for each reduction step.
    110   static std::unique_ptr<std::vector<float>> ReduceToRowArray2D(
    111       const Array2D<float>& matrix, float init,
    112       const std::function<float(float, float)>& reduce_function);
    113 
    114   // Performs a R2=>R1 reduction by reducing away the dimension specified in
    115   // 'dimension_to_reduce'.
    116   template <typename T>
    117   static std::vector<T> ReduceR2ToR1(const Array2D<T>& input,
    118                                      int dimension_to_reduce, T init,
    119                                      const std::function<T(T, T)>& freduce) {
    120     std::vector<T> result(dimension_to_reduce == 0 ? input.n2() : input.n1(),
    121                           init);
    122     for (int i0 = 0; i0 < input.n1(); ++i0) {
    123       for (int i1 = 0; i1 < input.n2(); ++i1) {
    124         int output = dimension_to_reduce == 0 ? i1 : i0;
    125         result[output] = freduce(result[output], input(i0, i1));
    126       }
    127     }
    128     return result;
    129   }
    130 
    131   // Returns the result of reducing the 4D array to a vector, reducing away
    132   // the dimensions specified in dims.
    133   static std::vector<float> Reduce4DTo1D(
    134       const Array4D<float>& array, float init,
    135       tensorflow::gtl::ArraySlice<int64> dims,
    136       const std::function<float(float, float)>& reduce_function);
    137 
    138   // Broadcast 1D dimension to 4D, from the dimension `broadcast_from_dim`.
    139   static std::unique_ptr<Array4D<float>> Broadcast1DTo4D(
    140       const std::vector<float>& array, const std::vector<int64>& bounds,
    141       int64 broadcast_from_dim);
    142 
    143   // Returns the result of reducing the 3D array to a 2D array, reducing away
    144   // the dimensions specified in dims.
    145   static std::unique_ptr<Array2D<float>> Reduce3DTo2D(
    146       const Array3D<float>& array, float init,
    147       tensorflow::gtl::ArraySlice<int64> dims,
    148       const std::function<float(float, float)>& reduce_function);
    149 
    150   // Applies map_function to each element in the input (2D array) and returns
    151   // the result.
    152   static std::unique_ptr<Array2D<float>> MapArray2D(
    153       const Array2D<float>& matrix,
    154       const std::function<float(float)>& map_function);
    155 
    156   // Applies map_function to each pair of corresponding elements in the two
    157   // inputs arrays and returns the result.
    158   static std::unique_ptr<Array2D<float>> MapArray2D(
    159       const Array2D<float>& lhs, const Array2D<float>& rhs,
    160       const std::function<float(float, float)>& map_function);
    161 
    162   // Number of windows in a given dimension. Calculation taken from
    163   // xla::MakePadding().
    164   static int64 WindowCount(int64 unpadded_width, int64 window_len, int64 stride,
    165                            Padding padding);
    166 
    167   // Windowed reductions with Add as the function to apply.
    168   static std::unique_ptr<std::vector<float>> ReduceWindow1DAdd(
    169       const tensorflow::gtl::ArraySlice<float>& operand, float init,
    170       const tensorflow::gtl::ArraySlice<int64>& window,
    171       const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding);
    172   static std::unique_ptr<Array2D<float>> ReduceWindow2DAdd(
    173       const Array2D<float>& operand, float init,
    174       const tensorflow::gtl::ArraySlice<int64>& window,
    175       const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding);
    176   static std::unique_ptr<Array3D<float>> ReduceWindow3DAdd(
    177       const Array3D<float>& operand, float init,
    178       const tensorflow::gtl::ArraySlice<int64>& window,
    179       const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding);
    180   static std::unique_ptr<Array4D<float>> ReduceWindow4DAdd(
    181       const Array4D<float>& operand, float init,
    182       const tensorflow::gtl::ArraySlice<int64>& window,
    183       const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding);
    184 
    185   // Windowed reductions with a generic reduce function.
    186   static std::unique_ptr<std::vector<float>> ReduceWindow1DGeneric(
    187       const tensorflow::gtl::ArraySlice<float>& operand, float init,
    188       const std::function<float(float, float)>& reduce_func,
    189       const tensorflow::gtl::ArraySlice<int64>& window,
    190       const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding);
    191   static std::unique_ptr<std::vector<float>> ReduceWindow1DGeneric(
    192       const tensorflow::gtl::ArraySlice<float>& operand, float init,
    193       const std::function<float(float, float)>& reduce_func,
    194       const tensorflow::gtl::ArraySlice<int64>& window,
    195       const tensorflow::gtl::ArraySlice<int64>& stride,
    196       const tensorflow::gtl::ArraySlice<std::pair<int64, int64>>& padding);
    197   static std::unique_ptr<Array4D<float>> ReduceWindow4DGeneric(
    198       const Array4D<float>& operand, float init,
    199       const std::function<float(float, float)>& reduce_func,
    200       const tensorflow::gtl::ArraySlice<int64>& window,
    201       const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding);
    202   // With arbitrary padding.
    203   static std::unique_ptr<Array4D<float>> ReduceWindow4DGeneric(
    204       const Array4D<float>& operand, float init,
    205       const std::function<float(float, float)>& reduce_func,
    206       const tensorflow::gtl::ArraySlice<int64>& window,
    207       const tensorflow::gtl::ArraySlice<int64>& stride,
    208       const tensorflow::gtl::ArraySlice<std::pair<int64, int64>>& padding);
    209 
    210   // Batch normalize data.
    211   static std::unique_ptr<Array4D<float>> BatchNorm4D(
    212       const Array4D<float>& input, const Array4D<float>& mean,
    213       const Array4D<float>& var, const Array4D<float>& scale,
    214       const Array4D<float>& offset, float epsilon);
    215 
    216   // Performs select and scatter with Greater Than or equal as the select, plus
    217   // as the scatter, and Same Padding.
    218   static std::unique_ptr<Array4D<float>> SelectAndScatter4DGePlus(
    219       const Array4D<float>& operand, const Array4D<float>& source, float init,
    220       const tensorflow::gtl::ArraySlice<int64>& window,
    221       const tensorflow::gtl::ArraySlice<int64>& stride, bool same_padding);
    222 
    223   // Concatenates the lhs and rhs arrays along the concatenate_dimension.
    224   // E.g. if concatenate_dimension is 0, the "n1"/height dimension is
    225   // concatenated, so the arrays are stacked on top of each other.
    226   template <typename T>
    227   static std::unique_ptr<Array2D<T>> Concat2D(const Array2D<T>& lhs,
    228                                               const Array2D<T>& rhs,
    229                                               int concatenate_dimension) {
    230     CHECK(0 <= concatenate_dimension && concatenate_dimension < 2);
    231     auto result = MakeUnique<Array2D<T>>(
    232         concatenate_dimension == 0 ? lhs.n1() + rhs.n1() : lhs.n1(),
    233         concatenate_dimension == 1 ? lhs.n2() + rhs.n2() : lhs.n2());
    234     for (int64 i0 = 0; i0 < result->n1(); ++i0) {
    235       for (int64 i1 = 0; i1 < result->n2(); ++i1) {
    236         // If we exceed the bounds of the LHS, draw from the RHS, where the
    237         // result index is adjusted by the number of values present in the LHS.
    238         (*result)(i0, i1) = i0 < lhs.n1() && i1 < lhs.n2()
    239                                 ? lhs(i0, i1)
    240                                 : rhs(i0 >= lhs.n1() ? i0 - lhs.n1() : i0,
    241                                       i1 >= lhs.n2() ? i1 - lhs.n2() : i1);
    242       }
    243     }
    244     return result;
    245   }
    246 
    247   // Concatenates the lhs and rhs 3D arrays along the concatenate_dimension. lhs
    248   // and rhs must have the same dimensions except for the concatenate dimension.
    249   template <typename T>
    250   static std::unique_ptr<Array3D<T>> Concat3D(const Array3D<T>& lhs,
    251                                               const Array3D<T>& rhs,
    252                                               int concatenate_dimension) {
    253     CHECK(0 <= concatenate_dimension && concatenate_dimension < 3);
    254     std::vector<int64> lhs_dims = {lhs.n1(), lhs.n2(), lhs.n3()};
    255     std::vector<int64> rhs_dims = {rhs.n1(), rhs.n2(), rhs.n3()};
    256     std::vector<int64> out_dims = {rhs.n1(), rhs.n2(), rhs.n3()};
    257     for (int i = 0; i < 3; ++i) {
    258       if (i != concatenate_dimension) {
    259         out_dims[i] = lhs_dims[i];
    260         CHECK_EQ(lhs_dims[i], rhs_dims[i]);
    261       } else {
    262         out_dims[i] = lhs_dims[i] + rhs_dims[i];
    263       }
    264     }
    265     auto result = MakeUnique<Array3D<T>>(out_dims[0], out_dims[1], out_dims[2]);
    266     for (int64 i0 = 0; i0 < result->n1(); ++i0) {
    267       for (int64 i1 = 0; i1 < result->n2(); ++i1) {
    268         for (int64 i2 = 0; i2 < result->n3(); ++i2) {
    269           (*result)(i0, i1, i2) =
    270               i0 < lhs.n1() && i1 < lhs.n2() && i2 < lhs.n3()
    271                   ? lhs(i0, i1, i2)
    272                   : rhs(i0 >= lhs.n1() ? i0 - lhs.n1() : i0,
    273                         i1 >= lhs.n2() ? i1 - lhs.n2() : i1,
    274                         i2 >= lhs.n3() ? i2 - lhs.n3() : i2);
    275         }
    276       }
    277     }
    278     return result;
    279   }
    280 
    281   // Concatenates the lhs and rhs 4D arrays along the concatenate_dimension. lhs
    282   // and rhs must have the same dimensions except for the concatenate dimension.
    283   template <typename T>
    284   static std::unique_ptr<Array4D<T>> Concat4D(const Array4D<T>& lhs,
    285                                               const Array4D<T>& rhs,
    286                                               int concatenate_dimension) {
    287     CHECK(0 <= concatenate_dimension && concatenate_dimension < 4);
    288     std::vector<int64> lhs_dims = {lhs.n1(), lhs.n2(), lhs.n3(), lhs.n4()};
    289     std::vector<int64> rhs_dims = {rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()};
    290     std::vector<int64> out_dims = {rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()};
    291     for (int i = 0; i < 4; ++i) {
    292       if (i != concatenate_dimension) {
    293         out_dims[i] = lhs_dims[i];
    294         CHECK_EQ(lhs_dims[i], rhs_dims[i]);
    295       } else {
    296         out_dims[i] = lhs_dims[i] + rhs_dims[i];
    297       }
    298     }
    299     auto result = MakeUnique<Array4D<T>>(out_dims[0], out_dims[1], out_dims[2],
    300                                          out_dims[3]);
    301     for (int64 i0 = 0; i0 < result->n1(); ++i0) {
    302       for (int64 i1 = 0; i1 < result->n2(); ++i1) {
    303         for (int64 i2 = 0; i2 < result->n3(); ++i2) {
    304           for (int64 i3 = 0; i3 < result->n4(); ++i3) {
    305             (*result)(i0, i1, i2, i3) =
    306                 i0 < lhs.n1() && i1 < lhs.n2() && i2 < lhs.n3() && i3 < lhs.n4()
    307                     ? lhs(i0, i1, i2, i3)
    308                     : rhs(i0 >= lhs.n1() ? i0 - lhs.n1() : i0,
    309                           i1 >= lhs.n2() ? i1 - lhs.n2() : i1,
    310                           i2 >= lhs.n3() ? i2 - lhs.n3() : i2,
    311                           i3 >= lhs.n4() ? i3 - lhs.n4() : i3);
    312           }
    313         }
    314       }
    315     }
    316     return result;
    317   }
    318 
    319   // Slices with modulo-wrapping.
    320   template <typename T>
    321   static std::vector<T> ModSlice1D(const tensorflow::gtl::ArraySlice<T>& input,
    322                                    int64 start, int64 size) {
    323     std::vector<T> result;
    324     for (int64 i = 0; i < size; ++i) {
    325       result.push_back(input[(start + i) % input.size()]);
    326     }
    327     return result;
    328   }
    329 
    330   // Slices the input array given starting indices, limit indices, and strides
    331   // in each dimension.
    332   template <typename T>
    333   static std::unique_ptr<Array2D<T>> Slice2D(const Array2D<T>& input,
    334                                              std::array<int64, 2> starts,
    335                                              std::array<int64, 2> limits,
    336                                              std::array<int64, 2> strides) {
    337     CHECK_LE(starts[0], input.n1());
    338     CHECK_LE(starts[1], input.n2());
    339     CHECK_LE(limits[0], input.n1());
    340     CHECK_LE(limits[1], input.n2());
    341     CHECK_GE(strides[0], 1);
    342     CHECK_GE(strides[1], 1);
    343     auto result =
    344         MakeUnique<Array2D<T>>(CeilOfRatio(limits[0] - starts[0], strides[0]),
    345                                CeilOfRatio(limits[1] - starts[1], strides[1]));
    346     for (int64 i0 = 0; i0 < result->n1(); ++i0) {
    347       for (int64 i1 = 0; i1 < result->n2(); ++i1) {
    348         (*result)(i0, i1) =
    349             input(starts[0] + i0 * strides[0], starts[1] + i1 * strides[1]);
    350       }
    351     }
    352     return result;
    353   }
    354 
    355   template <typename T>
    356   static std::unique_ptr<Array3D<T>> Slice3D(const Array3D<T>& input,
    357                                              std::array<int64, 3> starts,
    358                                              std::array<int64, 3> limits,
    359                                              std::array<int64, 3> strides) {
    360     CHECK_LE(starts[0], input.n1());
    361     CHECK_LE(starts[1], input.n2());
    362     CHECK_LE(starts[2], input.n3());
    363     CHECK_LE(limits[0], input.n1());
    364     CHECK_LE(limits[1], input.n2());
    365     CHECK_LE(limits[2], input.n3());
    366     CHECK_GE(strides[0], 1);
    367     CHECK_GE(strides[1], 1);
    368     CHECK_GE(strides[2], 1);
    369     auto result =
    370         MakeUnique<Array3D<T>>(CeilOfRatio(limits[0] - starts[0], strides[0]),
    371                                CeilOfRatio(limits[1] - starts[1], strides[1]),
    372                                CeilOfRatio(limits[2] - starts[2], strides[2]));
    373 
    374     for (int64 i0 = 0; i0 < result->n1(); ++i0) {
    375       for (int64 i1 = 0; i1 < result->n2(); ++i1) {
    376         for (int64 i2 = 0; i2 < result->n3(); ++i2) {
    377           (*result)(i0, i1, i2) =
    378               input(starts[0] + i0 * strides[0], starts[1] + i1 * strides[1],
    379                     starts[2] + i2 * strides[2]);
    380         }
    381       }
    382     }
    383     return result;
    384   }
    385 
    386   template <typename T>
    387   static std::unique_ptr<Array4D<T>> Slice4D(const Array4D<T>& input,
    388                                              std::array<int64, 4> starts,
    389                                              std::array<int64, 4> limits,
    390                                              std::array<int64, 4> strides) {
    391     CHECK_LE(starts[0], input.n1());
    392     CHECK_LE(starts[1], input.n2());
    393     CHECK_LE(starts[2], input.n3());
    394     CHECK_LE(starts[3], input.n4());
    395     CHECK_LE(limits[0], input.n1());
    396     CHECK_LE(limits[1], input.n2());
    397     CHECK_LE(limits[2], input.n3());
    398     CHECK_LE(limits[3], input.n4());
    399     CHECK_GE(strides[0], 1);
    400     CHECK_GE(strides[1], 1);
    401     CHECK_GE(strides[2], 1);
    402     CHECK_GE(strides[3], 1);
    403     auto result =
    404         MakeUnique<Array4D<T>>(CeilOfRatio(limits[0] - starts[0], strides[0]),
    405                                CeilOfRatio(limits[1] - starts[1], strides[1]),
    406                                CeilOfRatio(limits[2] - starts[2], strides[2]),
    407                                CeilOfRatio(limits[3] - starts[3], strides[3]));
    408     for (int64 i0 = 0; i0 < result->n1(); ++i0) {
    409       for (int64 i1 = 0; i1 < result->n2(); ++i1) {
    410         for (int64 i2 = 0; i2 < result->n3(); ++i2) {
    411           for (int64 i3 = 0; i3 < result->n4(); ++i3) {
    412             (*result)(i0, i1, i2, i3) =
    413                 input(starts[0] + i0 * strides[0], starts[1] + i1 * strides[1],
    414                       starts[2] + i2 * strides[2], starts[3] + i3 * strides[3]);
    415           }
    416         }
    417       }
    418     }
    419     return result;
    420   }
    421 
    422   // Applies map_function to each element in the input (2D array) and returns
    423   // the result.
    424   // (row, column) index of each element is also provided as arguments to
    425   // map_function.
    426   static std::unique_ptr<Array2D<float>> MapWithIndexArray2D(
    427       const Array2D<float>& matrix,
    428       const std::function<float(float, int64, int64)>& map_function);
    429 
    430   // Applies map_function to each element in the input (4D array) and returns
    431   // the result.
    432   template <typename F>
    433   static std::unique_ptr<Array4D<float>> MapArray4D(const Array4D<float>& input,
    434                                                     F&& map_function) {
    435     return MapWithIndexArray4D(input,
    436                                [&](float value, int64, int64, int64, int64) {
    437                                  return map_function(value);
    438                                });
    439   }
    440 
    441   // Applies map_function to each element in the input (4D array) and returns
    442   // the result.
    443   // (plane, depth, height, width) index of each element is also provided as
    444   // arguments to map_function.
    445   template <typename F>
    446   static std::unique_ptr<Array4D<float>> MapWithIndexArray4D(
    447       const Array4D<float>& input, F&& map_function) {
    448     auto result = MakeUnique<Array4D<float>>(input.planes(), input.depth(),
    449                                              input.height(), input.width());
    450     for (int64 plane = 0; plane < input.planes(); ++plane) {
    451       for (int64 depth = 0; depth < input.depth(); ++depth) {
    452         for (int64 height = 0; height < input.height(); ++height) {
    453           for (int64 width = 0; width < input.width(); ++width) {
    454             (*result)(plane, depth, height, width) =
    455                 map_function(input(plane, depth, height, width), plane, depth,
    456                              height, width);
    457           }
    458         }
    459       }
    460     }
    461     return result;
    462   }
    463 
    464   // Applies map_function to each pair of elements in the input lhs and rhs
    465   // (4D array) and returns the result.
    466   template <typename F>
    467   static std::unique_ptr<Array4D<float>> MapArray4D(const Array4D<float>& lhs,
    468                                                     const Array4D<float>& rhs,
    469                                                     F&& map_function) {
    470     return MapWithIndexArray4D(
    471         lhs, rhs, [&](float lhs, float rhs, int64, int64, int64, int64) {
    472           return map_function(lhs, rhs);
    473         });
    474   }
    475 
    476   // Applies map_function to each pair of element in lhs and rhs (4D array) and
    477   // returns the result.
    478   // (plane, depth, height, width) index of each element is also provided as
    479   // arguments to map_function.
    480   template <typename F>
    481   static std::unique_ptr<Array4D<float>> MapWithIndexArray4D(
    482       const Array4D<float>& lhs, const Array4D<float>& rhs, F&& map_function) {
    483     auto result = MakeUnique<Array4D<float>>(lhs.planes(), lhs.depth(),
    484                                              lhs.height(), lhs.width());
    485     for (int64 plane = 0; plane < lhs.planes(); ++plane) {
    486       for (int64 depth = 0; depth < lhs.depth(); ++depth) {
    487         for (int64 height = 0; height < lhs.height(); ++height) {
    488           for (int64 width = 0; width < lhs.width(); ++width) {
    489             (*result)(plane, depth, height, width) = map_function(
    490                 lhs(plane, depth, height, width),
    491                 rhs(plane, depth, height, width), plane, depth, height, width);
    492           }
    493         }
    494       }
    495     }
    496     return result;
    497   }
    498 
    499   // Returns the result of a 2D pad on an input matrix.
    500   template <typename NativeT>
    501   static std::unique_ptr<Array2D<NativeT>> PadArray2D(
    502       const Array2D<NativeT>& operand, const PaddingConfig& padding,
    503       const NativeT pad) {
    504     int64 in0 = operand.n1();
    505     int64 high_padding0 = padding.dimensions(0).edge_padding_high();
    506     int64 low_padding0 = padding.dimensions(0).edge_padding_low();
    507     int64 interior_padding0 = padding.dimensions(0).interior_padding();
    508     int64 out0 =
    509         in0 + low_padding0 + high_padding0 + (in0 - 1) * interior_padding0;
    510 
    511     int64 in1 = operand.n2();
    512     int64 high_padding1 = padding.dimensions(1).edge_padding_high();
    513     int64 low_padding1 = padding.dimensions(1).edge_padding_low();
    514     int64 interior_padding1 = padding.dimensions(1).interior_padding();
    515     int64 out1 =
    516         in1 + low_padding1 + high_padding1 + (in1 - 1) * interior_padding1;
    517 
    518     auto result = MakeUnique<Array2D<NativeT>>(out0, out1);
    519     result->Fill(pad);
    520     int64 o0 = low_padding0;
    521     for (int64 i0 = 0; i0 < in0; ++i0) {
    522       int64 o1 = low_padding1;
    523       for (int64 i1 = 0; i1 < in1; ++i1) {
    524         if (o0 >= 0 && o1 >= 0 && o0 < out0 && o1 < out1) {
    525           (*result)(o0, o1) = operand(i0, i1);
    526         }
    527         o1 += interior_padding1 + 1;
    528       }
    529       o0 += interior_padding0 + 1;
    530     }
    531     return result;
    532   }
    533 
    534   // Returns the result of a 3D pad on an input matrix.
    535   template <typename NativeT>
    536   static Array3D<NativeT> PadArray3D(const Array3D<NativeT>& operand,
    537                                      const PaddingConfig& padding,
    538                                      const NativeT pad) {
    539     CHECK_EQ(padding.dimensions_size(), 3);
    540 
    541     const std::vector<int64> input_bounds = {operand.n1(), operand.n2(),
    542                                              operand.n3()};
    543     std::vector<int64> pad_low(3);
    544     std::vector<int64> pad_high(3);
    545     std::vector<int64> pad_interior(3);
    546     std::vector<int64> output_bounds(3);
    547     for (int64 i = 0; i < 3; ++i) {
    548       pad_low[i] = padding.dimensions(i).edge_padding_low();
    549       pad_high[i] = padding.dimensions(i).edge_padding_high();
    550       CHECK_LE(0, pad_low[i]);
    551       CHECK_LE(0, pad_high[i]);
    552       CHECK_LE(0, padding.dimensions(i).interior_padding())
    553           << "not implemented";
    554       pad_interior[i] = padding.dimensions(i).interior_padding();
    555 
    556       output_bounds[i] = pad_low[i] + input_bounds[i] + pad_high[i] +
    557                          (input_bounds[i] - 1) * pad_interior[i];
    558     }
    559 
    560     Array3D<NativeT> result(output_bounds[0], output_bounds[1],
    561                             output_bounds[2]);
    562     std::vector<int> indices = {0, 0, 0};
    563     for (indices[0] = 0; indices[0] < output_bounds[0]; ++indices[0]) {
    564       for (indices[1] = 0; indices[1] < output_bounds[1]; ++indices[1]) {
    565         for (indices[2] = 0; indices[2] < output_bounds[2]; ++indices[2]) {
    566           NativeT* value = &result(indices[0], indices[1], indices[2]);
    567           bool value_padded = false;
    568           for (int i = 0; i < 3; ++i) {
    569             bool in_low_padding = indices[i] < pad_low[i];
    570             bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i];
    571             if (in_low_padding || in_high_padding) {
    572               *value = pad;
    573               value_padded = true;
    574             }
    575             if (pad_interior[i] &&
    576                 (indices[i] - pad_low[i]) % (pad_interior[i] + 1)) {
    577               *value = pad;
    578               value_padded = true;
    579             }
    580           }
    581           if (value_padded) {
    582             continue;
    583           }
    584           *value = operand((indices[0] - pad_low[0]) / (pad_interior[0] + 1),
    585                            (indices[1] - pad_low[1]) / (pad_interior[1] + 1),
    586                            (indices[2] - pad_low[2]) / (pad_interior[2] + 1));
    587         }
    588       }
    589     }
    590     return result;
    591   }
    592 
    593   // Returns the result of a 4D pad on an input array.
    594   template <typename NativeT>
    595   static Array4D<NativeT> PadArray4D(const Array4D<NativeT>& operand,
    596                                      const PaddingConfig& padding,
    597                                      const NativeT pad) {
    598     CHECK_EQ(padding.dimensions_size(), 4);
    599 
    600     const std::vector<int64> input_bounds = {operand.n1(), operand.n2(),
    601                                              operand.n3(), operand.n4()};
    602     std::vector<int64> pad_low(4);
    603     std::vector<int64> pad_high(4);
    604     std::vector<int64> pad_interior(4);
    605     std::vector<int64> output_bounds(4);
    606     for (int64 i = 0; i < 4; ++i) {
    607       pad_low[i] = padding.dimensions(i).edge_padding_low();
    608       pad_high[i] = padding.dimensions(i).edge_padding_high();
    609       CHECK_LE(0, padding.dimensions(i).interior_padding())
    610           << "not implemented";
    611       pad_interior[i] = padding.dimensions(i).interior_padding();
    612 
    613       output_bounds[i] = pad_low[i] + input_bounds[i] + pad_high[i] +
    614                          (input_bounds[i] - 1) * pad_interior[i];
    615     }
    616 
    617     Array4D<NativeT> result(output_bounds[0], output_bounds[1],
    618                             output_bounds[2], output_bounds[3]);
    619     result.Each(
    620         [&](tensorflow::gtl::ArraySlice<int64> indices, NativeT* value) {
    621           for (int i = 0; i < 4; ++i) {
    622             bool in_low_padding = indices[i] < pad_low[i];
    623             bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i];
    624             if (in_low_padding || in_high_padding) {
    625               *value = pad;
    626               return;
    627             }
    628             if (pad_interior[i] &&
    629                 (indices[i] - pad_low[i]) % (pad_interior[i] + 1)) {
    630               *value = pad;
    631               return;
    632             }
    633           }
    634           *value = operand((indices[0] - pad_low[0]) / (pad_interior[0] + 1),
    635                            (indices[1] - pad_low[1]) / (pad_interior[1] + 1),
    636                            (indices[2] - pad_low[2]) / (pad_interior[2] + 1),
    637                            (indices[3] - pad_low[3]) / (pad_interior[3] + 1));
    638         });
    639     return result;
    640   }
    641 
    642   // ApplyElementwise2D(f, x, y, ...) returns the Array2D formed by running
    643   // f(x[i], y[i], ...) for each array element in the Array2Ds x, y, ....
    644   //
    645   // The given arrays must have the same size and element type, and the return
    646   // type of f must be implicitly convertible to the arrays' element type.
    647   //
    648   // Example usage:
    649   //
    650   //   Array2D<float> x, y, z = ...;
    651   //   std::unique_ptr<Array2D> result = ReferenceUtil::ApplyElementwise2D(
    652   //     [](float a, float b, float c) { return a * b + c; }, x, y, z);
    653   //
    654   template <typename F, typename T1, typename... Ts>
    655   static std::unique_ptr<Array2D<T1>> ApplyElementwise2D(
    656       F&& f, const Array2D<T1>& array1, const Array2D<Ts>&... arrays) {
    657     AssertSameSize2D(array1, arrays...);
    658     auto result = MakeUnique<Array2D<T1>>(array1.n1(), array1.n2());
    659     for (int64 i = 0; i < array1.n1(); ++i) {
    660       for (int64 j = 0; j < array1.n2(); ++j) {
    661         (*result)(i, j) = f(array1(i, j), arrays(i, j)...);
    662       }
    663     }
    664     return result;
    665   }
    666 
    667  private:
    668   template <typename T1, typename T2, typename... Ts>
    669   static void AssertSameSize2D(const Array2D<T1>& array1,
    670                                const Array2D<T2>& array2,
    671                                const Array2D<Ts>&... arrays) {
    672     static_assert(std::is_same<T1, T2>::value, "Args must be same type.");
    673     CHECK_EQ(array1.n1(), array2.n1());
    674     CHECK_EQ(array1.n2(), array2.n2());
    675     AssertSameSize2D(array2, arrays...);
    676   }
    677 
    678   // Recursive base case for AssertSameSize2D.
    679   template <typename Array1>
    680   static void AssertSameSize2D(const Array1& array1) {}
    681 
    682   TF_DISALLOW_COPY_AND_ASSIGN(ReferenceUtil);
    683 };
    684 
    685 }  // namespace xla
    686 
    687 #endif  // TENSORFLOW_COMPILER_XLA_REFERENCE_UTIL_H_
    688