HomeSort by relevance Sort by last modified time
    Searched defs:dim_numbers (Results 1 - 14 of 14) sorted by null

  /external/tensorflow/tensorflow/compiler/tf2xla/lib/
scatter.cc 138 xla::ScatterDimensionNumbers dim_numbers; local
139 dim_numbers.set_index_vector_dim(indices_are_vectors
165 dim_numbers.add_update_window_dims(i);
170 dim_numbers.add_inserted_window_dims(i);
171 dim_numbers.add_scatter_dims_to_operand_dims(i);
193 VLOG(3) << " index_vector_dim: " << dim_numbers.index_vector_dim();
195 << absl::StrJoin(dim_numbers.update_window_dims(), ",") << "]";
197 << absl::StrJoin(dim_numbers.inserted_window_dims(), ",") << "]";
199 << absl::StrJoin(dim_numbers.scatter_dims_to_operand_dims(), ",")
203 dim_numbers);
    [all...]
  /external/tensorflow/tensorflow/compiler/xla/service/
batch_dot_simplification.cc 26 const DotDimensionNumbers& dim_numbers = batch_dot->dot_dimension_numbers(); local
34 if (dim_numbers.lhs_contracting_dimensions_size() != 1) {
39 for (int64 batch_dim : dim_numbers.lhs_batch_dimensions()) {
54 DotDimensionNumbers new_dim_numbers = dim_numbers;
58 for (int64 i = 0, e = dim_numbers.lhs_batch_dimensions_size() -
gather_expander.cc 111 HloInstruction* index_vector, const GatherDimensionNumbers& dim_numbers,
132 int64 index_vector_dim_index = FindIndex(dim_numbers.start_index_map(), i);
133 if (index_vector_dim_index != dim_numbers.start_index_map_size()) {
153 const GatherDimensionNumbers& dim_numbers = gather.gather_dimension_numbers(); local
161 dim_numbers.index_vector_dim() ==
197 ExpandIndexVectorIntoOperandSpace(index_vector, dim_numbers,
207 AsInt64Slice(dim_numbers.collapsed_slice_dims())));
235 const GatherDimensionNumbers& dim_numbers) {
240 if (!absl::c_binary_search(dim_numbers.collapsed_slice_dims(), i)) {
318 const GatherDimensionNumbers& dim_numbers local
    [all...]
scatter_expander.cc 133 HloInstruction* index_vector, const ScatterDimensionNumbers& dim_numbers,
154 FindIndex(dim_numbers.scatter_dims_to_operand_dims(), i);
156 dim_numbers.scatter_dims_to_operand_dims_size()) {
222 const ScatterDimensionNumbers& dim_numbers = local
258 ExpandIndexVectorIntoOperandSpace(index_vector, dim_numbers,
279 AsInt64Slice(dim_numbers.inserted_window_dims())));
350 const ScatterDimensionNumbers& dim_numbers = local
364 if (i != dim_numbers.index_vector_dim()) {
379 scatter_indices, dim_numbers.index_vector_dim()));
388 updates, AsInt64Slice(dim_numbers.update_window_dims())))
    [all...]
hlo_cost_analysis_test.cc 650 GatherDimensionNumbers dim_numbers; local
651 dim_numbers.add_offset_dims(1);
652 dim_numbers.add_collapsed_slice_dims(0);
653 dim_numbers.add_start_index_map(0);
654 dim_numbers.set_index_vector_dim(1);
655 Gather(operand, indices, dim_numbers, {1, 3});
677 ScatterDimensionNumbers dim_numbers; local
678 dim_numbers.set_index_vector_dim(1);
679 dim_numbers.add_update_window_dims(1);
680 dim_numbers.add_inserted_window_dims(0)
    [all...]
triangular_solve_expander.cc 74 GatherDimensionNumbers dim_numbers; local
76 dim_numbers.add_offset_dims(i);
77 dim_numbers.add_start_index_map(i);
81 dim_numbers.add_offset_dims(ndims - 1);
82 dim_numbers.add_offset_dims(ndims);
83 dim_numbers.add_start_index_map(ndims - 2);
84 dim_numbers.add_start_index_map(ndims - 1);
85 dim_numbers.set_index_vector_dim(1);
86 diag_blocks = Gather(a, start_indices, dim_numbers, slice_sizes);
convolution_group_converter.cc 202 auto dim_numbers = convolution->convolution_dimension_numbers(); local
218 int64 input_batch_dimension = dim_numbers.input_batch_dimension();
219 int64 output_batch_dimension = dim_numbers.output_batch_dimension();
220 int64 output_feature_dimension = dim_numbers.output_feature_dimension();
243 convolution->window(), dim_numbers, convolution->precision_config()));
334 auto dim_numbers = convolution->convolution_dimension_numbers(); local
336 int64 kernel_input_feature_dim = dim_numbers.kernel_input_feature_dimension();
339 dim_numbers.kernel_output_feature_dimension();
383 convolution->window(), dim_numbers, convolution->precision_config());
387 int64 activation_input_feature_dim = dim_numbers.input_feature_dimension()
    [all...]
hlo_evaluator.cc 364 const DotDimensionNumbers& dim_numbers,
374 ShapeInference::InferDotOpShape(lhs.shape(), rhs.shape(), dim_numbers));
378 dim_numbers, precision_config);
782 const Shape& output_shape, const GatherDimensionNumbers& dim_numbers) {
789 !absl::c_binary_search(dim_numbers.offset_dims(), i);
801 const GatherDimensionNumbers& dim_numbers) {
807 absl::c_binary_search(dim_numbers.offset_dims(), i);
809 while (absl::c_binary_search(dim_numbers.collapsed_slice_dims(),
831 const GatherDimensionNumbers* dim_numbers, const Shape& input_shape,
833 : dim_numbers_(*dim_numbers), start_indices_(*start_indices)
1050 const GatherDimensionNumbers& dim_numbers = local
    [all...]
elemental_ir_emitter.cc 1867 const GatherDimensionNumbers& dim_numbers = hlo->gather_dimension_numbers(); local
2118 const DotDimensionNumbers& dim_numbers = hlo->dot_dimension_numbers(); local
    [all...]
hlo_parser.cc 1636 GatherDimensionNumbers dim_numbers = local
1672 ScatterDimensionNumbers dim_numbers = local
    [all...]
  /external/tensorflow/tensorflow/compiler/xla/tests/
gather_operation_test.cc 646 GatherDimensionNumbers dim_numbers; local
647 dim_numbers.add_offset_dims(1);
648 dim_numbers.add_collapsed_slice_dims(0);
649 dim_numbers.add_start_index_map(0);
650 dim_numbers.set_index_vector_dim(1);
651 Gather(operand, indices, dim_numbers, {1, 3});
  /external/tensorflow/tensorflow/compiler/tf2xla/kernels/
gather_op.cc 116 xla::GatherDimensionNumbers dim_numbers; local
122 dim_numbers.add_collapsed_slice_dims(i);
131 dim_numbers.add_offset_dims(i);
135 dim_numbers.add_offset_dims(i + indices_rank - num_index_dims);
139 dim_numbers.set_index_vector_dim(indices_are_nd ? (indices_shape.dims() - 1)
142 dim_numbers.add_start_index_map(i);
145 *gather_output = xla::Gather(input, indices, dim_numbers, slice_sizes);
  /external/tensorflow/tensorflow/compiler/xla/service/gpu/
ir_emission_utils.cc 69 const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers(); local
74 dim_numbers.lhs_batch_dimensions_size())) {
78 CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)),
79 rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0)));
ir_emitter_unnested.cc 1081 const ScatterDimensionNumbers& dim_numbers = local
    [all...]

Completed in 1690 milliseconds