Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/service/hlo_verifier.h"
     17 
     18 #include <memory>
     19 #include <utility>
     20 
     21 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     23 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
     24 #include "tensorflow/compiler/xla/shape_util.h"
     25 #include "tensorflow/compiler/xla/test.h"
     26 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
     27 #include "tensorflow/compiler/xla/types.h"
     28 #include "tensorflow/compiler/xla/xla_data.pb.h"
     29 #include "tensorflow/core/lib/core/status_test_util.h"
     30 
     31 namespace xla {
     32 namespace {
     33 
     34 using ::testing::HasSubstr;
     35 
     36 using HloVerifierTest = HloTestBase;
     37 
     38 TEST_F(HloVerifierTest, NullInstructionParent) {
     39   HloComputation::Builder builder(TestName());
     40   const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
     41   HloInstruction* param = builder.AddInstruction(
     42       HloInstruction::CreateParameter(0, scalar_shape, "param"));
     43   HloInstruction* negate = builder.AddInstruction(
     44       HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
     45   auto module = CreateNewModule();
     46   module->AddEntryComputation(builder.Build());
     47 
     48   TF_ASSERT_OK(verifier().Run(module.get()).status());
     49 
     50   negate->set_parent(nullptr);
     51 
     52   auto status = verifier().Run(module.get()).status();
     53   ASSERT_FALSE(status.ok());
     54   EXPECT_THAT(status.error_message(), HasSubstr("has a null parent pointer"));
     55 }
     56 
     57 TEST_F(HloVerifierTest, NullComputationParent) {
     58   HloComputation::Builder builder(TestName());
     59   const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
     60   HloInstruction* param = builder.AddInstruction(
     61       HloInstruction::CreateParameter(0, scalar_shape, "param"));
     62   builder.AddInstruction(
     63       HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
     64   auto module = CreateNewModule();
     65   HloComputation* computation = module->AddEntryComputation(builder.Build());
     66 
     67   TF_ASSERT_OK(verifier().Run(module.get()).status());
     68 
     69   computation->set_parent(nullptr);
     70 
     71   auto status = verifier().Run(module.get()).status();
     72   ASSERT_FALSE(status.ok());
     73   EXPECT_THAT(status.error_message(), HasSubstr("has a null parent pointer"));
     74 }
     75 
     76 TEST_F(HloVerifierTest, DifferentOperandParents) {
     77   HloComputation::Builder builder(TestName());
     78   const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
     79   HloInstruction* param = builder.AddInstruction(
     80       HloInstruction::CreateParameter(0, scalar_shape, "param"));
     81   HloInstruction* negate = builder.AddInstruction(
     82       HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param));
     83   auto module = CreateNewModule();
     84   module->AddEntryComputation(builder.Build());
     85 
     86   HloComputation::Builder emb_builder(TestName());
     87   HloInstruction* emb_param = emb_builder.AddInstruction(
     88       HloInstruction::CreateParameter(0, scalar_shape, "param"));
     89   module->AddEmbeddedComputation(emb_builder.Build());
     90 
     91   TF_ASSERT_OK(verifier().Run(module.get()).status());
     92   TF_ASSERT_OK(negate->ReplaceOperandWith(0, emb_param));
     93 
     94   auto status = verifier().Run(module.get()).status();
     95   ASSERT_FALSE(status.ok());
     96   EXPECT_THAT(status.error_message(),
     97               HasSubstr("is in a different computation"));
     98 }
     99 
    100 TEST_F(HloVerifierTest, ResetsShapeVerifierState) {
    101   HloComputation::Builder builder(TestName());
    102   Shape s1 = ShapeUtil::MakeShape(F32, {1});
    103   Shape s2 = ShapeUtil::MakeShape(F32, {2});
    104 
    105   HloInstruction* param =
    106       builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "param"));
    107 
    108   // Create an add instruction with the incorrect shape.
    109   HloInstruction* add = builder.AddInstruction(
    110       HloInstruction::CreateBinary(s2, HloOpcode::kAdd, param, param));
    111 
    112   // In order to trigger the bug we're checking for, the instruction with the
    113   // bad shape can't be the root of the computation.
    114   builder.AddInstruction(
    115       HloInstruction::CreateBinary(s2, HloOpcode::kMultiply, add, add));
    116 
    117   auto module = CreateNewModule();
    118   module->AddEntryComputation(builder.Build());
    119 
    120   // Run the verifier twice.  It should fail both times, because it shouldn't
    121   // carry state in its DFS visitor between runs.
    122   EXPECT_FALSE(verifier().Run(module.get()).status().ok());
    123   EXPECT_FALSE(verifier().Run(module.get()).status().ok());
    124 }
    125 
    126 }  // namespace
    127 }  // namespace xla
    128