HomeSort by relevance Sort by last modified time
    Searched refs:dim_nums (Results 1 - 6 of 6) sorted by null

  /external/tensorflow/tensorflow/compiler/xla/tests/
convolution_dimension_numbers_test.cc 105 ConvolutionDimensionNumbers dim_nums = local
108 int64 old_input_batch_dim = dim_nums.input_batch_dimension();
109 int64 old_output_batch_dim = dim_nums.output_batch_dimension();
110 dim_nums.set_input_batch_dimension(dim_nums.input_feature_dimension());
111 dim_nums.set_output_batch_dimension(dim_nums.output_feature_dimension());
112 dim_nums.set_input_feature_dimension(old_input_batch_dim);
113 dim_nums.set_output_feature_dimension(old_output_batch_dim);
116 dim_nums.kernel_input_feature_dimension()
    [all...]
  /external/tensorflow/tensorflow/compiler/xla/service/cpu/
dot_op_emitter_internal.h 41 DotDimensionNumbers dim_nums; member in struct:xla::cpu::internal::DotInfo
48 dim_nums = instr.dot_dimension_numbers();
dot_op_emitter.cc 60 DotDimensionNumbers dim_nums; member in struct:xla::cpu::__anon44343::DotInfo
69 dim_nums = instr.dot_dimension_numbers();
411 const DotDimensionNumbers& dim_nums = dot_info_.dim_nums; local
416 int64 lhs_reduction_dimension = dim_nums.lhs_contracting_dimensions(0);
417 int64 rhs_reduction_dimension = dim_nums.rhs_contracting_dimensions(0);
689 const DotDimensionNumbers& dim_nums = dot_info_.dim_nums; local
692 /*m=*/lhs_shape.dimensions(1 - dim_nums.lhs_contracting_dimensions(0)),
693 /*k=*/lhs_shape.dimensions(dim_nums.lhs_contracting_dimensions(0))
    [all...]
  /external/tensorflow/tensorflow/compiler/xla/service/
hlo_matchers.cc 222 const DotDimensionNumbers& dim_nums = instruction->dot_dimension_numbers(); local
223 if (dim_nums.lhs_contracting_dimensions_size() != 1 ||
224 dim_nums.lhs_contracting_dimensions(0) != lhs_contracting_dim_) {
227 << absl::StrJoin(dim_nums.lhs_contracting_dimensions(), ",")
232 if (dim_nums.rhs_contracting_dimensions_size() != 1 ||
233 dim_nums.rhs_contracting_dimensions(0) != rhs_contracting_dim_) {
236 << absl::StrJoin(dim_nums.rhs_contracting_dimensions(), ",")
  /external/tensorflow/tensorflow/compiler/xla/service/gpu/
gemm_thunk.cc 323 DotDimensionNumbers dim_nums = GetDimensionNumbers(*hlo_instruction()); local
324 CHECK_EQ(dim_nums.lhs_batch_dimensions_size(),
325 dim_nums.rhs_batch_dimensions_size());
326 CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2, output_shape_.rank());
328 int64 row_dim = dim_nums.lhs_batch_dimensions_size();
329 int64 col_dim = dim_nums.lhs_batch_dimensions_size() + 1;
335 for (int64 batch_dim : dim_nums.lhs_batch_dimensions()) {
388 lhs_data, lhs_shape_, dim_nums.lhs_contracting_dimensions(0) == row_dim);
390 rhs_data, rhs_shape_, dim_nums.rhs_contracting_dimensions(0) == col_dim);
gpu_layout_assignment.cc 194 const DotDimensionNumbers& dim_nums = local
196 CHECK_EQ(dim_nums.lhs_batch_dimensions_size(),
197 dim_nums.rhs_batch_dimensions_size());
198 CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2,
200 for (int64 batch_dim : dim_nums.lhs_batch_dimensions()) {

Completed in 96 milliseconds