1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/hlo_proto_util.h" 17 #include "tensorflow/compiler/xla/service/hlo_verifier.h" 18 19 #include <string> 20 21 #include "tensorflow/compiler/xla/util.h" 22 23 namespace xla { 24 25 HloProto MakeHloProto(const HloModule& module, 26 const BufferAssignment& assignment) { 27 BufferAssignmentProto proto_assignment = assignment.ToProto(); 28 HloProto proto = MakeHloProto(module); 29 proto.mutable_buffer_assignment()->Swap(&proto_assignment); 30 return proto; 31 } 32 33 HloProto MakeHloProto(const HloModule& module) { 34 HloModuleProto proto_module = module.ToProto(); 35 HloProto proto; 36 proto.mutable_hlo_module()->Swap(&proto_module); 37 return proto; 38 } 39 40 StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto( 41 const HloModuleProto& proto, const HloModuleConfig& module_config) { 42 VLOG(4) << proto.ShortDebugString(); 43 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module, 44 HloModule::CreateFromProto(proto, module_config)); 45 TF_RETURN_IF_ERROR( 46 HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false) 47 .Run(module.get()) 48 .status()); 49 return std::move(module); 50 } 51 52 StatusOr<std::vector<const ShapeProto*>> EntryComputationParameterShapes( 53 const HloProto& hlo_proto) { 54 if (!hlo_proto.has_hlo_module()) { 55 return NotFound("HloProto missing HloModuleProto."); 56 } 57 if (!hlo_proto.hlo_module().has_host_program_shape()) { 58 return NotFound("HloProto missing program shape."); 59 } 60 61 std::vector<const ShapeProto*> parameter_shapes; 62 const auto& program_shape = hlo_proto.hlo_module().host_program_shape(); 63 for (const ShapeProto& shape : program_shape.parameters()) { 64 parameter_shapes.push_back(&shape); 65 } 66 return parameter_shapes; 67 } 68 69 StatusOr<const ShapeProto*> EntryComputationOutputShape( 70 const HloProto& hlo_proto) { 71 if (!hlo_proto.has_hlo_module()) { 72 return NotFound("HloProto missing HloModuleProto."); 73 } 74 if (!hlo_proto.hlo_module().has_host_program_shape()) { 75 return NotFound("HloProto missing program shape."); 76 } 77 if (!hlo_proto.hlo_module().host_program_shape().has_result()) { 78 return NotFound("HloProto missing result in its program shape"); 79 } 80 81 return &hlo_proto.hlo_module().host_program_shape().result(); 82 } 83 84 } // namespace xla 85