Home | History | Annotate | Download | only in ops
      1 // Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 // ============================================================================
     15 
     16 #include "tensorflow/core/framework/common_shape_fns.h"
     17 #include "tensorflow/core/framework/node_def_util.h"
     18 #include "tensorflow/core/framework/op.h"
     19 #include "tensorflow/core/framework/resource_mgr.h"
     20 #include "tensorflow/core/framework/shape_inference.h"
     21 
     22 using ::tensorflow::shape_inference::InferenceContext;
     23 using ::tensorflow::shape_inference::ShapeAndType;
     24 using ::tensorflow::shape_inference::ShapeHandle;
     25 
     26 namespace tensorflow {
     27 
     28 namespace {
     29 
     30 Status ValidateVariableResourceHandle(InferenceContext* c,
     31                                       ShapeAndType* shape_and_type) {
     32   auto* handle_data = c->input_handle_shapes_and_types(0);
     33   if (handle_data == nullptr || handle_data->empty()) {
     34     shape_and_type->shape = c->UnknownShape();
     35     shape_and_type->dtype = DT_INVALID;
     36   } else {
     37     *shape_and_type = (*handle_data)[0];
     38     DataType value_dtype;
     39     TF_RETURN_IF_ERROR(c->GetAttr("dtype", &value_dtype));
     40     if (shape_and_type->dtype != value_dtype) {
     41       return errors::InvalidArgument(
     42           "Trying to read variable with wrong dtype. "
     43           "Expected ",
     44           DataTypeString(shape_and_type->dtype), " got ",
     45           DataTypeString(value_dtype));
     46     }
     47   }
     48   return Status::OK();
     49 }
     50 
     51 Status ReadVariableShapeFn(InferenceContext* c) {
     52   ShapeAndType shape_and_type;
     53   TF_RETURN_IF_ERROR(ValidateVariableResourceHandle(c, &shape_and_type));
     54   c->set_output(0, shape_and_type.shape);
     55   return Status::OK();
     56 }
     57 
     58 }  // namespace
     59 
     60 REGISTER_OP("VarHandleOp")
     61     .Attr("container: string = ''")
     62     .Attr("shared_name: string = ''")
     63     .Attr("dtype: type")
     64     .Attr("shape: shape")
     65     .Output("resource: resource")
     66     .SetIsStateful()
     67     .SetShapeFn([](InferenceContext* c) {
     68       c->set_output(0, c->Scalar());
     69       DataType t;
     70       TF_RETURN_IF_ERROR(c->GetAttr("dtype", &t));
     71       PartialTensorShape p;
     72       TF_RETURN_IF_ERROR(c->GetAttr("shape", &p));
     73       ShapeHandle s;
     74       TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(p, &s));
     75       c->set_output_handle_shapes_and_types(0,
     76                                             std::vector<ShapeAndType>{{s, t}});
     77 
     78       return Status::OK();
     79     });
     80 
     81 REGISTER_OP("ReadVariableOp")
     82     .Input("resource: resource")
     83     .Output("value: dtype")
     84     .Attr("dtype: type")
     85     .SetShapeFn(ReadVariableShapeFn);
     86 
     87 REGISTER_OP("DestroyResourceOp")
     88     .Input("resource: resource")
     89     .Attr("ignore_lookup_error: bool = true")
     90     .SetIsStateful()
     91     .SetShapeFn(shape_inference::NoOutputs);
     92 
     93 Status CreateAssignShapeFn(InferenceContext* c) {
     94   ShapeAndType handle_shape_and_type;
     95   TF_RETURN_IF_ERROR(ValidateVariableResourceHandle(c, &handle_shape_and_type));
     96 
     97   ShapeHandle value_shape = c->input(1);
     98   ShapeHandle unused;
     99   TF_RETURN_IF_ERROR(
    100       c->Merge(handle_shape_and_type.shape, value_shape, &unused));
    101   return Status::OK();
    102 }
    103 
    104 REGISTER_OP("AssignVariableOp")
    105     .Input("resource: resource")
    106     .Input("value: dtype")
    107     .Attr("dtype: type")
    108     .SetShapeFn(CreateAssignShapeFn);
    109 
    110 REGISTER_OP("AssignAddVariableOp")
    111     .Input("resource: resource")
    112     .Input("value: dtype")
    113     .Attr("dtype: type")
    114     .SetShapeFn(CreateAssignShapeFn);
    115 
    116 REGISTER_OP("AssignSubVariableOp")
    117     .Input("resource: resource")
    118     .Input("value: dtype")
    119     .Attr("dtype: type")
    120     .SetShapeFn(CreateAssignShapeFn);
    121 
    122 REGISTER_OP("VarIsInitializedOp")
    123     .Input("resource: resource")
    124     .Output("is_initialized: bool")
    125     .SetShapeFn(tensorflow::shape_inference::ScalarShape);
    126 
    127 Status VariableShapeShapeFn(InferenceContext* c) {
    128   auto* handle_data = c->input_handle_shapes_and_types(0);
    129   if (handle_data == nullptr || handle_data->empty()) {
    130     return errors::InvalidArgument("Handle doesn't have shape information.");
    131   }
    132   ShapeHandle var_shape = (*handle_data)[0].shape;
    133   int64 rank = c->RankKnown(var_shape) ? c->Rank(var_shape)
    134                                        : InferenceContext::kUnknownDim;
    135   c->set_output(0, c->Vector(rank));
    136   return Status::OK();
    137 }
    138 
    139 REGISTER_OP("VariableShape")
    140     .Input("input: resource")
    141     .Output("output: out_type")
    142     .Attr("out_type: {int32, int64} = DT_INT32")
    143     .SetShapeFn(VariableShapeShapeFn);
    144 
    145 REGISTER_OP("ResourceGather")
    146     .Input("resource: resource")
    147     .Input("indices: Tindices")
    148     .Attr("validate_indices: bool = true")
    149     .Output("output: dtype")
    150     .Attr("dtype: type")
    151     .Attr("Tindices: {int32,int64}")
    152     .SetShapeFn([](InferenceContext* c) {
    153       ShapeAndType handle_shape_and_type;
    154       TF_RETURN_IF_ERROR(
    155           ValidateVariableResourceHandle(c, &handle_shape_and_type));
    156 
    157       ShapeHandle unused;
    158       TF_RETURN_IF_ERROR(
    159           c->WithRankAtLeast(handle_shape_and_type.shape, 1, &unused));
    160       ShapeHandle params_subshape;
    161       TF_RETURN_IF_ERROR(
    162           c->Subshape(handle_shape_and_type.shape, 1, &params_subshape));
    163       ShapeHandle indices_shape = c->input(1);
    164       ShapeHandle out;
    165       TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, params_subshape, &out));
    166       c->set_output(0, out);
    167       return Status::OK();
    168     });
    169 
    170 REGISTER_OP("ResourceScatterAdd")
    171     .Input("resource: resource")
    172     .Input("indices: Tindices")
    173     .Input("updates: dtype")
    174     .Attr("dtype: numbertype")
    175     .Attr("Tindices: {int32, int64}")
    176     .SetShapeFn([](InferenceContext* c) {
    177       ShapeAndType handle_shape_and_type;
    178       TF_RETURN_IF_ERROR(
    179           ValidateVariableResourceHandle(c, &handle_shape_and_type));
    180       ShapeHandle var_shape = handle_shape_and_type.shape;
    181       ShapeHandle indices_shape = c->input(1);
    182 
    183       ShapeHandle unused_updates_shape;
    184       ShapeHandle concat;
    185       ShapeHandle var_subshape;
    186       TF_RETURN_IF_ERROR(c->Subshape(var_shape, 1, &var_subshape));
    187       TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, var_subshape, &concat));
    188       TF_RETURN_IF_ERROR(c->Merge(c->input(2), concat, &unused_updates_shape));
    189       return Status::OK();
    190     });
    191 
    192 REGISTER_OP("ResourceScatterUpdate")
    193     .Input("resource: resource")
    194     .Input("indices: Tindices")
    195     .Input("updates: dtype")
    196     .Attr("dtype: type")
    197     .Attr("Tindices: {int32, int64}")
    198     .SetShapeFn([](InferenceContext* c) {
    199       ShapeAndType handle_shape_and_type;
    200       TF_RETURN_IF_ERROR(
    201           ValidateVariableResourceHandle(c, &handle_shape_and_type));
    202       ShapeHandle var_shape = handle_shape_and_type.shape;
    203       ShapeHandle indices_shape = c->input(1);
    204 
    205       ShapeHandle unused_updates_shape;
    206       ShapeHandle concat;
    207       ShapeHandle var_subshape;
    208       TF_RETURN_IF_ERROR(c->Subshape(var_shape, 1, &var_subshape));
    209       TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, var_subshape, &concat));
    210       TF_RETURN_IF_ERROR(c->Merge(c->input(2), concat, &unused_updates_shape));
    211       return Status::OK();
    212     });
    213 
    214 REGISTER_OP("CriticalSectionOp")
    215     .Attr("container: string = ''")
    216     .Attr("shared_name: string = ''")
    217     .Output("resource: resource")
    218     .SetIsStateful()
    219     .SetShapeFn([](InferenceContext* c) {
    220       c->set_output(0, c->Scalar());
    221       return Status::OK();
    222     });
    223 
    224 REGISTER_OP("ExecuteInCriticalSection")
    225     .Input("critical_section: resource")
    226     .Input("arguments: Targuments")
    227     .Output("outputs: output_types")
    228     .Attr("f: func")
    229     .Attr("Targuments: list(type) >= 0")
    230     .Attr("output_types: list(type) >= 0")
    231     .Attr("output_shapes: list(shape) >= 0")
    232     .SetShapeFn([](InferenceContext* c) {
    233       std::vector<PartialTensorShape> output_shapes;
    234       TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
    235       for (int i = 0; i < output_shapes.size(); ++i) {
    236         ShapeHandle s;
    237         TF_RETURN_IF_ERROR(
    238             c->MakeShapeFromPartialTensorShape(output_shapes[i], &s));
    239         c->set_output(i, s);
    240       }
    241       return Status::OK();
    242     });
    243 
    244 }  // namespace tensorflow
    245