1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/reference_util.h" 17 18 #include <array> 19 #include <utility> 20 21 #include "tensorflow/compiler/xla/client/computation_builder.h" 22 #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" 23 #include "tensorflow/compiler/xla/service/hlo_evaluator.h" 24 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 25 #include "tensorflow/compiler/xla/service/shape_inference.h" 26 #include "tensorflow/compiler/xla/window_util.h" 27 #include "tensorflow/compiler/xla/xla_data.pb.h" 28 #include "tensorflow/core/lib/math/math_util.h" 29 #include "tensorflow/core/platform/logging.h" 30 31 namespace xla { 32 33 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::TransposeArray2D( 34 const Array2D<float>& operand) { 35 auto result = MakeUnique<Array2D<float>>(operand.width(), operand.height()); 36 for (int64 w = 0; w < operand.width(); ++w) { 37 for (int64 h = 0; h < operand.height(); ++h) { 38 (*result)(w, h) = operand(h, w); 39 } 40 } 41 42 return result; 43 } 44 45 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MatmulArray2D( 46 const Array2D<float>& lhs, const Array2D<float>& rhs) { 47 CHECK_EQ(lhs.width(), rhs.height()); 48 int m = lhs.height(); 49 int n = rhs.width(); 50 int k = lhs.width(); 51 auto result = MakeUnique<Array2D<float>>(m, n); 52 // Because Eigen is a header-oriented library, make sure that the Eigen code 53 // is the same as the code used by the CPU backend (otherwise the linker will 54 // randomly pick *some* definition). 55 __xla_cpu_runtime_EigenSingleThreadedMatMulF32( 56 /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m, 57 k, 58 /*transpose_lhs=*/0, 59 /*transpose_rhs=*/0); 60 return result; 61 } 62 63 /* static */ std::unique_ptr<Array2D<double>> ReferenceUtil::MatmulArray2D( 64 const Array2D<double>& lhs, const Array2D<double>& rhs) { 65 CHECK_EQ(lhs.width(), rhs.height()); 66 int m = lhs.height(); 67 int n = rhs.width(); 68 int k = lhs.width(); 69 auto result = MakeUnique<Array2D<double>>(m, n); 70 // Because Eigen is a header-oriented library, make sure that the Eigen code 71 // is the same as the code used by the CPU backend (otherwise the linker will 72 // randomly pick *some* definition). 73 __xla_cpu_runtime_EigenSingleThreadedMatMulF64( 74 /*run_options_ptr=*/nullptr, result->data(), rhs.data(), lhs.data(), n, m, 75 k, 76 /*transpose_lhs=*/0, 77 /*transpose_rhs=*/0); 78 return result; 79 } 80 81 /* static */ std::unique_ptr<Array2D<double>> ReferenceUtil::Array2DF32ToF64( 82 const Array2D<float>& input) { 83 auto result = MakeUnique<Array2D<double>>(input.height(), input.width()); 84 for (int64 rowno = 0; rowno < input.height(); ++rowno) { 85 for (int64 colno = 0; colno < input.height(); ++colno) { 86 (*result)(rowno, colno) = input(rowno, colno); 87 } 88 } 89 return result; 90 } 91 92 /* static */ std::unique_ptr<Array3D<float>> ReferenceUtil::ConvArray3D( 93 const Array3D<float>& lhs, const Array3D<float>& rhs, int64 kernel_stride, 94 Padding padding) { 95 return ConvArray3DGeneralDimensionsDilated( 96 lhs, rhs, kernel_stride, padding, 1, 1, 97 ComputationBuilder::CreateDefaultConvDimensionNumbers(1)); 98 } 99 100 /*static*/ std::unique_ptr<Array3D<float>> 101 ReferenceUtil::ConvArray3DGeneralDimensionsDilated( 102 const Array3D<float>& lhs, const Array3D<float>& rhs, int64 kernel_stride, 103 Padding padding, int64 lhs_dilation, int64 rhs_dilation, 104 const ConvolutionDimensionNumbers& dnums) { 105 CHECK_EQ(dnums.input_spatial_dimensions_size(), 1); 106 CHECK_EQ(dnums.kernel_spatial_dimensions_size(), 1); 107 CHECK_EQ(dnums.output_spatial_dimensions_size(), 1); 108 // Reuse the code for Array4D-convolution by extending the 3D input into a 4D 109 // array by adding a fourth dummy dimension of size 1 without stride, padding 110 // and dilation. 111 Array4D<float> a4dlhs(lhs.n1(), lhs.n2(), lhs.n3(), 1); 112 a4dlhs.Each( 113 [&](tensorflow::gtl::ArraySlice<int64> indices, float* value_ptr) { 114 CHECK_EQ(indices[3], 0); 115 *value_ptr = lhs.operator()(indices[0], indices[1], indices[2]); 116 }); 117 Array4D<float> a4drhs(rhs.n1(), rhs.n2(), rhs.n3(), 1); 118 a4drhs.Each( 119 [&](tensorflow::gtl::ArraySlice<int64> indices, float* value_ptr) { 120 CHECK_EQ(indices[3], 0); 121 *value_ptr = rhs.operator()(indices[0], indices[1], indices[2]); 122 }); 123 // Add a second dummy spatial dimensions. 124 ConvolutionDimensionNumbers dnums2d = dnums; 125 dnums2d.add_input_spatial_dimensions(3); 126 dnums2d.add_kernel_spatial_dimensions(3); 127 dnums2d.add_output_spatial_dimensions(3); 128 std::unique_ptr<Array4D<float>> convr4 = ConvArray4DGeneralDimensionsDilated( 129 a4dlhs, a4drhs, {kernel_stride, 1}, padding, {lhs_dilation, 1}, 130 {rhs_dilation, 1}, dnums2d); 131 132 auto convr3 = MakeUnique<Array3D<float>>(convr4->planes(), convr4->depth(), 133 convr4->height()); 134 convr4->Each( 135 [&](tensorflow::gtl::ArraySlice<int64> indices, float* value_ptr) { 136 CHECK_EQ(indices[3], 0); 137 convr3->operator()(indices[0], indices[1], indices[2]) = *value_ptr; 138 }); 139 return convr3; 140 } 141 142 /* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::ConvArray4D( 143 const Array4D<float>& lhs, const Array4D<float>& rhs, 144 std::pair<int64, int64> kernel_stride, Padding padding) { 145 return ConvArray4DGeneralDimensions( 146 lhs, rhs, kernel_stride, padding, 147 ComputationBuilder::CreateDefaultConvDimensionNumbers()); 148 } 149 150 /* static */ std::unique_ptr<Array4D<float>> 151 ReferenceUtil::SeparableConvArray4D(const Array4D<float>& input, 152 const Array4D<float>& depthwise_weights, 153 const Array4D<float>& pointwise_weights, 154 std::pair<int64, int64> kernel_stride, 155 Padding padding) { 156 const int64 depth_multiplier = depthwise_weights.planes(); 157 CHECK_EQ(pointwise_weights.depth(), input.depth() * depth_multiplier); 158 159 // Combine the two weights by reducing the depth_multiplier, so that we can 160 // apply a single convolution on the combined weights. 161 Array4D<float> weights(pointwise_weights.planes(), input.depth(), 162 depthwise_weights.height(), depthwise_weights.width()); 163 for (int64 kx = 0; kx < depthwise_weights.width(); ++kx) { 164 for (int64 ky = 0; ky < depthwise_weights.height(); ++ky) { 165 for (int64 kz = 0; kz < input.depth(); ++kz) { 166 for (int64 out = 0; out < pointwise_weights.planes(); ++out) { 167 float weight = 0.0; 168 for (int64 depth = 0; depth < depth_multiplier; ++depth) { 169 weight += 170 depthwise_weights(depth, kz, ky, kx) * 171 pointwise_weights(out, depth + kz * depth_multiplier, 0, 0); 172 } 173 weights(out, kz, ky, kx) = weight; 174 } 175 } 176 } 177 } 178 179 return ConvArray4D(input, weights, kernel_stride, padding); 180 } 181 182 /* static */ int64 ReferenceUtil::WindowCount(int64 unpadded_width, 183 int64 window_len, int64 stride, 184 Padding padding) { 185 if (padding == Padding::kValid) { 186 return window_util::StridedBound(unpadded_width, window_len, stride); 187 } 188 return tensorflow::MathUtil::CeilOfRatio(unpadded_width, stride); 189 } 190 191 /* static */ std::unique_ptr<std::vector<float>> 192 ReferenceUtil::ReduceWindow1DGeneric( 193 const tensorflow::gtl::ArraySlice<float>& operand, float init, 194 const std::function<float(float, float)>& reduce_func, 195 const tensorflow::gtl::ArraySlice<int64>& window, 196 const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) { 197 std::vector<int64> dim_lengths{static_cast<int64>(operand.size())}; 198 return ReduceWindow1DGeneric( 199 operand, init, reduce_func, window, stride, 200 xla::MakePadding(dim_lengths, window, stride, padding)); 201 } 202 203 /* static */ std::unique_ptr<std::vector<float>> 204 ReferenceUtil::ReduceWindow1DGeneric( 205 const tensorflow::gtl::ArraySlice<float>& operand, float init, 206 const std::function<float(float, float)>& reduce_func, 207 const tensorflow::gtl::ArraySlice<int64>& window, 208 const tensorflow::gtl::ArraySlice<int64>& stride, 209 const tensorflow::gtl::ArraySlice<std::pair<int64, int64>>& padding) { 210 std::vector<int64> dim_lengths{static_cast<int64>(operand.size())}; 211 std::vector<int64> window_counts(window.size(), 0); 212 std::vector<int64> pad_low(window.size(), 0); 213 for (int64 i = 0; i < window.size(); ++i) { 214 int64 padded_width = padding[i].first + dim_lengths[i] + padding[i].second; 215 window_counts[i] = 216 window_util::StridedBound(padded_width, window[i], stride[i]); 217 pad_low[i] = padding[i].first; 218 } 219 auto result = MakeUnique<std::vector<float>>(window_counts[0]); 220 221 // Do a full 1D reduce window. 222 for (int64 i0 = 0; i0 < window_counts[0]; ++i0) { 223 int64 i0_base = i0 * stride[0] - pad_low[0]; 224 225 float val = init; 226 for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) { 227 if (i0_base + i0_win >= 0 && i0_base + i0_win < dim_lengths[0]) { 228 val = reduce_func(val, operand[i0_base + i0_win]); 229 } 230 } 231 (*result)[i0] = val; 232 } 233 return result; 234 } 235 236 /* static */ std::unique_ptr<std::vector<float>> 237 ReferenceUtil::ReduceWindow1DAdd( 238 const tensorflow::gtl::ArraySlice<float>& operand, float init, 239 const tensorflow::gtl::ArraySlice<int64>& window, 240 const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) { 241 const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; }; 242 return ReduceWindow1DGeneric(operand, init, add_reduce, window, stride, 243 padding); 244 } 245 246 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::ReduceWindow2DAdd( 247 const Array2D<float>& operand, float init, 248 const tensorflow::gtl::ArraySlice<int64>& window, 249 const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) { 250 std::vector<int64> dim_lengths{operand.height(), operand.width()}; 251 auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding); 252 253 std::vector<int64> window_counts(window.size(), 0); 254 std::vector<int64> pad_low(window.size(), 0); 255 for (int64 i = 0; i < window.size(); ++i) { 256 window_counts[i] = 257 WindowCount(dim_lengths[i], window[i], stride[i], padding); 258 pad_low[i] = padding_both[i].first; 259 } 260 auto result = MakeUnique<Array2D<float>>(window_counts[0], window_counts[1]); 261 262 // Do a full 2D reduce window. 263 for (int64 i0 = 0; i0 < window_counts[0]; ++i0) { 264 for (int64 i1 = 0; i1 < window_counts[1]; ++i1) { 265 int64 i0_base = i0 * stride[0] - pad_low[0]; 266 int64 i1_base = i1 * stride[1] - pad_low[1]; 267 268 float val = init; 269 for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) { 270 for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) { 271 if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 && 272 i0_base + i0_win < operand.n1() && 273 i1_base + i1_win < operand.n2()) { 274 val += operand(i0_base + i0_win, i1_base + i1_win); 275 } 276 } 277 } 278 (*result)(i0, i1) = val; 279 } 280 } 281 return result; 282 } 283 284 /* static */ std::unique_ptr<Array3D<float>> ReferenceUtil::ReduceWindow3DAdd( 285 const Array3D<float>& operand, float init, 286 const tensorflow::gtl::ArraySlice<int64>& window, 287 const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) { 288 std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3()}; 289 auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding); 290 291 std::vector<int64> window_counts(window.size(), 0); 292 std::vector<int64> pad_low(window.size(), 0); 293 for (int64 i = 0; i < window.size(); ++i) { 294 window_counts[i] = 295 WindowCount(dim_lengths[i], window[i], stride[i], padding); 296 pad_low[i] = padding_both[i].first; 297 } 298 auto result = MakeUnique<Array3D<float>>(window_counts[0], window_counts[1], 299 window_counts[2]); 300 301 for (int64 i0 = 0; i0 < window_counts[0]; ++i0) { 302 for (int64 i1 = 0; i1 < window_counts[1]; ++i1) { 303 for (int64 i2 = 0; i2 < window_counts[2]; ++i2) { 304 int64 i0_base = i0 * stride[0] - pad_low[0]; 305 int64 i1_base = i1 * stride[1] - pad_low[1]; 306 int64 i2_base = i2 * stride[2] - pad_low[2]; 307 308 float val = init; 309 for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) { 310 for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) { 311 for (int64 i2_win = 0; i2_win < window[2]; ++i2_win) { 312 if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 && 313 i2_base + i2_win >= 0 && i0_base + i0_win < operand.n1() && 314 i1_base + i1_win < operand.n2() && 315 i2_base + i2_win < operand.n3()) { 316 val += operand(i0_base + i0_win, i1_base + i1_win, 317 i2_base + i2_win); 318 } 319 } 320 } 321 } 322 (*result)(i0, i1, i2) = val; 323 } 324 } 325 } 326 return result; 327 } 328 329 /* static */ std::unique_ptr<Array4D<float>> 330 ReferenceUtil::ReduceWindow4DGeneric( 331 const Array4D<float>& operand, float init, 332 const std::function<float(float, float)>& reduce_func, 333 const tensorflow::gtl::ArraySlice<int64>& window, 334 const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) { 335 std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(), 336 operand.n4()}; 337 return ReduceWindow4DGeneric( 338 operand, init, reduce_func, window, stride, 339 xla::MakePadding(dim_lengths, window, stride, padding)); 340 } 341 342 /* static */ std::unique_ptr<Array4D<float>> 343 ReferenceUtil::ReduceWindow4DGeneric( 344 const Array4D<float>& operand, float init, 345 const std::function<float(float, float)>& reduce_func, 346 const tensorflow::gtl::ArraySlice<int64>& window, 347 const tensorflow::gtl::ArraySlice<int64>& stride, 348 const tensorflow::gtl::ArraySlice<std::pair<int64, int64>>& padding) { 349 std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(), 350 operand.n4()}; 351 352 std::vector<int64> window_counts(window.size(), 0); 353 std::vector<int64> pad_low(window.size(), 0); 354 for (int64 i = 0; i < window.size(); ++i) { 355 int64 padded_width = padding[i].first + dim_lengths[i] + padding[i].second; 356 window_counts[i] = 357 window_util::StridedBound(padded_width, window[i], stride[i]); 358 pad_low[i] = padding[i].first; 359 } 360 auto result = MakeUnique<Array4D<float>>(window_counts[0], window_counts[1], 361 window_counts[2], window_counts[3]); 362 // Do a full 4D reduce window. 363 for (int64 i0 = 0; i0 < window_counts[0]; ++i0) { 364 for (int64 i1 = 0; i1 < window_counts[1]; ++i1) { 365 for (int64 i2 = 0; i2 < window_counts[2]; ++i2) { 366 for (int64 i3 = 0; i3 < window_counts[3]; ++i3) { 367 int64 i0_base = i0 * stride[0] - pad_low[0]; 368 int64 i1_base = i1 * stride[1] - pad_low[1]; 369 int64 i2_base = i2 * stride[2] - pad_low[2]; 370 int64 i3_base = i3 * stride[3] - pad_low[3]; 371 372 float val = init; 373 for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) { 374 for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) { 375 for (int64 i2_win = 0; i2_win < window[2]; ++i2_win) { 376 for (int64 i3_win = 0; i3_win < window[3]; ++i3_win) { 377 if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 && 378 i2_base + i2_win >= 0 && i3_base + i3_win >= 0 && 379 i0_base + i0_win < operand.n1() && 380 i1_base + i1_win < operand.n2() && 381 i2_base + i2_win < operand.n3() && 382 i3_base + i3_win < operand.n4()) { 383 val = reduce_func( 384 val, operand(i0_base + i0_win, i1_base + i1_win, 385 i2_base + i2_win, i3_base + i3_win)); 386 } 387 } 388 } 389 } 390 } 391 (*result)(i0, i1, i2, i3) = val; 392 } 393 } 394 } 395 } 396 return result; 397 } 398 399 /* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::ReduceWindow4DAdd( 400 const Array4D<float>& operand, float init, 401 const tensorflow::gtl::ArraySlice<int64>& window, 402 const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) { 403 const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; }; 404 return ReduceWindow4DGeneric(operand, init, add_reduce, window, stride, 405 padding); 406 } 407 408 /* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::BatchNorm4D( 409 const Array4D<float>& input, const Array4D<float>& mean, 410 const Array4D<float>& var, const Array4D<float>& scale, 411 const Array4D<float>& offset, float epsilon) { 412 auto normalized = 413 *MapArray4D(input, mean, [](float a, float b) { return a - b; }); 414 normalized = *MapArray4D(normalized, var, [&](float a, float b) { 415 return a / std::sqrt(b + epsilon); 416 }); 417 normalized = 418 *MapArray4D(normalized, scale, [](float a, float b) { return a * b; }); 419 return MapArray4D(normalized, offset, [](float a, float b) { return a + b; }); 420 } 421 422 /* static */ std::unique_ptr<Array4D<float>> 423 ReferenceUtil::SelectAndScatter4DGePlus( 424 const Array4D<float>& operand, const Array4D<float>& source, float init, 425 const tensorflow::gtl::ArraySlice<int64>& window, 426 const tensorflow::gtl::ArraySlice<int64>& stride, bool same_padding) { 427 Padding padding = same_padding ? Padding::kSame : Padding::kValid; 428 auto result = MakeUnique<Array4D<float>>(operand.n1(), operand.n2(), 429 operand.n3(), operand.n4()); 430 std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(), 431 operand.n4()}; 432 auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding); 433 // Fill the output, with the initial value. 434 result->Fill(init); 435 436 std::vector<int64> window_counts(window.size(), 0); 437 std::vector<int64> pad_low(window.size(), 0); 438 for (int64 i = 0; i < window.size(); ++i) { 439 window_counts[i] = 440 WindowCount(dim_lengths[i], window[i], stride[i], padding); 441 pad_low[i] = padding_both[i].first; 442 } 443 CHECK_EQ(window_counts[0], source.n1()); 444 CHECK_EQ(window_counts[1], source.n2()); 445 CHECK_EQ(window_counts[2], source.n3()); 446 CHECK_EQ(window_counts[3], source.n4()); 447 448 // Do a full 4D select and Scatter. 449 for (int64 i0 = 0; i0 < window_counts[0]; ++i0) { 450 for (int64 i1 = 0; i1 < window_counts[1]; ++i1) { 451 for (int64 i2 = 0; i2 < window_counts[2]; ++i2) { 452 for (int64 i3 = 0; i3 < window_counts[3]; ++i3) { 453 // Now we are inside a window and need to find the max and the argmax. 454 int64 i0_base = i0 * stride[0] - pad_low[0]; 455 int64 i1_base = i1 * stride[1] - pad_low[1]; 456 int64 i2_base = i2 * stride[2] - pad_low[2]; 457 int64 i3_base = i3 * stride[3] - pad_low[3]; 458 int64 scatter_0 = (i0_base >= 0) ? i0_base : 0; 459 int64 scatter_1 = (i1_base >= 0) ? i1_base : 0; 460 int64 scatter_2 = (i2_base >= 0) ? i2_base : 0; 461 int64 scatter_3 = (i3_base >= 0) ? i3_base : 0; 462 float val = operand(scatter_0, scatter_1, scatter_2, scatter_3); 463 for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) { 464 for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) { 465 for (int64 i2_win = 0; i2_win < window[2]; ++i2_win) { 466 for (int64 i3_win = 0; i3_win < window[3]; ++i3_win) { 467 if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 && 468 i2_base + i2_win >= 0 && i3_base + i3_win >= 0 && 469 i0_base + i0_win < operand.n1() && 470 i1_base + i1_win < operand.n2() && 471 i2_base + i2_win < operand.n3() && 472 i3_base + i3_win < operand.n4()) { 473 float tmp = operand(i0_base + i0_win, i1_base + i1_win, 474 i2_base + i2_win, i3_base + i3_win); 475 if (tmp >= val) { 476 val = tmp; 477 scatter_0 = i0_base + i0_win; 478 scatter_1 = i1_base + i1_win; 479 scatter_2 = i2_base + i2_win; 480 scatter_3 = i3_base + i3_win; 481 } 482 } 483 } 484 } 485 } 486 } 487 (*result)(scatter_0, scatter_1, scatter_2, scatter_3) += 488 source(i0, i1, i2, i3); 489 } 490 } 491 } 492 } 493 return result; 494 } 495 496 /* static */ std::unique_ptr<Array4D<float>> 497 ReferenceUtil::ConvArray4DGeneralDimensions( 498 const Array4D<float>& lhs, const Array4D<float>& rhs, 499 std::pair<int64, int64> kernel_stride, Padding padding, 500 ConvolutionDimensionNumbers dimension_numbers) { 501 return ConvArray4DGeneralDimensionsDilated(lhs, rhs, kernel_stride, padding, 502 {1, 1}, {1, 1}, 503 std::move(dimension_numbers)); 504 } 505 506 /* static */ std::unique_ptr<Array4D<float>> 507 ReferenceUtil::ConvArray4DGeneralDimensionsDilated( 508 const Array4D<float>& lhs, const Array4D<float>& rhs, 509 std::pair<int64, int64> kernel_stride, Padding padding, 510 std::pair<int64, int64> lhs_dilation, std::pair<int64, int64> rhs_dilation, 511 ConvolutionDimensionNumbers dnums) { 512 HloComputation::Builder b("ConvArray4DGeneralDimensionDilated"); 513 auto lhs_literal = Literal::CreateR4FromArray4D<float>(lhs); 514 auto rhs_literal = Literal::CreateR4FromArray4D<float>(rhs); 515 516 std::array<int64, 2> ordered_kernel_strides; 517 std::array<int64, 2> ordered_input_dimensions; 518 std::array<int64, 2> ordered_kernel_dimensions; 519 if (dnums.kernel_spatial_dimensions(0) > dnums.kernel_spatial_dimensions(1)) { 520 ordered_kernel_strides[0] = kernel_stride.second; 521 ordered_kernel_strides[1] = kernel_stride.first; 522 } else { 523 ordered_kernel_strides[0] = kernel_stride.first; 524 ordered_kernel_strides[1] = kernel_stride.second; 525 } 526 527 ordered_input_dimensions[0] = 528 lhs_literal->shape().dimensions(dnums.input_spatial_dimensions(0)); 529 ordered_input_dimensions[1] = 530 lhs_literal->shape().dimensions(dnums.input_spatial_dimensions(1)); 531 ordered_kernel_dimensions[0] = 532 rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0)); 533 ordered_kernel_dimensions[1] = 534 rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1)); 535 536 std::vector<std::pair<int64, int64>> paddings = 537 MakePadding(ordered_input_dimensions, ordered_kernel_dimensions, 538 ordered_kernel_strides, padding); 539 CHECK_EQ(paddings.size(), 2); 540 541 Window window; 542 543 WindowDimension dim; 544 dim.set_size( 545 rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(0))); 546 dim.set_stride(kernel_stride.first); 547 dim.set_padding_low(paddings[0].first); 548 dim.set_padding_high(paddings[0].second); 549 dim.set_window_dilation(rhs_dilation.first); 550 dim.set_base_dilation(lhs_dilation.first); 551 *window.add_dimensions() = dim; 552 553 WindowDimension dim2; 554 dim2.set_size( 555 rhs_literal->shape().dimensions(dnums.kernel_spatial_dimensions(1))); 556 dim2.set_stride(kernel_stride.second); 557 dim2.set_padding_low(paddings[1].first); 558 dim2.set_padding_high(paddings[1].second); 559 dim2.set_window_dilation(rhs_dilation.second); 560 dim2.set_base_dilation(lhs_dilation.second); 561 *window.add_dimensions() = dim2; 562 563 const Shape& shape = 564 ShapeInference::InferConvolveShape(lhs_literal->shape(), 565 rhs_literal->shape(), window, dnums) 566 .ConsumeValueOrDie(); 567 568 HloInstruction* lhs_instruction = 569 b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal))); 570 HloInstruction* rhs_instruction = 571 b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); 572 573 b.AddInstruction(HloInstruction::CreateConvolve( 574 shape, lhs_instruction, rhs_instruction, window, dnums)); 575 HloModule module("ReferenceUtil"); 576 auto computation = module.AddEntryComputation(b.Build()); 577 578 HloEvaluator evaluator; 579 std::unique_ptr<Literal> result_literal = 580 evaluator.Evaluate<const Literal*>(*computation, {}).ConsumeValueOrDie(); 581 582 CHECK_EQ(ShapeUtil::Rank(result_literal->shape()), 4); 583 auto result = 584 MakeUnique<Array4D<float>>(result_literal->shape().dimensions(0), 585 result_literal->shape().dimensions(1), 586 result_literal->shape().dimensions(2), 587 result_literal->shape().dimensions(3)); 588 589 result->Each([&](tensorflow::gtl::ArraySlice<int64> indices, float* value) { 590 *value = result_literal->Get<float>(indices); 591 }); 592 593 return result; 594 } 595 596 /* static */ std::unique_ptr<std::vector<float>> 597 ReferenceUtil::ReduceToColArray2D( 598 const Array2D<float>& matrix, float init, 599 const std::function<float(float, float)>& reduce_function) { 600 int64 rows = matrix.height(); 601 int64 cols = matrix.width(); 602 auto result = MakeUnique<std::vector<float>>(); 603 for (int64 i = 0; i < rows; ++i) { 604 float acc = init; 605 for (int64 j = 0; j < cols; ++j) { 606 acc = reduce_function(acc, matrix(i, j)); 607 } 608 result->push_back(acc); 609 } 610 return result; 611 } 612 613 /* static */ std::unique_ptr<std::vector<float>> 614 ReferenceUtil::ReduceToRowArray2D( 615 const Array2D<float>& matrix, float init, 616 const std::function<float(float, float)>& reduce_function) { 617 int64 rows = matrix.height(); 618 int64 cols = matrix.width(); 619 auto result = MakeUnique<std::vector<float>>(); 620 for (int64 i = 0; i < cols; ++i) { 621 float acc = init; 622 for (int64 j = 0; j < rows; ++j) { 623 acc = reduce_function(acc, matrix(j, i)); 624 } 625 result->push_back(acc); 626 } 627 return result; 628 } 629 630 /*static*/ std::vector<float> ReferenceUtil::Reduce4DTo1D( 631 const Array4D<float>& array, float init, 632 tensorflow::gtl::ArraySlice<int64> dims, 633 const std::function<float(float, float)>& reduce_function) { 634 std::vector<float> result; 635 CHECK_EQ(dims.size(), 3); 636 const std::set<int64> dim_set(dims.begin(), dims.end()); 637 CHECK_EQ(dim_set.size(), 3); 638 for (int64 a0 = 0; a0 == 0 || (!dim_set.count(0) && a0 < array.n1()); ++a0) { 639 for (int64 a1 = 0; a1 == 0 || (!dim_set.count(1) && a1 < array.n2()); 640 ++a1) { 641 for (int64 a2 = 0; a2 == 0 || (!dim_set.count(2) && a2 < array.n3()); 642 ++a2) { 643 for (int64 a3 = 0; a3 == 0 || (!dim_set.count(3) && a3 < array.n4()); 644 ++a3) { 645 float accumulator = init; 646 for (int64 i0 = 0; i0 == 0 || (dim_set.count(0) && i0 < array.n1()); 647 ++i0) { 648 for (int64 i1 = 0; i1 == 0 || (dim_set.count(1) && i1 < array.n2()); 649 ++i1) { 650 for (int64 i2 = 0; 651 i2 == 0 || (dim_set.count(2) && i2 < array.n3()); ++i2) { 652 for (int64 i3 = 0; 653 i3 == 0 || (dim_set.count(3) && i3 < array.n4()); ++i3) { 654 // Handle zero-sized arrays. 655 if (array.n1() > 0 && array.n2() > 0 && array.n3() > 0 && 656 array.n4() > 0) { 657 accumulator = reduce_function( 658 accumulator, array(a0 + i0, a1 + i1, a2 + i2, a3 + i3)); 659 } 660 } 661 } 662 } 663 } 664 result.push_back(accumulator); 665 } 666 } 667 } 668 } 669 return result; 670 } 671 672 /* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::Broadcast1DTo4D( 673 const std::vector<float>& array, const std::vector<int64>& bounds, 674 int64 broadcast_from_dim) { 675 auto result = 676 MakeUnique<Array4D<float>>(bounds[0], bounds[1], bounds[2], bounds[3]); 677 for (int64 i = 0; i < result->n1(); ++i) { 678 for (int64 j = 0; j < result->n2(); ++j) { 679 for (int64 k = 0; k < result->n3(); ++k) { 680 for (int64 l = 0; l < result->n4(); ++l) { 681 switch (broadcast_from_dim) { 682 case 0: 683 (*result)(i, j, k, l) = array[i]; 684 break; 685 case 1: 686 (*result)(i, j, k, l) = array[j]; 687 break; 688 case 2: 689 (*result)(i, j, k, l) = array[k]; 690 break; 691 case 3: 692 (*result)(i, j, k, l) = array[l]; 693 break; 694 default: 695 break; 696 } 697 } 698 } 699 } 700 } 701 return result; 702 } 703 704 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::Reduce3DTo2D( 705 const Array3D<float>& array, float init, 706 tensorflow::gtl::ArraySlice<int64> dims, 707 const std::function<float(float, float)>& reduce_function) { 708 CHECK_EQ(dims.size(), 1); 709 int64 rows = dims[0] == 0 ? array.n2() : array.n1(); 710 int64 cols = dims[0] == 2 ? array.n2() : array.n3(); 711 auto result = MakeUnique<Array2D<float>>(rows, cols); 712 result->Fill(init); 713 for (int i0 = 0; i0 < array.n1(); ++i0) { 714 for (int i1 = 0; i1 < array.n2(); ++i1) { 715 for (int i2 = 0; i2 < array.n3(); ++i2) { 716 int64 row = dims[0] == 0 ? i1 : i0; 717 int64 col = dims[0] == 2 ? i1 : i2; 718 (*result)(row, col) = 719 reduce_function((*result)(row, col), array(i0, i1, i2)); 720 } 721 } 722 } 723 return result; 724 } 725 726 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapArray2D( 727 const Array2D<float>& matrix, 728 const std::function<float(float)>& map_function) { 729 int64 rows = matrix.height(); 730 int64 cols = matrix.width(); 731 auto result = MakeUnique<Array2D<float>>(rows, cols); 732 for (int64 i = 0; i < rows; ++i) { 733 for (int64 j = 0; j < cols; ++j) { 734 (*result)(i, j) = map_function(matrix(i, j)); 735 } 736 } 737 return result; 738 } 739 740 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapArray2D( 741 const Array2D<float>& lhs, const Array2D<float>& rhs, 742 const std::function<float(float, float)>& map_function) { 743 CHECK_EQ(lhs.height(), rhs.height()); 744 CHECK_EQ(lhs.width(), rhs.width()); 745 int64 rows = lhs.height(); 746 int64 cols = rhs.width(); 747 auto result = MakeUnique<Array2D<float>>(rows, cols); 748 for (int64 i = 0; i < rows; ++i) { 749 for (int64 j = 0; j < cols; ++j) { 750 (*result)(i, j) = map_function(lhs(i, j), rhs(i, j)); 751 } 752 } 753 return result; 754 } 755 756 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapWithIndexArray2D( 757 const Array2D<float>& matrix, 758 const std::function<float(float, int64, int64)>& map_function) { 759 int64 rows = matrix.height(); 760 int64 cols = matrix.width(); 761 auto result = MakeUnique<Array2D<float>>(rows, cols); 762 for (int64 i = 0; i < rows; ++i) { 763 for (int64 j = 0; j < cols; ++j) { 764 (*result)(i, j) = map_function(matrix(i, j), i, j); 765 } 766 } 767 return result; 768 } 769 770 } // namespace xla 771