Home | History | Annotate | Download | only in service
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      7     http://www.apache.org/licenses/LICENSE-2.0
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     16 #include "tensorflow/compiler/xla/service/convolution_group_converter.h"
     18 #include <memory>
     19 #include <vector>
     21 #include "absl/memory/memory.h"
     22 #include "tensorflow/compiler/xla/literal.h"
     23 #include "tensorflow/compiler/xla/literal_util.h"
     24 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
     25 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     27 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     28 #include "tensorflow/compiler/xla/shape_util.h"
     29 #include "tensorflow/compiler/xla/status_macros.h"
     30 #include "tensorflow/compiler/xla/types.h"
     31 #include "tensorflow/compiler/xla/util.h"
     32 #include "tensorflow/compiler/xla/xla_data.pb.h"
     33 #include "tensorflow/core/lib/core/errors.h"
     34 #include "tensorflow/core/lib/core/status.h"
     35 #include "tensorflow/core/platform/logging.h"
     37 namespace xla {
     39 namespace {
     41 // ConvolutionVisitor traverses the HLO computation and rewrites Convolution
     42 // operations with feature_group_count > 1 into convolutions with
     43 // feature_group_count = 1.
     44 class ConvolutionVisitor : public DfsHloVisitorWithDefault {
     45  public:
     46   // Default visitor action is to do nothing and return OK.
     47   Status DefaultAction(HloInstruction* /*hlo_instruction*/) override {
     48     return Status::OK();
     49   }
     51   Status HandleConvolution(HloInstruction* convolution) override;
     53   Status HandleBatchGroupCount(HloInstruction* convolution);
     55   // Runs the visitor on a computation.
     56   static bool Run(HloComputation* computation,
     57                   std::function<bool(HloInstruction*)> is_cost_viable,
     58                   bool convert_batch_groups_only,
     59                   bool canonicalize_depthwise_filter);
     61   // Returns whether any convolution ops were rewritten.
     62   const bool changed() const { return changed_; }
     64   ~ConvolutionVisitor() override = default;
     66  private:
     67   explicit ConvolutionVisitor(
     68       HloComputation* computation,
     69       std::function<bool(HloInstruction*)> is_cost_viable,
     70       bool convert_batch_groups_only,
     71       bool canonicalize_depthwise_filter = false)
     72       : computation_(computation),
     73         filter_expansion_(!canonicalize_depthwise_filter),
     74         convert_batch_groups_only_(convert_batch_groups_only),
     75         is_cost_viable_(is_cost_viable) {}
     77   // Current HloComputation instance the ConvolutionVisitor is traversing.
     78   HloComputation* computation_;
     80   // Whether rewrite has occurred.
     81   bool changed_ = false;
     83   // Whether filter expansion is required.
     84   bool filter_expansion_;
     86   // Decides whether to convert batch groups or feature groups.
     87   bool convert_batch_groups_only_;
     89   // std::function<std::vector<LloValue*>(int64, int64)> chunk_fetcher
     90   std::function<bool(HloInstruction*)> is_cost_viable_;
     91 };
     93 bool ConvolutionVisitor::Run(
     94     HloComputation* computation,
     95     std::function<bool(HloInstruction*)> is_cost_viable,
     96     bool convert_batch_groups_only, bool canonicalize_depthwise_filter) {
     97   ConvolutionVisitor visitor(computation, is_cost_viable,
     98                              convert_batch_groups_only,
     99                              canonicalize_depthwise_filter);
    100   TF_CHECK_OK(computation->Accept(&visitor));
    101   return visitor.changed_;
    102 }
    104 Shape ExpandedFilterShape(const Shape& shape, int64 group_count,
    105                           int64 input_feature_dim) {
    106   int64 num_dims = shape.dimensions_size();
    107   CHECK_GE(num_dims, 2);
    108   Shape expanded_shape = shape;
    109   expanded_shape.set_dimensions(
    110       input_feature_dim, shape.dimensions(input_feature_dim) * group_count);
    111   return expanded_shape;
    112 }
    114 // Returns a vector with 'group_count' many groups, where the i-th group
    115 // consists of 'group_size' times the value i.
    116 std::vector<int32> GetMaskIds(int64 group_size, int64 group_count) {
    117   std::vector<int32> values;
    118   for (int i = 0; i < group_count; ++i) {
    119     for (int j = 0; j < group_size; ++j) {
    120       values.push_back(i);
    121     }
    122   }
    123   return values;
    124 }
    126 // Create a mask for grouped convolution that will make a normal convolution
    127 // produce the same results as a grouped convolution. For a [2, 1, 6]
    128 // filter this returns a [2, 3, 6] mask
    129 //   1 1 0 0 0 0
    130 //   0 0 1 1 0 0
    131 //   0 0 0 0 1 1
    132 //
    133 //   1 1 0 0 0 0
    134 //   0 0 1 1 0 0
    135 //   0 0 0 0 1 1
    136 //
    137 // The first step is to create a rank 1 constant:
    138 //   0 1 2
    139 //
    140 // This is broadcasted to
    141 //   0 0 0 0 0 0
    142 //   1 1 1 1 1 1
    143 //   2 2 2 2 2 2
    144 //
    145 //   0 0 0 0 0 0
    146 //   1 1 1 1 1 1
    147 //   2 2 2 2 2 2
    148 //
    149 // Then we create another rank 1 constant
    150 //   0 0 1 1 2 2
    151 //
    152 // This is broadcasted to
    153 //   0 0 1 1 2 2
    154 //   0 0 1 1 2 2
    155 //   0 0 1 1 2 2
    156 //
    157 //   0 0 1 1 2 2
    158 //   0 0 1 1 2 2
    159 //   0 0 1 1 2 2
    160 //
    161 // Finally we use the Eq op of these two broadcasted constants and get the
    162 // desired mask.
    163 HloInstruction* GetExpandedFilterMask(
    164     const Shape& filter_shape, int64 kernel_input_feature_dim,
    165     int64 kernel_output_feature_dim, int64 group_count,
    166     const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
    167         add_instruction) {
    168   Shape expanded_filter_shape =
    169       ExpandedFilterShape(filter_shape, group_count, kernel_input_feature_dim);
    170   Shape mask_shape = ShapeUtil::MakeShape(
    171       S32, AsInt64Slice(expanded_filter_shape.dimensions()));
    172   int64 output_feature = filter_shape.dimensions(kernel_output_feature_dim);
    173   int64 group_size = filter_shape.dimensions(kernel_input_feature_dim);
    175   // Create a 'input_feature' sized linspace and 'output_feature' sized linspace
    176   // that will be broadcasted into perpendicular dimensions and compared.
    177   const std::vector<int32> input_feature_filter_mask =
    178       GetMaskIds(group_size, group_count);
    179   const std::vector<int32> output_feature_filter_mask =
    180       GetMaskIds(output_feature / group_count, group_count);
    181   auto mask1 = add_instruction(HloInstruction::CreateConstant(
    182       LiteralUtil::CreateR1<int32>(input_feature_filter_mask)));
    183   auto broadcasted_mask1 = add_instruction(HloInstruction::CreateBroadcast(
    184       mask_shape, mask1, {kernel_input_feature_dim}));
    185   auto mask2 = add_instruction(HloInstruction::CreateConstant(
    186       LiteralUtil::CreateR1<int32>(output_feature_filter_mask)));
    187   auto broadcasted_mask2 = add_instruction(HloInstruction::CreateBroadcast(
    188       mask_shape, mask2, {kernel_output_feature_dim}));
    190   // Compare the broadcasted output feature linspace to the input feature
    191   // linspace to create a diagonal predicate.
    192   Shape predicate_shape = ShapeUtil::MakeShape(
    193       PRED, AsInt64Slice(expanded_filter_shape.dimensions()));
    194   return add_instruction(HloInstruction::CreateCompare(
    195       predicate_shape, broadcasted_mask1, broadcasted_mask2,
    196       ComparisonDirection::kEq));
    197 }
    199 // This function handles batch_group_counts which are relevant only for
    200 // depthwise backprop filter convolutions.
    201 Status ConvolutionVisitor::HandleBatchGroupCount(HloInstruction* convolution) {
    202   auto dim_numbers = convolution->convolution_dimension_numbers();
    203   auto activation = convolution->mutable_operand(0);
    204   auto filter = convolution->mutable_operand(1);
    205   int64 batch_group_count = convolution->batch_group_count();
    207   if (batch_group_count == 1) {
    208     return Status::OK();
    209   }
    211   VLOG(2) << "Dealing with batch_group_count " << batch_group_count
    212           << " for convolution " << convolution->ToString() << "\n";
    214   auto add = [&](std::unique_ptr<HloInstruction> inst) {
    215     return computation_->AddInstruction(std::move(inst));
    216   };
    218   int64 input_batch_dimension = dim_numbers.input_batch_dimension();
    219   int64 output_batch_dimension = dim_numbers.output_batch_dimension();
    220   int64 output_feature_dimension = dim_numbers.output_feature_dimension();
    222   int64 input_batch = activation->shape().dimensions(input_batch_dimension);
    224   // We are not yet supporting batch_group of sizes greater than 1.
    225   TF_RET_CHECK(input_batch == batch_group_count);
    227   if (!is_cost_viable_(convolution) || filter_expansion_) {
    228     // We first obtain the expanded the filter (which is the convolution
    229     // output). The batch dimension is the expanded one (which originally
    230     // represents kernel input feature dimension). We mask the filter to zero
    231     // out the expanded regions. Next we reduce the filter in the batch
    232     // dimension to obtain the original filter size.
    234     HloInstruction* filter_mask =
    235         GetExpandedFilterMask(convolution->shape(), output_batch_dimension,
    236                               output_feature_dimension, batch_group_count, add);
    237     auto expanded_filter_shape = ExpandedFilterShape(
    238         convolution->shape(), batch_group_count, output_batch_dimension);
    240     auto new_convolution = add(HloInstruction::CreateConvolve(
    241         expanded_filter_shape, activation, filter,
    242         /*feature_group_count=*/1, /*batch_group_count=*/1,
    243         convolution->window(), dim_numbers, convolution->precision_config()));
    245     auto zero = add(HloInstruction::CreateConstant(
    246         LiteralUtil::Zero(expanded_filter_shape.element_type())));
    247     auto zero_filter =
    248         add(HloInstruction::CreateBroadcast(expanded_filter_shape, zero, {}));
    250     auto new_filter = add(HloInstruction::CreateTernary(
    251         expanded_filter_shape, HloOpcode::kSelect, filter_mask, new_convolution,
    252         zero_filter));
    254     PrimitiveType reduce_type = new_filter->shape().element_type();
    255     auto reduce_window_shape = new_convolution->shape();
    256     reduce_window_shape.set_dimensions(output_batch_dimension, 1);
    258     // Ensure that data input to reduce window uses at least 32 bits.
    259     if (primitive_util::BitWidth(reduce_type) < primitive_util::BitWidth(F32)) {
    260       reduce_type = F32;
    261       reduce_window_shape.set_element_type(F32);
    262       Shape convert_shape = new_filter->shape();
    263       convert_shape.set_element_type(F32);
    264       new_filter =
    265           add(HloInstruction::CreateConvert(convert_shape, new_filter));
    266     }
    268     auto zero_literal = LiteralUtil::Zero(reduce_type);
    269     auto zero_scalar =
    270         add(HloInstruction::CreateConstant(std::move(zero_literal)));
    272     auto reduce_function = [&]() -> HloComputation* {
    273       HloComputation::Builder b("add_computation");
    274       Shape shape = ShapeUtil::MakeShape(reduce_type, {});
    275       auto lhs =
    276           b.AddInstruction(HloInstruction::CreateParameter(0, shape, "lhs"));
    277       auto rhs =
    278           b.AddInstruction(HloInstruction::CreateParameter(1, shape, "rhs"));
    279       auto scalar_op = b.AddInstruction(
    280           HloInstruction::CreateBinary(shape, HloOpcode::kAdd, lhs, rhs));
    281       return computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op));
    282     };
    284     // Create the reduce window.
    285     Window window;
    286     for (int64 i = 0; i < new_convolution->shape().dimensions_size(); ++i) {
    287       auto* dim = window.add_dimensions();
    288       dim->set_padding_low(0);
    289       dim->set_padding_high(0);
    290       dim->set_window_dilation(1);
    291       dim->set_base_dilation(1);
    292       if (i == output_batch_dimension) {
    293         dim->set_stride(batch_group_count);
    294         dim->set_size(batch_group_count);
    295       } else {
    296         dim->set_stride(1);
    297         dim->set_size(1);
    298       }
    299     }
    300     auto reduce_window = add(HloInstruction::CreateReduceWindow(
    301         reduce_window_shape, new_filter, zero_scalar, window,
    302         reduce_function()));
    304     Shape convert_back_shape = reduce_window->shape();
    305     convert_back_shape.set_element_type(activation->shape().element_type());
    307     // Convert reduced data back to the original data type.
    308     auto reduce_window_converted =
    309         HloInstruction::CreateConvert(convert_back_shape, reduce_window);
    311     TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction(
    312         convolution, std::move(reduce_window_converted)));
    313     changed_ = true;
    314   }
    316   return Status::OK();
    317 }
    319 Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
    320   if (convert_batch_groups_only_) {
    321     return HandleBatchGroupCount(convolution);
    322   }
    324   auto add = [&](std::unique_ptr<HloInstruction> inst) {
    325     return computation_->AddInstruction(std::move(inst));
    326   };
    328   int64 group_count = convolution->feature_group_count();
    329   if (group_count == 1) {
    330     return Status::OK();
    331   }
    333   changed_ = true;
    334   auto dim_numbers = convolution->convolution_dimension_numbers();
    335   auto filter = convolution->mutable_operand(1);
    336   int64 kernel_input_feature_dim = dim_numbers.kernel_input_feature_dimension();
    337   int64 group_size = filter->shape().dimensions(kernel_input_feature_dim);
    338   int64 kernel_output_feature_dim =
    339       dim_numbers.kernel_output_feature_dimension();
    340   auto expanded_filter_shape = ExpandedFilterShape(filter->shape(), group_count,
    341                                                    kernel_input_feature_dim);
    342   HloInstruction* filter_mask =
    343       GetExpandedFilterMask(filter->shape(), kernel_input_feature_dim,
    344                             kernel_output_feature_dim, group_count, add);
    345   HloInstruction* expanded_filter;
    347   if (group_size == 1) {
    348     bool depthwise_separable =
    349         (group_count == filter->shape().dimensions(kernel_output_feature_dim));
    350     // If the code generator handles depthwise separable convolutions
    351     // inherently, then no filter expansion is needed.
    352     if (!filter_expansion_ && depthwise_separable) {
    353       changed_ = false;
    354       return Status::OK();
    355     }
    356     // We want to repeat 'filter' in the 'input_feature_dim' dimension
    357     // 'group_count' times.
    358     Shape reshaped_filter_shape =
    359         ShapeUtil::DeleteDimension(kernel_input_feature_dim, filter->shape());
    360     auto reshaped_filter =
    361         add(HloInstruction::CreateReshape(reshaped_filter_shape, filter));
    362     std::vector<int64> broadcast_dims;
    363     for (int64 i = 0; i < filter->shape().dimensions_size(); ++i) {
    364       if (i == kernel_input_feature_dim) {
    365         continue;
    366       }
    367       broadcast_dims.push_back(i);
    368     }
    369     expanded_filter = add(HloInstruction::CreateBroadcast(
    370         expanded_filter_shape, reshaped_filter, broadcast_dims));
    372     auto zero = add(HloInstruction::CreateConstant(
    373         LiteralUtil::Zero(expanded_filter_shape.element_type())));
    374     auto zero_filter =
    375         add(HloInstruction::CreateBroadcast(expanded_filter_shape, zero, {}));
    376     auto new_filter = add(HloInstruction::CreateTernary(
    377         expanded_filter_shape, HloOpcode::kSelect, filter_mask, expanded_filter,
    378         zero_filter));
    380     auto new_convolution = HloInstruction::CreateConvolve(
    381         convolution->shape(), convolution->mutable_operand(0), new_filter,
    382         /*feature_group_count=*/1, /*batch_group_count=*/1,
    383         convolution->window(), dim_numbers, convolution->precision_config());
    384     TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction(
    385         convolution, std::move(new_convolution)));
    386   } else {
    387     int64 activation_input_feature_dim = dim_numbers.input_feature_dimension();
    389     int64 output_feature =
    390         filter->shape().dimensions(kernel_output_feature_dim);
    392     // If group_count == output_feature, then we map those grouped convolutions
    393     // onto depthwise convolution. This is done by adding an additional spatial
    394     // dimension to the activations, kernel, and the output.
    395     // E.g., we would turn
    396     // [2, 12]{B, IF} conv [3, 4]{IF, OF} into
    397     // [3, 2, 4]{S, B, IF} depth conv [3, 1, 4]{S, IF, OF}, where S is the
    398     // additional spatial dimension. The generated convolution output will be
    399     // [1, 2, 4]{S, B, OF} and then reshape the output back to [2, 4] {B, OF}.
    401     if (group_count == output_feature && !filter_expansion_) {
    402       auto filter = convolution->mutable_operand(1);
    403       auto activation = convolution->mutable_operand(0);
    405       // Add spatial dimension to the activation, and reshape.
    406       Shape reshaped_activation_shape = activation->shape();
    407       ShapeUtil::AppendMajorDimension(group_size, &reshaped_activation_shape);
    409       int64 new_spatial_dim = reshaped_activation_shape.dimensions().size() - 1;
    411       reshaped_activation_shape.set_dimensions(activation_input_feature_dim,
    412                                                group_count);
    413       activation = add(
    414           HloInstruction::CreateReshape(reshaped_activation_shape, activation));
    416       // Add spatial dimension to the filter, and reshape.
    417       Shape reshaped_filter_shape = filter->shape();
    418       ShapeUtil::AppendMajorDimension(1, &reshaped_filter_shape);
    420       filter =
    421           add(HloInstruction::CreateReshape(reshaped_filter_shape, filter));
    423       Shape new_output_shape = convolution->shape();
    424       ShapeUtil::AppendMajorDimension(1, &new_output_shape);
    426       // Edit convolution dimension numbers. Note that kernel_input_feature_dim
    427       // now becomes a spatial dimension, and the newly added dimension of size
    428       // 1 is the new kernel_input_feature_dim.
    429       dim_numbers.add_input_spatial_dimensions(new_spatial_dim);
    430       dim_numbers.add_kernel_spatial_dimensions(kernel_input_feature_dim);
    431       dim_numbers.set_kernel_input_feature_dimension(new_spatial_dim);
    432       dim_numbers.add_output_spatial_dimensions(new_spatial_dim);
    434       // Add window for the new spatial dimension.
    435       Window new_window = convolution->window();
    436       auto* dim = new_window.add_dimensions();
    437       dim->set_window_dilation(1);
    438       dim->set_base_dilation(1);
    439       dim->set_stride(1);
    440       dim->set_size(group_size);
    442       auto new_convolution = add(HloInstruction::CreateConvolve(
    443           new_output_shape, activation, filter, group_count,
    444           /*batch_group_count=*/1, new_window, dim_numbers,
    445           convolution->precision_config()));
    447       // Delete the extra spatial dimension, and reshape.
    448       Shape reshaped_convolution_shape =
    449           ShapeUtil::DeleteDimension(new_spatial_dim, new_convolution->shape());
    450       auto reshaped_convolution = HloInstruction::CreateReshape(
    451           reshaped_convolution_shape, new_convolution);
    453       TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction(
    454           convolution, std::move(reshaped_convolution)));
    456     } else {
    457       // The filter expansion mechanism adds zeroes in the kernel.
    458       // For an OF = 12, IF = 6, and kernel IF = 2, the expanded filter mask
    459       // would look like (IF on the Y-axis, OF on the X-axis)
    460       // 1 1 1 1 0 0 0 0 0 0 0 0
    461       // 1 1 1 1 0 0 0 0 0 0 0 0
    462       // 0 0 0 0 1 1 1 1 0 0 0 0
    463       // 0 0 0 0 1 1 1 1 0 0 0 0
    464       // 0 0 0 0 0 0 0 0 1 1 1 1
    465       // 0 0 0 0 0 0 0 0 1 1 1 1
    466       //
    467       // Instead of convolving the above with the input, we instead slice the
    468       // kernel into three kernels, each containing islands of 1s from the
    469       // filter above. We also slice the activations in the IF dimension with
    470       // each slice of size = group_size. For each slice, we perform
    471       // convolutions, and concatenate the generated outputs in the output OF
    472       // dimension.
    474       std::vector<HloInstruction*> sliced_convolutions;
    475       auto activation = convolution->mutable_operand(0);
    476       std::vector<int64> slice_strides(filter->shape().dimensions_size(), 1);
    477       std::vector<int64> filter_slice_starts(filter->shape().dimensions_size(),
    478                                              0);
    479       std::vector<int64> filter_slice_limits(
    480           filter->shape().dimensions().begin(),
    481           filter->shape().dimensions().end());
    482       std::vector<int64> activation_slice_starts(
    483           activation->shape().dimensions_size(), 0);
    484       std::vector<int64> activation_slice_limits(
    485           activation->shape().dimensions().begin(),
    486           activation->shape().dimensions().end());
    488       int64 output_feature =
    489           filter->shape().dimensions(kernel_output_feature_dim);
    490       auto output_feature_dim = dim_numbers.output_feature_dimension();
    491       int64 filter_slice_width = output_feature / group_count;
    493       int64 activation_input_feature_dim =
    494           dim_numbers.input_feature_dimension();
    496       for (int64 i = 0; i < group_count; i++) {
    497         filter_slice_starts[kernel_output_feature_dim] = i * filter_slice_width;
    498         filter_slice_limits[kernel_output_feature_dim] =
    499             (i + 1) * filter_slice_width;
    500         auto filter_sliced_shape = filter->shape();
    501         filter_sliced_shape.set_dimensions(kernel_output_feature_dim,
    502                                            filter_slice_width);
    503         auto filter_slice = add(HloInstruction::CreateSlice(
    504             filter_sliced_shape, filter, filter_slice_starts,
    505             filter_slice_limits, slice_strides));
    507         activation_slice_starts[activation_input_feature_dim] = i * group_size;
    508         activation_slice_limits[activation_input_feature_dim] =
    509             (i + 1) * group_size;
    510         auto activation_sliced_shape = activation->shape();
    511         activation_sliced_shape.set_dimensions(activation_input_feature_dim,
    512                                                group_size);
    513         auto activation_slice = add(HloInstruction::CreateSlice(
    514             activation_sliced_shape, activation, activation_slice_starts,
    515             activation_slice_limits, slice_strides));
    517         auto conv_slice_shape = convolution->shape();
    518         conv_slice_shape.set_dimensions(output_feature_dim, filter_slice_width);
    520         auto new_convolution = add(HloInstruction::CreateConvolve(
    521             conv_slice_shape, activation_slice, filter_slice,
    522             /*feature_group_count=*/1, /*batch_group_count=*/1,
    523             convolution->window(), dim_numbers,
    524             convolution->precision_config()));
    526         sliced_convolutions.push_back(new_convolution);
    527       }
    529       auto new_conv = HloInstruction::CreateConcatenate(
    530           convolution->shape(), sliced_convolutions, output_feature_dim);
    531       TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction(
    532           convolution, std::move(new_conv)));
    533     }
    534   }
    536   return Status::OK();
    537 }
    539 }  // namespace
    541 StatusOr<bool> ConvolutionGroupConverter::Run(HloModule* module) {
    542   XLA_VLOG_LINES(
    543       2, "ConvolutionGroupConverter::Run(), before:\n" + module->ToString());
    544   bool changed = false;
    545   for (auto* comp : module->MakeNonfusionComputations()) {
    546     if (ConvolutionVisitor::Run(comp, is_cost_viable_,
    547                                 convert_batch_groups_only_,
    548                                 filter_expansion_)) {
    549       changed = true;
    550     }
    551   }
    552   XLA_VLOG_LINES(
    553       2, "ConvolutionGroupConverter::Run(), after:\n" + module->ToString());
    554   return changed;
    555 }
    557 }  // namespace xla