Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/hlo_query.h"
     17 
     18 #include "tensorflow/compiler/xla/literal_util.h"
     19 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     20 #include "tensorflow/compiler/xla/shape_util.h"
     21 
     22 namespace xla {
     23 namespace hlo_query {
     24 
     25 bool IsConstantR0F32(HloInstruction* instruction, float* out) {
     26   if (instruction->opcode() == HloOpcode::kConstant &&
     27       ShapeUtil::IsScalarF32(instruction->shape())) {
     28     *out = instruction->literal().Get<float>({});
     29     return true;
     30   }
     31 
     32   return false;
     33 }
     34 
     35 bool AllOperandsAreParametersOrConstants(const HloInstruction& instruction) {
     36   for (const auto& operand : instruction.operands()) {
     37     if (operand->opcode() != HloOpcode::kParameter &&
     38         operand->opcode() != HloOpcode::kConstant) {
     39       return false;
     40     }
     41   }
     42   return true;
     43 }
     44 
     45 bool AllOperandsAreParameters(const HloInstruction& instruction) {
     46   for (const auto& operand : instruction.operands()) {
     47     if (operand->opcode() != HloOpcode::kParameter) {
     48       return false;
     49     }
     50   }
     51   return true;
     52 }
     53 
     54 bool AllOperandsAreConstants(const HloInstruction& instruction) {
     55   for (const auto& operand : instruction.operands()) {
     56     if (operand->opcode() != HloOpcode::kConstant) {
     57       return false;
     58     }
     59   }
     60   return true;
     61 }
     62 
     63 HloInstruction* GetMatchingOperand(
     64     std::function<bool(const HloInstruction*)> matcher,
     65     HloInstruction* instruction) {
     66   for (HloInstruction* op : instruction->operands()) {
     67     if (matcher(op)) {
     68       return op;
     69     }
     70   }
     71   return nullptr;
     72 }
     73 
     74 bool MatchBinaryInstructionOperand(
     75     std::function<bool(const HloInstruction*)> matcher,
     76     HloInstruction* instruction, HloInstruction** matching_operand,
     77     HloInstruction** other_operand) {
     78   CHECK_EQ(instruction->operand_count(), 2);
     79   if (matcher(instruction->operand(0))) {
     80     *matching_operand = instruction->mutable_operand(0);
     81     *other_operand = instruction->mutable_operand(1);
     82     return true;
     83   }
     84   if (matcher(instruction->operand(1))) {
     85     *matching_operand = instruction->mutable_operand(1);
     86     *other_operand = instruction->mutable_operand(0);
     87     return true;
     88   }
     89   return false;
     90 }
     91 
     92 bool MatchBinaryInstructionOperandOpcode(HloOpcode opcode,
     93                                          HloInstruction* instruction,
     94                                          HloInstruction** matching_operand,
     95                                          HloInstruction** other_operand) {
     96   return MatchBinaryInstructionOperand(
     97       [opcode](const HloInstruction* instruction) {
     98         return instruction->opcode() == opcode;
     99       },
    100       instruction, matching_operand, other_operand);
    101 }
    102 
    103 bool IsScalarConstant(const HloInstruction* instruction) {
    104   return instruction->IsConstant() && ShapeUtil::IsScalar(instruction->shape());
    105 }
    106 
    107 }  // namespace hlo_query
    108 }  // namespace xla
    109