1 // Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // ============================================================================ 15 16 #include "tensorflow/core/framework/common_shape_fns.h" 17 #include "tensorflow/core/framework/node_def_util.h" 18 #include "tensorflow/core/framework/op.h" 19 #include "tensorflow/core/framework/resource_mgr.h" 20 #include "tensorflow/core/framework/shape_inference.h" 21 22 using ::tensorflow::shape_inference::InferenceContext; 23 using ::tensorflow::shape_inference::ShapeAndType; 24 using ::tensorflow::shape_inference::ShapeHandle; 25 26 namespace tensorflow { 27 28 namespace { 29 30 Status ValidateVariableResourceHandle(InferenceContext* c, 31 ShapeAndType* shape_and_type) { 32 auto* handle_data = c->input_handle_shapes_and_types(0); 33 if (handle_data == nullptr || handle_data->empty()) { 34 shape_and_type->shape = c->UnknownShape(); 35 shape_and_type->dtype = DT_INVALID; 36 } else { 37 *shape_and_type = (*handle_data)[0]; 38 DataType value_dtype; 39 TF_RETURN_IF_ERROR(c->GetAttr("dtype", &value_dtype)); 40 if (shape_and_type->dtype != value_dtype) { 41 return errors::InvalidArgument( 42 "Trying to read variable with wrong dtype. " 43 "Expected ", 44 DataTypeString(shape_and_type->dtype), " got ", 45 DataTypeString(value_dtype)); 46 } 47 } 48 return Status::OK(); 49 } 50 51 Status ReadVariableShapeFn(InferenceContext* c) { 52 ShapeAndType shape_and_type; 53 TF_RETURN_IF_ERROR(ValidateVariableResourceHandle(c, &shape_and_type)); 54 c->set_output(0, shape_and_type.shape); 55 return Status::OK(); 56 } 57 58 } // namespace 59 60 REGISTER_OP("VarHandleOp") 61 .Attr("container: string = ''") 62 .Attr("shared_name: string = ''") 63 .Attr("dtype: type") 64 .Attr("shape: shape") 65 .Output("resource: resource") 66 .SetIsStateful() 67 .SetShapeFn([](InferenceContext* c) { 68 c->set_output(0, c->Scalar()); 69 DataType t; 70 TF_RETURN_IF_ERROR(c->GetAttr("dtype", &t)); 71 PartialTensorShape p; 72 TF_RETURN_IF_ERROR(c->GetAttr("shape", &p)); 73 ShapeHandle s; 74 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(p, &s)); 75 c->set_output_handle_shapes_and_types(0, 76 std::vector<ShapeAndType>{{s, t}}); 77 78 return Status::OK(); 79 }); 80 81 REGISTER_OP("ReadVariableOp") 82 .Input("resource: resource") 83 .Output("value: dtype") 84 .Attr("dtype: type") 85 .SetShapeFn(ReadVariableShapeFn); 86 87 REGISTER_OP("DestroyResourceOp") 88 .Input("resource: resource") 89 .Attr("ignore_lookup_error: bool = true") 90 .SetIsStateful() 91 .SetShapeFn(shape_inference::NoOutputs); 92 93 Status CreateAssignShapeFn(InferenceContext* c) { 94 ShapeAndType handle_shape_and_type; 95 TF_RETURN_IF_ERROR(ValidateVariableResourceHandle(c, &handle_shape_and_type)); 96 97 ShapeHandle value_shape = c->input(1); 98 ShapeHandle unused; 99 TF_RETURN_IF_ERROR( 100 c->Merge(handle_shape_and_type.shape, value_shape, &unused)); 101 return Status::OK(); 102 } 103 104 REGISTER_OP("AssignVariableOp") 105 .Input("resource: resource") 106 .Input("value: dtype") 107 .Attr("dtype: type") 108 .SetShapeFn(CreateAssignShapeFn); 109 110 REGISTER_OP("AssignAddVariableOp") 111 .Input("resource: resource") 112 .Input("value: dtype") 113 .Attr("dtype: type") 114 .SetShapeFn(CreateAssignShapeFn); 115 116 REGISTER_OP("AssignSubVariableOp") 117 .Input("resource: resource") 118 .Input("value: dtype") 119 .Attr("dtype: type") 120 .SetShapeFn(CreateAssignShapeFn); 121 122 REGISTER_OP("VarIsInitializedOp") 123 .Input("resource: resource") 124 .Output("is_initialized: bool") 125 .SetShapeFn(tensorflow::shape_inference::ScalarShape); 126 127 Status VariableShapeShapeFn(InferenceContext* c) { 128 auto* handle_data = c->input_handle_shapes_and_types(0); 129 if (handle_data == nullptr || handle_data->empty()) { 130 return errors::InvalidArgument("Handle doesn't have shape information."); 131 } 132 ShapeHandle var_shape = (*handle_data)[0].shape; 133 int64 rank = c->RankKnown(var_shape) ? c->Rank(var_shape) 134 : InferenceContext::kUnknownDim; 135 c->set_output(0, c->Vector(rank)); 136 return Status::OK(); 137 } 138 139 REGISTER_OP("VariableShape") 140 .Input("input: resource") 141 .Output("output: out_type") 142 .Attr("out_type: {int32, int64} = DT_INT32") 143 .SetShapeFn(VariableShapeShapeFn); 144 145 REGISTER_OP("ResourceGather") 146 .Input("resource: resource") 147 .Input("indices: Tindices") 148 .Attr("validate_indices: bool = true") 149 .Output("output: dtype") 150 .Attr("dtype: type") 151 .Attr("Tindices: {int32,int64}") 152 .SetShapeFn([](InferenceContext* c) { 153 ShapeAndType handle_shape_and_type; 154 TF_RETURN_IF_ERROR( 155 ValidateVariableResourceHandle(c, &handle_shape_and_type)); 156 157 ShapeHandle unused; 158 TF_RETURN_IF_ERROR( 159 c->WithRankAtLeast(handle_shape_and_type.shape, 1, &unused)); 160 ShapeHandle params_subshape; 161 TF_RETURN_IF_ERROR( 162 c->Subshape(handle_shape_and_type.shape, 1, ¶ms_subshape)); 163 ShapeHandle indices_shape = c->input(1); 164 ShapeHandle out; 165 TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, params_subshape, &out)); 166 c->set_output(0, out); 167 return Status::OK(); 168 }); 169 170 REGISTER_OP("ResourceScatterAdd") 171 .Input("resource: resource") 172 .Input("indices: Tindices") 173 .Input("updates: dtype") 174 .Attr("dtype: numbertype") 175 .Attr("Tindices: {int32, int64}") 176 .SetShapeFn([](InferenceContext* c) { 177 ShapeAndType handle_shape_and_type; 178 TF_RETURN_IF_ERROR( 179 ValidateVariableResourceHandle(c, &handle_shape_and_type)); 180 ShapeHandle var_shape = handle_shape_and_type.shape; 181 ShapeHandle indices_shape = c->input(1); 182 183 ShapeHandle unused_updates_shape; 184 ShapeHandle concat; 185 ShapeHandle var_subshape; 186 TF_RETURN_IF_ERROR(c->Subshape(var_shape, 1, &var_subshape)); 187 TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, var_subshape, &concat)); 188 TF_RETURN_IF_ERROR(c->Merge(c->input(2), concat, &unused_updates_shape)); 189 return Status::OK(); 190 }); 191 192 REGISTER_OP("ResourceScatterUpdate") 193 .Input("resource: resource") 194 .Input("indices: Tindices") 195 .Input("updates: dtype") 196 .Attr("dtype: type") 197 .Attr("Tindices: {int32, int64}") 198 .SetShapeFn([](InferenceContext* c) { 199 ShapeAndType handle_shape_and_type; 200 TF_RETURN_IF_ERROR( 201 ValidateVariableResourceHandle(c, &handle_shape_and_type)); 202 ShapeHandle var_shape = handle_shape_and_type.shape; 203 ShapeHandle indices_shape = c->input(1); 204 205 ShapeHandle unused_updates_shape; 206 ShapeHandle concat; 207 ShapeHandle var_subshape; 208 TF_RETURN_IF_ERROR(c->Subshape(var_shape, 1, &var_subshape)); 209 TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, var_subshape, &concat)); 210 TF_RETURN_IF_ERROR(c->Merge(c->input(2), concat, &unused_updates_shape)); 211 return Status::OK(); 212 }); 213 214 REGISTER_OP("CriticalSectionOp") 215 .Attr("container: string = ''") 216 .Attr("shared_name: string = ''") 217 .Output("resource: resource") 218 .SetIsStateful() 219 .SetShapeFn([](InferenceContext* c) { 220 c->set_output(0, c->Scalar()); 221 return Status::OK(); 222 }); 223 224 REGISTER_OP("ExecuteInCriticalSection") 225 .Input("critical_section: resource") 226 .Input("arguments: Targuments") 227 .Output("outputs: output_types") 228 .Attr("f: func") 229 .Attr("Targuments: list(type) >= 0") 230 .Attr("output_types: list(type) >= 0") 231 .Attr("output_shapes: list(shape) >= 0") 232 .SetShapeFn([](InferenceContext* c) { 233 std::vector<PartialTensorShape> output_shapes; 234 TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); 235 for (int i = 0; i < output_shapes.size(); ++i) { 236 ShapeHandle s; 237 TF_RETURN_IF_ERROR( 238 c->MakeShapeFromPartialTensorShape(output_shapes[i], &s)); 239 c->set_output(i, s); 240 } 241 return Status::OK(); 242 }); 243 244 } // namespace tensorflow 245