1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/framework/common_shape_fns.h" 17 #include "tensorflow/core/framework/op.h" 18 #include "tensorflow/core/framework/shape_inference.h" 19 20 namespace tensorflow { 21 22 using shape_inference::InferenceContext; 23 using shape_inference::ShapeHandle; 24 25 REGISTER_OP("TPUReplicateMetadata") 26 .Attr("num_replicas: int >= 0") 27 .Attr("topology: string = \"\"") 28 .Attr("device_assignment: list(int) = []") 29 .Attr("computation_shape: list(int) = []") 30 .SetShapeFn(shape_inference::UnknownShape); 31 32 REGISTER_OP("TPUReplicatedInput") 33 .Input("inputs: N * T") 34 .Output("output: T") 35 .Attr("N: int >= 1") 36 .Attr("T: type") 37 .SetShapeFn([](InferenceContext* c) { 38 ShapeHandle cur = c->input(c->num_inputs() - 1); 39 for (int i = c->num_inputs() - 2; i >= 0; --i) { 40 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur), 41 "From merging shape ", i, 42 " with other shapes."); 43 } 44 c->set_output(0, cur); 45 return Status::OK(); 46 }) 47 .Doc( 48 "Operator that connects N unreplicated inputs to an N-way " 49 "replicated TPU computation."); 50 51 REGISTER_OP("TPUReplicatedOutput") 52 .Input("input: T") 53 .Output("outputs: num_replicas * T") 54 .Attr("num_replicas: int >= 1") 55 .Attr("T: type") 56 .SetShapeFn([](InferenceContext* c) { 57 for (int i = 0; i < c->num_outputs(); ++i) { 58 c->set_output(i, c->input(0)); 59 } 60 return Status::OK(); 61 }) 62 .Doc( 63 "Operator that connects the output of an N-way replicated TPU " 64 "computation to N separate outputs."); 65 66 REGISTER_OP("TPUReplicate") 67 .Attr("computation: func") 68 .Attr("num_replicas: int >= 1") 69 .Attr("topology: string = \"\"") 70 .Attr("device_assignment: list(int) = []") 71 .Attr("computation_shape: list(int) = []") 72 .Attr("Tinputs: list(type) >= 0") 73 .Attr("Tbroadcast_inputs: list(type) >= 0") 74 .Attr("NumVariables: int >= 0") 75 .Attr("Tguaranteed_constants: list(type) >= 0") 76 .Attr("output_types: list(type) >= 0") 77 .Input("inputs: Tinputs") 78 .Input("broadcast_inputs: Tbroadcast_inputs") 79 .Input("variables: NumVariables * resource") 80 .Input("guaranteed_constants: Tguaranteed_constants") 81 .Output("outputs: output_types") 82 .SetShapeFn(shape_inference::UnknownShape) 83 .Doc(R"doc( 84 Runs replicated computations on a distributed TPU system. 85 86 computation: a function containing the computation to run. 87 num_replicas: the number of replicas of the computation to run. 88 topology: A serialized tensorflow.tpu.TopologyProto that describes the TPU 89 topology. 90 computation_shape: a [mesh_dimension] array describing the shape of each 91 computation replica in numbers of cores in the TPU mesh. 92 device_assignment: a flattened array with shape 93 [replica] + computation_shape + [mesh_dimension] that maps the coordinates of 94 logical cores in each replica of a computation to physical coordinates in 95 the TPU topology. 96 Tinputs: the types of the arguments to 'computation'. 97 inputs: the inputs to 'computation', flattened, in replica-major order. 98 Tbroadcast_inputs: the types of the additional arguments to broadcast to all 99 replicas. 100 Tguaranteed_constants: the types of the arguments to 'guaranteed_constants'. 101 broadcast_inputs: additional arguments to broadcast to all replicas. The 102 broadcast inputs are appended to the per-replica inputs when calling 103 computation. 104 guaranteed_constants: arguments which have been guaranteed to not 105 change their values during the session lifetime. These contain tensors marked as 106 constant using the GuaranteeConstOp. 107 output_types: the types of the outputs of 'computation'. 108 outputs: the outputs of 'computation'. 109 )doc"); 110 111 } // namespace tensorflow 112