Home | History | Annotate | Download | only in xla
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/reference_util.h"
     17 
     18 #include <array>
     19 #include <utility>
     20 
     21 #include "tensorflow/compiler/xla/client/computation_builder.h"
     22 #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
     23 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
     24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     25 #include "tensorflow/compiler/xla/service/shape_inference.h"
     26 #include "tensorflow/compiler/xla/window_util.h"
     27 #include "tensorflow/compiler/xla/xla_data.pb.h"
     28 #include "tensorflow/core/lib/math/math_util.h"
     29 #include "tensorflow/core/platform/logging.h"
     30 
     31 namespace xla {
     32 
     33 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::TransposeArray2D(
     34     const Array2D<float>& operand) {
     35   auto result = MakeUnique<Array2D<float>>(operand.width(), operand.height());
     36   for (int64 w = 0; w < operand.width(); ++w) {
     37     for (int64 h = 0; h < operand.height(); ++h) {
     38       (*result)(w, h) = operand(h, w);
     39     }
     40   }
     41 
     42   return result;
     43 }
     44 
     45 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MatmulArray2D(
     46     const Array2D<float>& lhs, const Array2D<float>& rhs) {
     47   CHECK_EQ(lhs.width(), rhs.height());
     48   int m = lhs.height();
     49   int n = rhs.width();
     50   int k = lhs.width();
     51   auto result = MakeUnique<Array2D<float>>(m, n);
     52   // Because Eigen is a header-oriented library, make sure that the Eigen code
     53   // is the same as the code used by the CPU backend (otherwise the linker will
     54   // randomly pick *some* definition).
     55   __xla_cpu_runtime_EigenSingleThreadedMatMulF32(
     56       /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m,
     57       k,
     58       /*transpose_lhs=*/0,
     59       /*transpose_rhs=*/0);
     60   return result;
     61 }
     62 
     63 /* static */ std::unique_ptr<Array2D<double>> ReferenceUtil::MatmulArray2D(
     64     const Array2D<double>& lhs, const Array2D<double>& rhs) {
     65   CHECK_EQ(lhs.width(), rhs.height());
     66   int m = lhs.height();
     67   int n = rhs.width();
     68   int k = lhs.width();
     69   auto result = MakeUnique<Array2D<double>>(m, n);
     70   // Because Eigen is a header-oriented library, make sure that the Eigen code
     71   // is the same as the code used by the CPU backend (otherwise the linker will
     72   // randomly pick *some* definition).
     73   __xla_cpu_runtime_EigenSingleThreadedMatMulF64(
     74       /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m,
     75       k,
     76       /*transpose_lhs=*/0,
     77       /*transpose_rhs=*/0);
     78   return result;
     79 }
     80 
     81 /* static */ std::unique_ptr<Array2D<double>> ReferenceUtil::Array2DF32ToF64(
     82     const Array2D<float>& input) {
     83   auto result = MakeUnique<Array2D<double>>(input.height(), input.width());
     84   for (int64 rowno = 0; rowno < input.height(); ++rowno) {
     85     for (int64 colno = 0; colno < input.height(); ++colno) {
     86       (*result)(rowno, colno) = input(rowno, colno);
     87     }
     88   }
     89   return result;
     90 }
     91 
     92 /*  static */ std::unique_ptr<Array3D<float>> ReferenceUtil::ConvArray3D(
     93     const Array3D<float>& lhs, const Array3D<float>& rhs, int64 kernel_stride,
     94     Padding padding) {
     95   return ConvArray3DGeneralDimensionsDilated(
     96       lhs, rhs, kernel_stride, padding, 1, 1,
     97       ComputationBuilder::CreateDefaultConvDimensionNumbers(1));
     98 }
     99 
    100 /*static*/ std::unique_ptr<Array3D<float>>
    101 ReferenceUtil::ConvArray3DGeneralDimensionsDilated(
    102     const Array3D<float>& lhs, const Array3D<float>& rhs, int64 kernel_stride,
    103     Padding padding, int64 lhs_dilation, int64 rhs_dilation,
    104     const ConvolutionDimensionNumbers& dnums) {
    105   CHECK_EQ(dnums.input_spatial_dimensions_size(), 1);
    106   CHECK_EQ(dnums.kernel_spatial_dimensions_size(), 1);
    107   CHECK_EQ(dnums.output_spatial_dimensions_size(), 1);
    108   // Reuse the code for Array4D-convolution by extending the 3D input into a 4D
    109   // array by adding a fourth dummy dimension of size 1 without stride, padding
    110   // and dilation.
    111   Array4D<float> a4dlhs(lhs.n1(), lhs.n2(), lhs.n3(), 1);
    112   a4dlhs.Each(
    113       [&](tensorflow::gtl::ArraySlice<int64> indices, float* value_ptr) {
    114         CHECK_EQ(indices[3], 0);
    115         *value_ptr = lhs.operator()(indices[0], indices[1], indices[2]);
    116       });
    117   Array4D<float> a4drhs(rhs.n1(), rhs.n2(), rhs.n3(), 1);
    118   a4drhs.Each(
    119       [&](tensorflow::gtl::ArraySlice<int64> indices, float* value_ptr) {
    120         CHECK_EQ(indices[3], 0);
    121         *value_ptr = rhs.operator()(indices[0], indices[1], indices[2]);
    122       });
    123   // Add a second dummy spatial dimensions.
    124   ConvolutionDimensionNumbers dnums2d = dnums;
    125   dnums2d.add_input_spatial_dimensions(3);
    126   dnums2d.add_kernel_spatial_dimensions(3);
    127   dnums2d.add_output_spatial_dimensions(3);
    128   std::unique_ptr<Array4D<float>> convr4 = ConvArray4DGeneralDimensionsDilated(
    129       a4dlhs, a4drhs, {kernel_stride, 1}, padding, {lhs_dilation, 1},
    130       {rhs_dilation, 1}, dnums2d);
    131 
    132   auto convr3 = MakeUnique<Array3D<float>>(convr4->planes(), convr4->depth(),
    133                                            convr4->height());
    134   convr4->Each(
    135       [&](tensorflow::gtl::ArraySlice<int64> indices, float* value_ptr) {
    136         CHECK_EQ(indices[3], 0);
    137         convr3->operator()(indices[0], indices[1], indices[2]) = *value_ptr;
    138       });
    139   return convr3;
    140 }
    141 
    142 /* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::ConvArray4D(
    143     const Array4D<float>& lhs, const Array4D<float>& rhs,
    144     std::pair<int64, int64> kernel_stride, Padding padding) {
    145   return ConvArray4DGeneralDimensions(
    146       lhs, rhs, kernel_stride, padding,
    147       ComputationBuilder::CreateDefaultConvDimensionNumbers());
    148 }
    149 
    150 /* static */ std::unique_ptr<Array4D<float>>
    151 ReferenceUtil::SeparableConvArray4D(const Array4D<float>& input,
    152                                     const Array4D<float>& depthwise_weights,
    153                                     const Array4D<float>& pointwise_weights,
    154                                     std::pair<int64, int64> kernel_stride,
    155                                     Padding padding) {
    156   const int64 depth_multiplier = depthwise_weights.planes();
    157   CHECK_EQ(pointwise_weights.depth(), input.depth() * depth_multiplier);
    158 
    159   // Combine the two weights by reducing the depth_multiplier, so that we can
    160   // apply a single convolution on the combined weights.
    161   Array4D<float> weights(pointwise_weights.planes(), input.depth(),
    162                          depthwise_weights.height(), depthwise_weights.width());
    163   for (int64 kx = 0; kx < depthwise_weights.width(); ++kx) {
    164     for (int64 ky = 0; ky < depthwise_weights.height(); ++ky) {
    165       for (int64 kz = 0; kz < input.depth(); ++kz) {
    166         for (int64 out = 0; out < pointwise_weights.planes(); ++out) {
    167           float weight = 0.0;
    168           for (int64 depth = 0; depth < depth_multiplier; ++depth) {
    169             weight +=
    170                 depthwise_weights(depth, kz, ky, kx) *
    171                 pointwise_weights(out, depth + kz * depth_multiplier, 0, 0);
    172           }
    173           weights(out, kz, ky, kx) = weight;
    174         }
    175       }
    176     }
    177   }
    178 
    179   return ConvArray4D(input, weights, kernel_stride, padding);
    180 }
    181 
    182 /* static */ int64 ReferenceUtil::WindowCount(int64 unpadded_width,
    183                                               int64 window_len, int64 stride,
    184                                               Padding padding) {
    185   if (padding == Padding::kValid) {
    186     return window_util::StridedBound(unpadded_width, window_len, stride);
    187   }
    188   return tensorflow::MathUtil::CeilOfRatio(unpadded_width, stride);
    189 }
    190 
    191 /* static  */ std::unique_ptr<std::vector<float>>
    192 ReferenceUtil::ReduceWindow1DGeneric(
    193     const tensorflow::gtl::ArraySlice<float>& operand, float init,
    194     const std::function<float(float, float)>& reduce_func,
    195     const tensorflow::gtl::ArraySlice<int64>& window,
    196     const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) {
    197   std::vector<int64> dim_lengths{static_cast<int64>(operand.size())};
    198   return ReduceWindow1DGeneric(
    199       operand, init, reduce_func, window, stride,
    200       xla::MakePadding(dim_lengths, window, stride, padding));
    201 }
    202 
    203 /* static  */ std::unique_ptr<std::vector<float>>
    204 ReferenceUtil::ReduceWindow1DGeneric(
    205     const tensorflow::gtl::ArraySlice<float>& operand, float init,
    206     const std::function<float(float, float)>& reduce_func,
    207     const tensorflow::gtl::ArraySlice<int64>& window,
    208     const tensorflow::gtl::ArraySlice<int64>& stride,
    209     const tensorflow::gtl::ArraySlice<std::pair<int64, int64>>& padding) {
    210   std::vector<int64> dim_lengths{static_cast<int64>(operand.size())};
    211   std::vector<int64> window_counts(window.size(), 0);
    212   std::vector<int64> pad_low(window.size(), 0);
    213   for (int64 i = 0; i < window.size(); ++i) {
    214     int64 padded_width = padding[i].first + dim_lengths[i] + padding[i].second;
    215     window_counts[i] =
    216         window_util::StridedBound(padded_width, window[i], stride[i]);
    217     pad_low[i] = padding[i].first;
    218   }
    219   auto result = MakeUnique<std::vector<float>>(window_counts[0]);
    220 
    221   // Do a full 1D reduce window.
    222   for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
    223     int64 i0_base = i0 * stride[0] - pad_low[0];
    224 
    225     float val = init;
    226     for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) {
    227       if (i0_base + i0_win >= 0 && i0_base + i0_win < dim_lengths[0]) {
    228         val = reduce_func(val, operand[i0_base + i0_win]);
    229       }
    230     }
    231     (*result)[i0] = val;
    232   }
    233   return result;
    234 }
    235 
    236 /* static  */ std::unique_ptr<std::vector<float>>
    237 ReferenceUtil::ReduceWindow1DAdd(
    238     const tensorflow::gtl::ArraySlice<float>& operand, float init,
    239     const tensorflow::gtl::ArraySlice<int64>& window,
    240     const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) {
    241   const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
    242   return ReduceWindow1DGeneric(operand, init, add_reduce, window, stride,
    243                                padding);
    244 }
    245 
    246 /* static  */ std::unique_ptr<Array2D<float>> ReferenceUtil::ReduceWindow2DAdd(
    247     const Array2D<float>& operand, float init,
    248     const tensorflow::gtl::ArraySlice<int64>& window,
    249     const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) {
    250   std::vector<int64> dim_lengths{operand.height(), operand.width()};
    251   auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
    252 
    253   std::vector<int64> window_counts(window.size(), 0);
    254   std::vector<int64> pad_low(window.size(), 0);
    255   for (int64 i = 0; i < window.size(); ++i) {
    256     window_counts[i] =
    257         WindowCount(dim_lengths[i], window[i], stride[i], padding);
    258     pad_low[i] = padding_both[i].first;
    259   }
    260   auto result = MakeUnique<Array2D<float>>(window_counts[0], window_counts[1]);
    261 
    262   // Do a full 2D reduce window.
    263   for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
    264     for (int64 i1 = 0; i1 < window_counts[1]; ++i1) {
    265       int64 i0_base = i0 * stride[0] - pad_low[0];
    266       int64 i1_base = i1 * stride[1] - pad_low[1];
    267 
    268       float val = init;
    269       for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) {
    270         for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) {
    271           if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
    272               i0_base + i0_win < operand.n1() &&
    273               i1_base + i1_win < operand.n2()) {
    274             val += operand(i0_base + i0_win, i1_base + i1_win);
    275           }
    276         }
    277       }
    278       (*result)(i0, i1) = val;
    279     }
    280   }
    281   return result;
    282 }
    283 
    284 /* static  */ std::unique_ptr<Array3D<float>> ReferenceUtil::ReduceWindow3DAdd(
    285     const Array3D<float>& operand, float init,
    286     const tensorflow::gtl::ArraySlice<int64>& window,
    287     const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) {
    288   std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3()};
    289   auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
    290 
    291   std::vector<int64> window_counts(window.size(), 0);
    292   std::vector<int64> pad_low(window.size(), 0);
    293   for (int64 i = 0; i < window.size(); ++i) {
    294     window_counts[i] =
    295         WindowCount(dim_lengths[i], window[i], stride[i], padding);
    296     pad_low[i] = padding_both[i].first;
    297   }
    298   auto result = MakeUnique<Array3D<float>>(window_counts[0], window_counts[1],
    299                                            window_counts[2]);
    300 
    301   for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
    302     for (int64 i1 = 0; i1 < window_counts[1]; ++i1) {
    303       for (int64 i2 = 0; i2 < window_counts[2]; ++i2) {
    304         int64 i0_base = i0 * stride[0] - pad_low[0];
    305         int64 i1_base = i1 * stride[1] - pad_low[1];
    306         int64 i2_base = i2 * stride[2] - pad_low[2];
    307 
    308         float val = init;
    309         for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) {
    310           for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) {
    311             for (int64 i2_win = 0; i2_win < window[2]; ++i2_win) {
    312               if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
    313                   i2_base + i2_win >= 0 && i0_base + i0_win < operand.n1() &&
    314                   i1_base + i1_win < operand.n2() &&
    315                   i2_base + i2_win < operand.n3()) {
    316                 val += operand(i0_base + i0_win, i1_base + i1_win,
    317                                i2_base + i2_win);
    318               }
    319             }
    320           }
    321         }
    322         (*result)(i0, i1, i2) = val;
    323       }
    324     }
    325   }
    326   return result;
    327 }
    328 
    329 /* static */ std::unique_ptr<Array4D<float>>
    330 ReferenceUtil::ReduceWindow4DGeneric(
    331     const Array4D<float>& operand, float init,
    332     const std::function<float(float, float)>& reduce_func,
    333     const tensorflow::gtl::ArraySlice<int64>& window,
    334     const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) {
    335   std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
    336                                  operand.n4()};
    337   return ReduceWindow4DGeneric(
    338       operand, init, reduce_func, window, stride,
    339       xla::MakePadding(dim_lengths, window, stride, padding));
    340 }
    341 
    342 /* static */ std::unique_ptr<Array4D<float>>
    343 ReferenceUtil::ReduceWindow4DGeneric(
    344     const Array4D<float>& operand, float init,
    345     const std::function<float(float, float)>& reduce_func,
    346     const tensorflow::gtl::ArraySlice<int64>& window,
    347     const tensorflow::gtl::ArraySlice<int64>& stride,
    348     const tensorflow::gtl::ArraySlice<std::pair<int64, int64>>& padding) {
    349   std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
    350                                  operand.n4()};
    351 
    352   std::vector<int64> window_counts(window.size(), 0);
    353   std::vector<int64> pad_low(window.size(), 0);
    354   for (int64 i = 0; i < window.size(); ++i) {
    355     int64 padded_width = padding[i].first + dim_lengths[i] + padding[i].second;
    356     window_counts[i] =
    357         window_util::StridedBound(padded_width, window[i], stride[i]);
    358     pad_low[i] = padding[i].first;
    359   }
    360   auto result = MakeUnique<Array4D<float>>(window_counts[0], window_counts[1],
    361                                            window_counts[2], window_counts[3]);
    362   // Do a full 4D reduce window.
    363   for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
    364     for (int64 i1 = 0; i1 < window_counts[1]; ++i1) {
    365       for (int64 i2 = 0; i2 < window_counts[2]; ++i2) {
    366         for (int64 i3 = 0; i3 < window_counts[3]; ++i3) {
    367           int64 i0_base = i0 * stride[0] - pad_low[0];
    368           int64 i1_base = i1 * stride[1] - pad_low[1];
    369           int64 i2_base = i2 * stride[2] - pad_low[2];
    370           int64 i3_base = i3 * stride[3] - pad_low[3];
    371 
    372           float val = init;
    373           for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) {
    374             for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) {
    375               for (int64 i2_win = 0; i2_win < window[2]; ++i2_win) {
    376                 for (int64 i3_win = 0; i3_win < window[3]; ++i3_win) {
    377                   if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
    378                       i2_base + i2_win >= 0 && i3_base + i3_win >= 0 &&
    379                       i0_base + i0_win < operand.n1() &&
    380                       i1_base + i1_win < operand.n2() &&
    381                       i2_base + i2_win < operand.n3() &&
    382                       i3_base + i3_win < operand.n4()) {
    383                     val = reduce_func(
    384                         val, operand(i0_base + i0_win, i1_base + i1_win,
    385                                      i2_base + i2_win, i3_base + i3_win));
    386                   }
    387                 }
    388               }
    389             }
    390           }
    391           (*result)(i0, i1, i2, i3) = val;
    392         }
    393       }
    394     }
    395   }
    396   return result;
    397 }
    398 
    399 /* static  */ std::unique_ptr<Array4D<float>> ReferenceUtil::ReduceWindow4DAdd(
    400     const Array4D<float>& operand, float init,
    401     const tensorflow::gtl::ArraySlice<int64>& window,
    402     const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) {
    403   const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
    404   return ReduceWindow4DGeneric(operand, init, add_reduce, window, stride,
    405                                padding);
    406 }
    407 
    408 /* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::BatchNorm4D(
    409     const Array4D<float>& input, const Array4D<float>& mean,
    410     const Array4D<float>& var, const Array4D<float>& scale,
    411     const Array4D<float>& offset, float epsilon) {
    412   auto normalized =
    413       *MapArray4D(input, mean, [](float a, float b) { return a - b; });
    414   normalized = *MapArray4D(normalized, var, [&](float a, float b) {
    415     return a / std::sqrt(b + epsilon);
    416   });
    417   normalized =
    418       *MapArray4D(normalized, scale, [](float a, float b) { return a * b; });
    419   return MapArray4D(normalized, offset, [](float a, float b) { return a + b; });
    420 }
    421 
    422 /* static  */ std::unique_ptr<Array4D<float>>
    423 ReferenceUtil::SelectAndScatter4DGePlus(
    424     const Array4D<float>& operand, const Array4D<float>& source, float init,
    425     const tensorflow::gtl::ArraySlice<int64>& window,
    426     const tensorflow::gtl::ArraySlice<int64>& stride, bool same_padding) {
    427   Padding padding = same_padding ? Padding::kSame : Padding::kValid;
    428   auto result = MakeUnique<Array4D<float>>(operand.n1(), operand.n2(),
    429                                            operand.n3(), operand.n4());
    430   std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
    431                                  operand.n4()};
    432   auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
    433   // Fill the output, with the initial value.
    434   result->Fill(init);
    435 
    436   std::vector<int64> window_counts(window.size(), 0);
    437   std::vector<int64> pad_low(window.size(), 0);
    438   for (int64 i = 0; i < window.size(); ++i) {
    439     window_counts[i] =
    440         WindowCount(dim_lengths[i], window[i], stride[i], padding);
    441     pad_low[i] = padding_both[i].first;
    442   }
    443   CHECK_EQ(window_counts[0], source.n1());
    444   CHECK_EQ(window_counts[1], source.n2());
    445   CHECK_EQ(window_counts[2], source.n3());
    446   CHECK_EQ(window_counts[3], source.n4());
    447 
    448   // Do a full 4D select and Scatter.
    449   for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
    450     for (int64 i1 = 0; i1 < window_counts[1]; ++i1) {
    451       for (int64 i2 = 0; i2 < window_counts[2]; ++i2) {
    452         for (int64 i3 = 0; i3 < window_counts[3]; ++i3) {
    453           // Now we are inside a window and need to find the max and the argmax.
    454           int64 i0_base = i0 * stride[0] - pad_low[0];
    455           int64 i1_base = i1 * stride[1] - pad_low[1];
    456           int64 i2_base = i2 * stride[2] - pad_low[2];
    457           int64 i3_base = i3 * stride[3] - pad_low[3];
    458           int64 scatter_0 = (i0_base >= 0) ? i0_base : 0;
    459           int64 scatter_1 = (i1_base >= 0) ? i1_base : 0;
    460           int64 scatter_2 = (i2_base >= 0) ? i2_base : 0;
    461           int64 scatter_3 = (i3_base >= 0) ? i3_base : 0;
    462           float val = operand(scatter_0, scatter_1, scatter_2, scatter_3);
    463           for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) {
    464             for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) {
    465               for (int64 i2_win = 0; i2_win < window[2]; ++i2_win) {
    466                 for (int64 i3_win = 0; i3_win < window[3]; ++i3_win) {
    467                   if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
    468                       i2_base + i2_win >= 0 && i3_base + i3_win >= 0 &&
    469                       i0_base + i0_win < operand.n1() &&
    470                       i1_base + i1_win < operand.n2() &&
    471                       i2_base + i2_win < operand.n3() &&
    472                       i3_base + i3_win < operand.n4()) {
    473                     float tmp = operand(i0_base + i0_win, i1_base + i1_win,
    474                                         i2_base + i2_win, i3_base + i3_win);
    475                     if (tmp >= val) {
    476                       val = tmp;
    477                       scatter_0 = i0_base + i0_win;
    478                       scatter_1 = i1_base + i1_win;
    479                       scatter_2 = i2_base + i2_win;
    480                       scatter_3 = i3_base + i3_win;
    481                     }
    482                   }
    483                 }
    484               }
    485             }
    486           }
    487           (*result)(scatter_0, scatter_1, scatter_2, scatter_3) +=
    488               source(i0, i1, i2, i3);
    489         }
    490       }
    491     }
    492   }
    493   return result;
    494 }
    495 
    496 /* static */ std::unique_ptr<Array4D<float>>
    497 ReferenceUtil::ConvArray4DGeneralDimensions(
    498     const Array4D<float>& lhs, const Array4D<float>& rhs,
    499     std::pair<int64, int64> kernel_stride, Padding padding,
    500     ConvolutionDimensionNumbers dimension_numbers) {
    501   return ConvArray4DGeneralDimensionsDilated(lhs, rhs, kernel_stride, padding,
    502                                              {1, 1}, {1, 1},
    503                                              std::move(dimension_numbers));
    504 }
    505 
    506 /* static */ std::unique_ptr<Array4D<float>>
    507 ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
    508     const Array4D<float>& lhs, const Array4D<float>& rhs,
    509     std::pair<int64, int64> kernel_stride, Padding padding,
    510     std::pair<int64, int64> lhs_dilation, std::pair<int64, int64> rhs_dilation,
    511     ConvolutionDimensionNumbers dnums) {
    512   HloComputation::Builder b("ConvArray4DGeneralDimensionDilated");
    513   auto lhs_literal = Literal::CreateR4FromArray4D<float>(lhs);
    514   auto rhs_literal = Literal::CreateR4FromArray4D<float>(rhs);
    515 
    516   std::array<int64, 2> ordered_kernel_strides;
    517   std::array<int64, 2> ordered_input_dimensions;
    518   std::array<int64, 2> ordered_kernel_dimensions;
    519   if (dnums.kernel_spatial_dimensions(0) > dnums.kernel_spatial_dimensions(1)) {
    520     ordered_kernel_strides[0] = kernel_stride.second;
    521     ordered_kernel_strides[1] = kernel_stride.first;
    522   } else {
    523     ordered_kernel_strides[0] = kernel_stride.first;
    524     ordered_kernel_strides[1] = kernel_stride.second;
    525   }
    526 
    527   ordered_input_dimensions[0] =
    528       lhs_literal->shape().dimensions(dnums.input_spatial_dimensions(0));
    529   ordered_input_dimensions[1] =
    530       lhs_literal->shape().dimensions(dnums.input_spatial_dimensions(1));
    531   ordered_kernel_dimensions[0] =
    532       rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0));
    533   ordered_kernel_dimensions[1] =
    534       rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1));
    535 
    536   std::vector<std::pair<int64, int64>> paddings =
    537       MakePadding(ordered_input_dimensions, ordered_kernel_dimensions,
    538                   ordered_kernel_strides, padding);
    539   CHECK_EQ(paddings.size(), 2);
    540 
    541   Window window;
    542 
    543   WindowDimension dim;
    544   dim.set_size(
    545       rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0)));
    546   dim.set_stride(kernel_stride.first);
    547   dim.set_padding_low(paddings[0].first);
    548   dim.set_padding_high(paddings[0].second);
    549   dim.set_window_dilation(rhs_dilation.first);
    550   dim.set_base_dilation(lhs_dilation.first);
    551   *window.add_dimensions() = dim;
    552 
    553   WindowDimension dim2;
    554   dim2.set_size(
    555       rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1)));
    556   dim2.set_stride(kernel_stride.second);
    557   dim2.set_padding_low(paddings[1].first);
    558   dim2.set_padding_high(paddings[1].second);
    559   dim2.set_window_dilation(rhs_dilation.second);
    560   dim2.set_base_dilation(lhs_dilation.second);
    561   *window.add_dimensions() = dim2;
    562 
    563   const Shape& shape =
    564       ShapeInference::InferConvolveShape(lhs_literal->shape(),
    565                                          rhs_literal->shape(), window, dnums)
    566           .ConsumeValueOrDie();
    567 
    568   HloInstruction* lhs_instruction =
    569       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
    570   HloInstruction* rhs_instruction =
    571       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
    572 
    573   b.AddInstruction(HloInstruction::CreateConvolve(
    574       shape, lhs_instruction, rhs_instruction, window, dnums));
    575   HloModule module("ReferenceUtil");
    576   auto computation = module.AddEntryComputation(b.Build());
    577 
    578   HloEvaluator evaluator;
    579   std::unique_ptr<Literal> result_literal =
    580       evaluator.Evaluate<const Literal*>(*computation, {}).ConsumeValueOrDie();
    581 
    582   CHECK_EQ(ShapeUtil::Rank(result_literal->shape()), 4);
    583   auto result =
    584       MakeUnique<Array4D<float>>(result_literal->shape().dimensions(0),
    585                                  result_literal->shape().dimensions(1),
    586                                  result_literal->shape().dimensions(2),
    587                                  result_literal->shape().dimensions(3));
    588 
    589   result->Each([&](tensorflow::gtl::ArraySlice<int64> indices, float* value) {
    590     *value = result_literal->Get<float>(indices);
    591   });
    592 
    593   return result;
    594 }
    595 
    596 /* static */ std::unique_ptr<std::vector<float>>
    597 ReferenceUtil::ReduceToColArray2D(
    598     const Array2D<float>& matrix, float init,
    599     const std::function<float(float, float)>& reduce_function) {
    600   int64 rows = matrix.height();
    601   int64 cols = matrix.width();
    602   auto result = MakeUnique<std::vector<float>>();
    603   for (int64 i = 0; i < rows; ++i) {
    604     float acc = init;
    605     for (int64 j = 0; j < cols; ++j) {
    606       acc = reduce_function(acc, matrix(i, j));
    607     }
    608     result->push_back(acc);
    609   }
    610   return result;
    611 }
    612 
    613 /* static */ std::unique_ptr<std::vector<float>>
    614 ReferenceUtil::ReduceToRowArray2D(
    615     const Array2D<float>& matrix, float init,
    616     const std::function<float(float, float)>& reduce_function) {
    617   int64 rows = matrix.height();
    618   int64 cols = matrix.width();
    619   auto result = MakeUnique<std::vector<float>>();
    620   for (int64 i = 0; i < cols; ++i) {
    621     float acc = init;
    622     for (int64 j = 0; j < rows; ++j) {
    623       acc = reduce_function(acc, matrix(j, i));
    624     }
    625     result->push_back(acc);
    626   }
    627   return result;
    628 }
    629 
    630 /*static*/ std::vector<float> ReferenceUtil::Reduce4DTo1D(
    631     const Array4D<float>& array, float init,
    632     tensorflow::gtl::ArraySlice<int64> dims,
    633     const std::function<float(float, float)>& reduce_function) {
    634   std::vector<float> result;
    635   CHECK_EQ(dims.size(), 3);
    636   const std::set<int64> dim_set(dims.begin(), dims.end());
    637   CHECK_EQ(dim_set.size(), 3);
    638   for (int64 a0 = 0; a0 == 0 || (!dim_set.count(0) && a0 < array.n1()); ++a0) {
    639     for (int64 a1 = 0; a1 == 0 || (!dim_set.count(1) && a1 < array.n2());
    640          ++a1) {
    641       for (int64 a2 = 0; a2 == 0 || (!dim_set.count(2) && a2 < array.n3());
    642            ++a2) {
    643         for (int64 a3 = 0; a3 == 0 || (!dim_set.count(3) && a3 < array.n4());
    644              ++a3) {
    645           float accumulator = init;
    646           for (int64 i0 = 0; i0 == 0 || (dim_set.count(0) && i0 < array.n1());
    647                ++i0) {
    648             for (int64 i1 = 0; i1 == 0 || (dim_set.count(1) && i1 < array.n2());
    649                  ++i1) {
    650               for (int64 i2 = 0;
    651                    i2 == 0 || (dim_set.count(2) && i2 < array.n3()); ++i2) {
    652                 for (int64 i3 = 0;
    653                      i3 == 0 || (dim_set.count(3) && i3 < array.n4()); ++i3) {
    654                   // Handle zero-sized arrays.
    655                   if (array.n1() > 0 && array.n2() > 0 && array.n3() > 0 &&
    656                       array.n4() > 0) {
    657                     accumulator = reduce_function(
    658                         accumulator, array(a0 + i0, a1 + i1, a2 + i2, a3 + i3));
    659                   }
    660                 }
    661               }
    662             }
    663           }
    664           result.push_back(accumulator);
    665         }
    666       }
    667     }
    668   }
    669   return result;
    670 }
    671 
    672 /* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::Broadcast1DTo4D(
    673     const std::vector<float>& array, const std::vector<int64>& bounds,
    674     int64 broadcast_from_dim) {
    675   auto result =
    676       MakeUnique<Array4D<float>>(bounds[0], bounds[1], bounds[2], bounds[3]);
    677   for (int64 i = 0; i < result->n1(); ++i) {
    678     for (int64 j = 0; j < result->n2(); ++j) {
    679       for (int64 k = 0; k < result->n3(); ++k) {
    680         for (int64 l = 0; l < result->n4(); ++l) {
    681           switch (broadcast_from_dim) {
    682             case 0:
    683               (*result)(i, j, k, l) = array[i];
    684               break;
    685             case 1:
    686               (*result)(i, j, k, l) = array[j];
    687               break;
    688             case 2:
    689               (*result)(i, j, k, l) = array[k];
    690               break;
    691             case 3:
    692               (*result)(i, j, k, l) = array[l];
    693               break;
    694             default:
    695               break;
    696           }
    697         }
    698       }
    699     }
    700   }
    701   return result;
    702 }
    703 
    704 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::Reduce3DTo2D(
    705     const Array3D<float>& array, float init,
    706     tensorflow::gtl::ArraySlice<int64> dims,
    707     const std::function<float(float, float)>& reduce_function) {
    708   CHECK_EQ(dims.size(), 1);
    709   int64 rows = dims[0] == 0 ? array.n2() : array.n1();
    710   int64 cols = dims[0] == 2 ? array.n2() : array.n3();
    711   auto result = MakeUnique<Array2D<float>>(rows, cols);
    712   result->Fill(init);
    713   for (int i0 = 0; i0 < array.n1(); ++i0) {
    714     for (int i1 = 0; i1 < array.n2(); ++i1) {
    715       for (int i2 = 0; i2 < array.n3(); ++i2) {
    716         int64 row = dims[0] == 0 ? i1 : i0;
    717         int64 col = dims[0] == 2 ? i1 : i2;
    718         (*result)(row, col) =
    719             reduce_function((*result)(row, col), array(i0, i1, i2));
    720       }
    721     }
    722   }
    723   return result;
    724 }
    725 
    726 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapArray2D(
    727     const Array2D<float>& matrix,
    728     const std::function<float(float)>& map_function) {
    729   int64 rows = matrix.height();
    730   int64 cols = matrix.width();
    731   auto result = MakeUnique<Array2D<float>>(rows, cols);
    732   for (int64 i = 0; i < rows; ++i) {
    733     for (int64 j = 0; j < cols; ++j) {
    734       (*result)(i, j) = map_function(matrix(i, j));
    735     }
    736   }
    737   return result;
    738 }
    739 
    740 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapArray2D(
    741     const Array2D<float>& lhs, const Array2D<float>& rhs,
    742     const std::function<float(float, float)>& map_function) {
    743   CHECK_EQ(lhs.height(), rhs.height());
    744   CHECK_EQ(lhs.width(), rhs.width());
    745   int64 rows = lhs.height();
    746   int64 cols = rhs.width();
    747   auto result = MakeUnique<Array2D<float>>(rows, cols);
    748   for (int64 i = 0; i < rows; ++i) {
    749     for (int64 j = 0; j < cols; ++j) {
    750       (*result)(i, j) = map_function(lhs(i, j), rhs(i, j));
    751     }
    752   }
    753   return result;
    754 }
    755 
    756 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapWithIndexArray2D(
    757     const Array2D<float>& matrix,
    758     const std::function<float(float, int64, int64)>& map_function) {
    759   int64 rows = matrix.height();
    760   int64 cols = matrix.width();
    761   auto result = MakeUnique<Array2D<float>>(rows, cols);
    762   for (int64 i = 0; i < rows; ++i) {
    763     for (int64 j = 0; j < cols; ++j) {
    764       (*result)(i, j) = map_function(matrix(i, j), i, j);
    765     }
    766   }
    767   return result;
    768 }
    769 
    770 }  // namespace xla
    771