Home | History | Annotate | Download | only in ops
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/framework/common_shape_fns.h"
     17 #include "tensorflow/core/framework/op.h"
     18 #include "tensorflow/core/framework/shape_inference.h"
     19 
     20 namespace tensorflow {
     21 
     22 using shape_inference::InferenceContext;
     23 using shape_inference::ShapeHandle;
     24 
     25 REGISTER_OP("TPUReplicateMetadata")
     26     .Attr("num_replicas: int >= 0")
     27     .Attr("topology: string = \"\"")
     28     .Attr("device_assignment: list(int) = []")
     29     .Attr("computation_shape: list(int) = []")
     30     .SetShapeFn(shape_inference::UnknownShape);
     31 
     32 REGISTER_OP("TPUReplicatedInput")
     33     .Input("inputs: N * T")
     34     .Output("output: T")
     35     .Attr("N: int >= 1")
     36     .Attr("T: type")
     37     .SetShapeFn([](InferenceContext* c) {
     38       ShapeHandle cur = c->input(c->num_inputs() - 1);
     39       for (int i = c->num_inputs() - 2; i >= 0; --i) {
     40         TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur),
     41                                         "From merging shape ", i,
     42                                         " with other shapes.");
     43       }
     44       c->set_output(0, cur);
     45       return Status::OK();
     46     })
     47     .Doc(
     48         "Operator that connects N unreplicated inputs to an N-way "
     49         "replicated TPU computation.");
     50 
     51 REGISTER_OP("TPUReplicatedOutput")
     52     .Input("input: T")
     53     .Output("outputs: num_replicas * T")
     54     .Attr("num_replicas: int >= 1")
     55     .Attr("T: type")
     56     .SetShapeFn([](InferenceContext* c) {
     57       for (int i = 0; i < c->num_outputs(); ++i) {
     58         c->set_output(i, c->input(0));
     59       }
     60       return Status::OK();
     61     })
     62     .Doc(
     63         "Operator that connects the output of an N-way replicated TPU "
     64         "computation to N separate outputs.");
     65 
     66 REGISTER_OP("TPUReplicate")
     67     .Attr("computation: func")
     68     .Attr("num_replicas: int >= 1")
     69     .Attr("topology: string = \"\"")
     70     .Attr("device_assignment: list(int) = []")
     71     .Attr("computation_shape: list(int) = []")
     72     .Attr("Tinputs: list(type) >= 0")
     73     .Attr("Tbroadcast_inputs: list(type) >= 0")
     74     .Attr("NumVariables: int >= 0")
     75     .Attr("Tguaranteed_constants: list(type) >= 0")
     76     .Attr("output_types: list(type) >= 0")
     77     .Input("inputs: Tinputs")
     78     .Input("broadcast_inputs: Tbroadcast_inputs")
     79     .Input("variables: NumVariables * resource")
     80     .Input("guaranteed_constants: Tguaranteed_constants")
     81     .Output("outputs: output_types")
     82     .SetShapeFn(shape_inference::UnknownShape)
     83     .Doc(R"doc(
     84 Runs replicated computations on a distributed TPU system.
     85 
     86 computation: a function containing the computation to run.
     87 num_replicas: the number of replicas of the computation to run.
     88 topology: A serialized tensorflow.tpu.TopologyProto that describes the TPU
     89 topology.
     90 computation_shape: a [mesh_dimension] array describing the shape of each
     91   computation replica in numbers of cores in the TPU mesh.
     92 device_assignment: a flattened array with shape
     93   [replica] + computation_shape + [mesh_dimension] that maps the coordinates of
     94   logical cores in each replica of a computation to physical coordinates in
     95   the TPU topology.
     96 Tinputs: the types of the arguments to 'computation'.
     97 inputs: the inputs to 'computation', flattened, in replica-major order.
     98 Tbroadcast_inputs: the types of the additional arguments to broadcast to all
     99   replicas.
    100 Tguaranteed_constants: the types of the arguments to 'guaranteed_constants'.
    101 broadcast_inputs: additional arguments to broadcast to all replicas. The
    102   broadcast inputs are appended to the per-replica inputs when calling
    103   computation.
    104 guaranteed_constants: arguments which have been guaranteed to not
    105 change their values during the session lifetime. These contain tensors marked as
    106 constant using the GuaranteeConstOp.
    107 output_types: the types of the outputs of 'computation'.
    108 outputs: the outputs of 'computation'.
    109 )doc");
    110 
    111 }  // namespace tensorflow
    112