1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/core/framework/common_shape_fns.h" 16 #include "tensorflow/core/framework/attr_value.pb.h" 17 #include "tensorflow/core/lib/core/errors.h" 18 19 namespace tensorflow { 20 21 Status GetWindowedOutputSizeVerboseV2(int64 input_size, int64 filter_size, 22 int64 dilation_rate, int64 stride, 23 Padding padding_type, int64* output_size, 24 int64* padding_before, 25 int64* padding_after) { 26 if (stride <= 0) { 27 return errors::InvalidArgument("Stride must be > 0, but got ", stride); 28 } 29 if (dilation_rate < 1) { 30 return errors::InvalidArgument("Dilation rate must be >= 1, but got ", 31 dilation_rate); 32 } 33 34 // See also the parallel implementation in GetWindowedOutputSizeFromDimsV2. 35 int64 effective_filter_size = (filter_size - 1) * dilation_rate + 1; 36 switch (padding_type) { 37 case Padding::VALID: 38 *output_size = (input_size - effective_filter_size + stride) / stride; 39 *padding_before = *padding_after = 0; 40 break; 41 case Padding::EXPLICIT: 42 *output_size = (input_size + *padding_before + *padding_after - 43 effective_filter_size + stride) / 44 stride; 45 break; 46 case Padding::SAME: 47 *output_size = (input_size + stride - 1) / stride; 48 const int64 padding_needed = 49 std::max(int64{0}, (*output_size - 1) * stride + 50 effective_filter_size - input_size); 51 // For odd values of total padding, add more padding at the 'right' 52 // side of the given dimension. 53 *padding_before = padding_needed / 2; 54 *padding_after = padding_needed - *padding_before; 55 break; 56 } 57 if (*output_size < 0) { 58 return errors::InvalidArgument( 59 "Computed output size would be negative: ", *output_size, 60 " [input_size: ", input_size, 61 ", effective_filter_size: ", effective_filter_size, 62 ", stride: ", stride, "]"); 63 } 64 return Status::OK(); 65 } 66 67 Status GetWindowedOutputSizeVerbose(int64 input_size, int64 filter_size, 68 int64 stride, Padding padding_type, 69 int64* output_size, int64* padding_before, 70 int64* padding_after) { 71 return GetWindowedOutputSizeVerboseV2(input_size, filter_size, 72 /*dilation_rate=*/1, stride, 73 padding_type, output_size, 74 padding_before, padding_after); 75 } 76 77 Status GetWindowedOutputSize(int64 input_size, int64 filter_size, int64 stride, 78 Padding padding_type, int64* output_size, 79 int64* padding_size) { 80 if (padding_type == Padding::EXPLICIT) { 81 return errors::Internal( 82 "GetWindowedOutputSize does not handle EXPLICIT padding; call " 83 "GetWindowedOutputSizeVerbose instead"); 84 } 85 int64 padding_after_unused; 86 return GetWindowedOutputSizeVerbose(input_size, filter_size, stride, 87 padding_type, output_size, padding_size, 88 &padding_after_unused); 89 } 90 91 Status GetWindowedOutputSizeV2(int64 input_size, int64 filter_size, 92 int64 dilation_rate, int64 stride, 93 Padding padding_type, int64* output_size, 94 int64* padding_size) { 95 if (padding_type == Padding::EXPLICIT) { 96 return errors::Internal( 97 "GetWindowedOutputSizeV2 does not handle EXPLICIT padding; call " 98 "GetWindowedOutputSizeVerboseV2 instead"); 99 } 100 int64 padding_after_unused; 101 return GetWindowedOutputSizeVerboseV2(input_size, filter_size, dilation_rate, 102 stride, padding_type, output_size, 103 padding_size, &padding_after_unused); 104 } 105 106 Status Get3dOutputSize(const std::array<int64, 3>& input, 107 const std::array<int64, 3>& window, 108 const std::array<int64, 3>& strides, 109 Padding padding_type, std::array<int64, 3>* output_ptr, 110 std::array<int64, 3>* padding_ptr) { 111 for (size_t i = 0; i < input.size(); ++i) { 112 TF_RETURN_IF_ERROR(GetWindowedOutputSize(input[i], window[i], strides[i], 113 padding_type, &(*output_ptr)[i], 114 &(*padding_ptr)[i])); 115 } 116 return Status::OK(); 117 } 118 119 Status Get3dOutputSizeV2(const std::array<int64, 3>& input, 120 const std::array<int64, 3>& window, 121 const std::array<int64, 3>& dilations, 122 const std::array<int64, 3>& strides, 123 Padding padding_type, std::array<int64, 3>* output_ptr, 124 std::array<int64, 3>* padding_ptr) { 125 for (size_t i = 0; i < input.size(); ++i) { 126 TF_RETURN_IF_ERROR(GetWindowedOutputSizeV2( 127 input[i], window[i], dilations[i], strides[i], padding_type, 128 &(*output_ptr)[i], &(*padding_ptr)[i])); 129 } 130 return Status::OK(); 131 } 132 133 namespace shape_inference { 134 135 // The V2 version computes windowed output size with arbitrary dilation_rate, 136 // while the original version only handles the cases where dilation_rates equal 137 // to 1. 138 Status GetWindowedOutputSizeFromDimsV2( 139 shape_inference::InferenceContext* c, 140 shape_inference::DimensionHandle input_size, 141 shape_inference::DimensionOrConstant filter_size, int64 dilation_rate, 142 int64 stride, Padding padding_type, int64 padding_before, 143 int64 padding_after, shape_inference::DimensionHandle* output_size) { 144 if (stride <= 0) { 145 return errors::InvalidArgument("Stride must be > 0, but got ", stride); 146 } 147 148 if (dilation_rate < 1) { 149 return errors::InvalidArgument("Dilation rate must be >= 1, but got ", 150 dilation_rate); 151 } 152 153 // See also the parallel implementation in GetWindowedOutputSizeVerbose. 154 switch (padding_type) { 155 case Padding::VALID: 156 padding_before = padding_after = 0; 157 TF_FALLTHROUGH_INTENDED; 158 case Padding::EXPLICIT: 159 TF_RETURN_IF_ERROR( 160 c->Add(input_size, padding_before + padding_after, &input_size)); 161 if (dilation_rate > 1) { 162 DimensionHandle window_size; 163 TF_RETURN_IF_ERROR( 164 c->Subtract(c->MakeDim(filter_size), 1, &window_size)); 165 TF_RETURN_IF_ERROR( 166 c->Multiply(window_size, dilation_rate, &window_size)); 167 TF_RETURN_IF_ERROR(c->Add(window_size, 1, &window_size)); 168 TF_RETURN_IF_ERROR(c->Subtract(input_size, window_size, output_size)); 169 } else { 170 TF_RETURN_IF_ERROR(c->Subtract(input_size, filter_size, output_size)); 171 } 172 TF_RETURN_IF_ERROR(c->Add(*output_size, stride, output_size)); 173 TF_RETURN_IF_ERROR(c->Divide(*output_size, stride, 174 /*evenly_divisible=*/false, output_size)); 175 break; 176 case Padding::SAME: 177 TF_RETURN_IF_ERROR(c->Add(input_size, stride - 1, output_size)); 178 TF_RETURN_IF_ERROR(c->Divide(*output_size, stride, 179 /*evenly_divisible=*/false, output_size)); 180 break; 181 } 182 return Status::OK(); 183 } 184 185 Status GetWindowedOutputSizeFromDims( 186 shape_inference::InferenceContext* c, 187 shape_inference::DimensionHandle input_size, 188 shape_inference::DimensionOrConstant filter_size, int64 stride, 189 Padding padding_type, shape_inference::DimensionHandle* output_size) { 190 if (padding_type == Padding::EXPLICIT) { 191 return errors::Internal( 192 "GetWindowedOutputSizeFromDims does not handle EXPLICIT padding; call " 193 "GetWindowedOutputSizeFromDimsV2 instead"); 194 } 195 return GetWindowedOutputSizeFromDimsV2(c, input_size, filter_size, 196 /*dilation_rate=*/1, stride, 197 padding_type, 198 // Give dummy values of -1 to 199 // padding_before and padding_after, 200 // since explicit padding is not used. 201 -1, -1, output_size); 202 } 203 204 Status UnchangedShape(shape_inference::InferenceContext* c) { 205 c->set_output(0, c->input(0)); 206 auto* handle_data = c->input_handle_shapes_and_types(0); 207 if (handle_data != nullptr) { 208 c->set_output_handle_shapes_and_types(0, *handle_data); 209 } 210 return Status::OK(); 211 } 212 213 Status MatMulShape(shape_inference::InferenceContext* c) { 214 ShapeHandle a; 215 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &a)); 216 217 ShapeHandle b; 218 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &b)); 219 220 bool transpose_a, transpose_b; 221 TF_RETURN_IF_ERROR(c->GetAttr("transpose_a", &transpose_a)); 222 TF_RETURN_IF_ERROR(c->GetAttr("transpose_b", &transpose_b)); 223 DimensionHandle output_rows = transpose_a ? c->Dim(a, 1) : c->Dim(a, 0); 224 DimensionHandle output_cols = transpose_b ? c->Dim(b, 0) : c->Dim(b, 1); 225 226 // Validate that the inner shapes are compatible. 227 DimensionHandle inner_a = transpose_a ? c->Dim(a, 0) : c->Dim(a, 1); 228 DimensionHandle inner_b = transpose_b ? c->Dim(b, 1) : c->Dim(b, 0); 229 DimensionHandle merged; 230 TF_RETURN_IF_ERROR(c->Merge(inner_a, inner_b, &merged)); 231 232 c->set_output(0, c->Matrix(output_rows, output_cols)); 233 return Status::OK(); 234 } 235 236 Status BiasAddShape(shape_inference::InferenceContext* c) { 237 ShapeHandle input_shape; 238 239 // Fetch the data_format attribute, which may not exist. 240 string data_format; 241 Status s = c->GetAttr("data_format", &data_format); 242 243 if (s.ok() && data_format == "NCHW") { 244 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &input_shape)); 245 } else { 246 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape)); 247 } 248 249 ShapeHandle bias_shape; 250 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &bias_shape)); 251 DimensionHandle bias_dim = c->Dim(bias_shape, 0); 252 253 // If rank unknown, return unknown shape. 254 if (!c->RankKnown(input_shape)) { 255 c->set_output(0, c->UnknownShape()); 256 return Status::OK(); 257 } 258 259 // Output has the same shape as the input, and matches the length of 260 // the bias in its bias dimension. 261 ShapeHandle output_shape; 262 if (s.ok() && data_format == "NCHW") { 263 // Merge the length of bias_shape into the third to last dimension 264 ShapeHandle first; 265 TF_RETURN_IF_ERROR(c->Subshape(input_shape, 0, 1, &first)); 266 267 ShapeHandle last; 268 TF_RETURN_IF_ERROR(c->Subshape(input_shape, 2, &last)); 269 270 DimensionHandle input_bias_dim = c->Dim(input_shape, 1); 271 DimensionHandle merged_bias_dim; 272 TF_RETURN_IF_ERROR(c->Merge(input_bias_dim, bias_dim, &merged_bias_dim)); 273 ShapeHandle merged_bias = c->Vector(merged_bias_dim); 274 275 ShapeHandle temp; 276 TF_RETURN_IF_ERROR(c->Concatenate(first, merged_bias, &temp)); 277 TF_RETURN_IF_ERROR(c->Concatenate(temp, last, &output_shape)); 278 } else { 279 ShapeHandle all_but_bias; 280 TF_RETURN_IF_ERROR(c->Subshape(input_shape, 0, -1, &all_but_bias)); 281 282 DimensionHandle input_bias_dim = c->Dim(input_shape, -1); 283 DimensionHandle merged_bias_dim; 284 TF_RETURN_IF_ERROR(c->Merge(input_bias_dim, bias_dim, &merged_bias_dim)); 285 286 ShapeHandle merged_bias = c->Vector(merged_bias_dim); 287 TF_RETURN_IF_ERROR( 288 c->Concatenate(all_but_bias, merged_bias, &output_shape)); 289 } 290 291 c->set_output(0, output_shape); 292 return Status::OK(); 293 } 294 295 Status BiasAddGradShape(shape_inference::InferenceContext* c) { 296 ShapeHandle input_shape; 297 // Fetch the data_format attribute, which may not exist. 298 string data_format; 299 Status s = c->GetAttr("data_format", &data_format); 300 301 if (s.ok() && data_format == "NCHW") { 302 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &input_shape)); 303 c->set_output(0, c->Vector(c->Dim(input_shape, 1))); 304 } else { 305 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape)); 306 c->set_output(0, c->Vector(c->Dim(input_shape, -1))); 307 } 308 309 return Status::OK(); 310 } 311 312 Status CheckFormatConstraintsOnShape(const TensorFormat tensor_format, 313 const ShapeHandle shape_handle, 314 const string& tensor_name, 315 shape_inference::InferenceContext* c) { 316 if (tensor_format == FORMAT_NCHW_VECT_C) { 317 // Check that the vect dim has size 4. 318 const int num_dims = c->Rank(shape_handle); 319 DimensionHandle vect_dim = c->Dim( 320 shape_handle, GetTensorInnerFeatureDimIndex(num_dims, tensor_format)); 321 DimensionHandle unused_vect_dim; 322 TF_RETURN_IF_ERROR(c->WithValue(vect_dim, 4, &unused_vect_dim)); 323 } 324 325 return Status::OK(); 326 } 327 328 Status MakeShapeFromFormat(TensorFormat format, DimensionOrConstant N, 329 const std::vector<DimensionOrConstant>& spatial, 330 DimensionOrConstant C, ShapeHandle* out, 331 shape_inference::InferenceContext* context) { 332 const int num_dims = GetTensorDimsFromSpatialDims(spatial.size(), format); 333 std::vector<DimensionHandle> dims_actual(num_dims); 334 dims_actual[GetTensorBatchDimIndex(num_dims, format)] = context->MakeDim(N); 335 int outer_c_index = GetTensorFeatureDimIndex(num_dims, format); 336 dims_actual[outer_c_index] = context->MakeDim(C); 337 if (format == FORMAT_NCHW_VECT_C) { 338 dims_actual[GetTensorInnerFeatureDimIndex(num_dims, format)] = 339 context->MakeDim(4); 340 } else if (format == FORMAT_NHWC_VECT_W) { 341 dims_actual[GetTensorInnerWidthDimIndex(num_dims, format)] = 342 context->MakeDim(4); 343 } 344 for (int spatial_dim = 0; spatial_dim < spatial.size(); spatial_dim++) { 345 dims_actual[GetTensorSpatialDimIndex(num_dims, format, spatial_dim)] = 346 context->MakeDim(spatial[spatial_dim]); 347 } 348 *out = context->MakeShape(dims_actual); 349 return Status::OK(); 350 } 351 352 Status DimensionsFromShape(ShapeHandle shape, TensorFormat format, 353 DimensionHandle* batch_dim, 354 gtl::MutableArraySlice<DimensionHandle> spatial_dims, 355 DimensionHandle* filter_dim, 356 InferenceContext* context) { 357 const int32 rank = GetTensorDimsFromSpatialDims(spatial_dims.size(), format); 358 // Batch. 359 *batch_dim = context->Dim(shape, GetTensorBatchDimIndex(rank, format)); 360 // Spatial. 361 for (int spatial_dim_index = 0; spatial_dim_index < spatial_dims.size(); 362 ++spatial_dim_index) { 363 spatial_dims[spatial_dim_index] = context->Dim( 364 shape, GetTensorSpatialDimIndex(rank, format, spatial_dim_index)); 365 } 366 // Channel. 367 *filter_dim = context->Dim(shape, GetTensorFeatureDimIndex(rank, format)); 368 if (format == FORMAT_NCHW_VECT_C) { 369 TF_RETURN_IF_ERROR(context->Multiply( 370 *filter_dim, 371 context->Dim(shape, GetTensorInnerFeatureDimIndex(rank, format)), 372 filter_dim)); 373 } 374 return Status::OK(); 375 } 376 377 Status ShapeFromDimensions(DimensionHandle batch_dim, 378 gtl::ArraySlice<DimensionHandle> spatial_dims, 379 DimensionHandle filter_dim, TensorFormat format, 380 InferenceContext* context, ShapeHandle* shape) { 381 const int32 rank = GetTensorDimsFromSpatialDims(spatial_dims.size(), format); 382 std::vector<DimensionHandle> out_dims(rank); 383 384 // Batch. 385 out_dims[tensorflow::GetTensorBatchDimIndex(rank, format)] = batch_dim; 386 // Spatial. 387 for (int spatial_dim_index = 0; spatial_dim_index < spatial_dims.size(); 388 ++spatial_dim_index) { 389 out_dims[tensorflow::GetTensorSpatialDimIndex( 390 rank, format, spatial_dim_index)] = spatial_dims[spatial_dim_index]; 391 } 392 // Channel. 393 if (format == tensorflow::FORMAT_NCHW_VECT_C) { 394 // When format is NCHW_VECT_C, factor the feature map count 395 // into the outer feature count and the inner feature count (=4). 396 TF_RETURN_IF_ERROR(context->Divide( 397 filter_dim, 4, /*evenly_divisible=*/true, 398 &out_dims[tensorflow::GetTensorFeatureDimIndex(rank, format)])); 399 out_dims[GetTensorInnerFeatureDimIndex(rank, format)] = context->MakeDim(4); 400 } else { 401 out_dims[tensorflow::GetTensorFeatureDimIndex(rank, format)] = filter_dim; 402 } 403 404 *shape = context->MakeShape(out_dims); 405 return tensorflow::Status::OK(); 406 } 407 408 namespace { 409 410 Status Conv2DShapeImpl(shape_inference::InferenceContext* c, 411 bool supports_explicit_padding) { 412 string data_format_str, filter_format_str; 413 if (!c->GetAttr("data_format", &data_format_str).ok()) { 414 data_format_str = "NHWC"; 415 } 416 if (!c->GetAttr("filter_format", &filter_format_str).ok()) { 417 filter_format_str = "HWIO"; 418 } 419 420 TensorFormat data_format; 421 if (!FormatFromString(data_format_str, &data_format)) { 422 return errors::InvalidArgument("Invalid data format string: ", 423 data_format_str); 424 } 425 FilterTensorFormat filter_format; 426 if (!FilterFormatFromString(filter_format_str, &filter_format)) { 427 return errors::InvalidArgument("Invalid filter format string: ", 428 filter_format_str); 429 } 430 431 constexpr int num_spatial_dims = 2; 432 const int rank = GetTensorDimsFromSpatialDims(num_spatial_dims, data_format); 433 ShapeHandle conv_input_shape; 434 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &conv_input_shape)); 435 TF_RETURN_IF_ERROR(CheckFormatConstraintsOnShape( 436 data_format, conv_input_shape, "conv_input", c)); 437 438 // The filter rank should match the input (4 for NCHW, 5 for NCHW_VECT_C). 439 ShapeHandle filter_shape; 440 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &filter_shape)); 441 TF_RETURN_IF_ERROR( 442 CheckFormatConstraintsOnShape(data_format, filter_shape, "filter", c)); 443 444 std::vector<int32> dilations; 445 TF_RETURN_IF_ERROR(c->GetAttr("dilations", &dilations)); 446 447 if (dilations.size() != 4) { 448 return errors::InvalidArgument( 449 "Conv2D requires the dilation attribute to contain 4 values, but got: ", 450 dilations.size()); 451 } 452 453 std::vector<int32> strides; 454 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); 455 456 // strides.size() should be 4 (NCHW) even if the input is 5 (NCHW_VECT_C). 457 if (strides.size() != 4) { 458 return errors::InvalidArgument("Conv2D on data format ", data_format_str, 459 " requires the stride attribute to contain" 460 " 4 values, but got: ", 461 strides.size()); 462 } 463 464 const int32 stride_rows = GetTensorDim(strides, data_format, 'H'); 465 const int32 stride_cols = GetTensorDim(strides, data_format, 'W'); 466 const int32 dilation_rows = GetTensorDim(dilations, data_format, 'H'); 467 const int32 dilation_cols = GetTensorDim(dilations, data_format, 'W'); 468 469 DimensionHandle batch_size_dim; 470 DimensionHandle input_depth_dim; 471 gtl::InlinedVector<DimensionHandle, 2> input_spatial_dims(2); 472 TF_RETURN_IF_ERROR(DimensionsFromShape( 473 conv_input_shape, data_format, &batch_size_dim, 474 absl::MakeSpan(input_spatial_dims), &input_depth_dim, c)); 475 476 DimensionHandle output_depth_dim = c->Dim( 477 filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'O')); 478 DimensionHandle filter_rows_dim = c->Dim( 479 filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'H')); 480 DimensionHandle filter_cols_dim = c->Dim( 481 filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'W')); 482 DimensionHandle filter_input_depth_dim; 483 if (filter_format == FORMAT_OIHW_VECT_I) { 484 TF_RETURN_IF_ERROR(c->Multiply( 485 c->Dim(filter_shape, 486 GetFilterDimIndex<num_spatial_dims>(filter_format, 'I')), 487 c->Dim(filter_shape, 488 GetFilterTensorInnerInputChannelsDimIndex(rank, filter_format)), 489 &filter_input_depth_dim)); 490 } else { 491 filter_input_depth_dim = c->Dim( 492 filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'I')); 493 } 494 495 // Check that the input tensor and the filter tensor agree on the input 496 // channel count. 497 DimensionHandle unused; 498 TF_RETURN_IF_ERROR( 499 c->Merge(input_depth_dim, filter_input_depth_dim, &unused)); 500 501 Padding padding; 502 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); 503 504 std::vector<int64> explicit_paddings; 505 if (supports_explicit_padding) { 506 Status s = c->GetAttr("explicit_paddings", &explicit_paddings); 507 // Use the default value, which is an empty list, if the attribute is not 508 // found. Otherwise return the error to the caller. 509 if (!s.ok() && !errors::IsNotFound(s)) { 510 return s; 511 } 512 TF_RETURN_IF_ERROR(CheckValidPadding(padding, explicit_paddings, 513 /*num_dims=*/4, data_format)); 514 } else { 515 DCHECK(padding != Padding::EXPLICIT); 516 } 517 518 DimensionHandle output_rows, output_cols; 519 int64 pad_rows_before = -1, pad_rows_after = -1; 520 int64 pad_cols_before = -1, pad_cols_after = -1; 521 if (padding == Padding::EXPLICIT) { 522 GetExplicitPaddingForDim(explicit_paddings, data_format, 'H', 523 &pad_rows_before, &pad_rows_after); 524 GetExplicitPaddingForDim(explicit_paddings, data_format, 'W', 525 &pad_cols_before, &pad_cols_after); 526 } 527 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2( 528 c, input_spatial_dims[0], filter_rows_dim, dilation_rows, stride_rows, 529 padding, pad_rows_before, pad_rows_after, &output_rows)); 530 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2( 531 c, input_spatial_dims[1], filter_cols_dim, dilation_cols, stride_cols, 532 padding, pad_cols_before, pad_cols_after, &output_cols)); 533 534 ShapeHandle output_shape; 535 TF_RETURN_IF_ERROR( 536 ShapeFromDimensions(batch_size_dim, {output_rows, output_cols}, 537 output_depth_dim, data_format, c, &output_shape)); 538 c->set_output(0, output_shape); 539 return Status::OK(); 540 } 541 542 } // namespace 543 544 // Shape function for Conv2D-like operations that support explicit padding. 545 Status Conv2DShapeWithExplicitPadding(shape_inference::InferenceContext* c) { 546 return Conv2DShapeImpl(c, true); 547 } 548 549 // Shape function for Conv2D-like operations that do not support explicit 550 // padding. 551 Status Conv2DShape(shape_inference::InferenceContext* c) { 552 return Conv2DShapeImpl(c, false); 553 } 554 555 // TODO(mjanusz): Unify all conv/pooling shape functions. 556 Status Conv3DShape(shape_inference::InferenceContext* c) { 557 ShapeHandle input_shape; 558 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape)); 559 ShapeHandle filter_shape; 560 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 5, &filter_shape)); 561 562 string data_format; 563 Status s = c->GetAttr("data_format", &data_format); 564 565 std::vector<int32> dilations; 566 TF_RETURN_IF_ERROR(c->GetAttr("dilations", &dilations)); 567 568 if (dilations.size() != 5) { 569 return errors::InvalidArgument( 570 "Conv3D requires the dilation attribute to contain 5 values, but got: ", 571 dilations.size()); 572 } 573 574 std::vector<int32> strides; 575 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); 576 if (strides.size() != 5) { 577 return errors::InvalidArgument( 578 "Conv3D requires the stride attribute to contain 5 values, but got: ", 579 strides.size()); 580 } 581 582 int32 stride_planes, stride_rows, stride_cols; 583 int32 dilation_planes, dilation_rows, dilation_cols; 584 if (s.ok() && data_format == "NCDHW") { 585 // Convert input_shape to NDHWC. 586 auto dim = [&](char dimension) { 587 return c->Dim(input_shape, GetTensorDimIndex<3>(FORMAT_NCHW, dimension)); 588 }; 589 input_shape = 590 c->MakeShape({{dim('N'), dim('0'), dim('1'), dim('2'), dim('C')}}); 591 stride_planes = strides[2]; 592 stride_rows = strides[3]; 593 stride_cols = strides[4]; 594 dilation_planes = dilations[2]; 595 dilation_cols = dilations[3]; 596 dilation_rows = dilations[4]; 597 } else { 598 stride_planes = strides[1]; 599 stride_rows = strides[2]; 600 stride_cols = strides[3]; 601 dilation_planes = dilations[1]; 602 dilation_cols = dilations[2]; 603 dilation_rows = dilations[3]; 604 } 605 606 DimensionHandle batch_size_dim = c->Dim(input_shape, 0); 607 DimensionHandle in_planes_dim = c->Dim(input_shape, 1); 608 DimensionHandle in_rows_dim = c->Dim(input_shape, 2); 609 DimensionHandle in_cols_dim = c->Dim(input_shape, 3); 610 611 DimensionHandle filter_planes_dim = c->Dim(filter_shape, 0); 612 DimensionHandle filter_rows_dim = c->Dim(filter_shape, 1); 613 DimensionHandle filter_cols_dim = c->Dim(filter_shape, 2); 614 DimensionHandle output_depth_dim = c->Dim(filter_shape, 4); 615 616 DimensionHandle unused; 617 TF_RETURN_IF_ERROR( 618 c->Merge(c->Dim(input_shape, 4), c->Dim(filter_shape, 3), &unused)); 619 620 Padding padding; 621 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); 622 DimensionHandle output_planes, output_rows, output_cols; 623 624 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2( 625 c, in_planes_dim, filter_planes_dim, dilation_planes, stride_planes, 626 padding, -1, -1, &output_planes)); 627 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2( 628 c, in_rows_dim, filter_rows_dim, dilation_rows, stride_rows, padding, -1, 629 -1, &output_rows)); 630 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2( 631 c, in_cols_dim, filter_cols_dim, dilation_cols, stride_cols, padding, -1, 632 -1, &output_cols)); 633 634 ShapeHandle output_shape; 635 if (data_format == "NCDHW") { 636 output_shape = c->MakeShape({batch_size_dim, output_depth_dim, 637 output_planes, output_rows, output_cols}); 638 } else { 639 output_shape = c->MakeShape({batch_size_dim, output_planes, output_rows, 640 output_cols, output_depth_dim}); 641 } 642 c->set_output(0, output_shape); 643 return Status::OK(); 644 } 645 646 Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c) { 647 ShapeHandle input_shape; 648 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape)); 649 ShapeHandle filter_shape; 650 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &filter_shape)); 651 652 std::vector<int32> strides; 653 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); 654 655 if (strides.size() != 4) { 656 return errors::InvalidArgument( 657 "DepthwiseConv2D requires the stride attribute to contain 4 values, " 658 "but got: ", 659 strides.size()); 660 } 661 662 string data_format; 663 Status s = c->GetAttr("data_format", &data_format); 664 int32 stride_rows; 665 int32 stride_cols; 666 if (s.ok() && data_format == "NCHW") { 667 // Canonicalize input shape to NHWC so the shape inference code below can 668 // process it. 669 input_shape = 670 c->MakeShape({{c->Dim(input_shape, 0), c->Dim(input_shape, 2), 671 c->Dim(input_shape, 3), c->Dim(input_shape, 1)}}); 672 stride_rows = strides[2]; 673 stride_cols = strides[3]; 674 } else { 675 stride_rows = strides[1]; 676 stride_cols = strides[2]; 677 } 678 679 DimensionHandle batch_size_dim = c->Dim(input_shape, 0); 680 DimensionHandle in_rows_dim = c->Dim(input_shape, 1); 681 DimensionHandle in_cols_dim = c->Dim(input_shape, 2); 682 683 DimensionHandle filter_rows_dim = c->Dim(filter_shape, 0); 684 DimensionHandle filter_cols_dim = c->Dim(filter_shape, 1); 685 DimensionHandle input_depth = c->Dim(filter_shape, 2); 686 DimensionHandle depth_multiplier = c->Dim(filter_shape, 3); 687 688 // Check that the input depths are compatible. 689 TF_RETURN_IF_ERROR( 690 c->Merge(c->Dim(input_shape, 3), input_depth, &input_depth)); 691 692 DimensionHandle output_depth; 693 TF_RETURN_IF_ERROR(c->Multiply(input_depth, depth_multiplier, &output_depth)); 694 695 Padding padding; 696 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); 697 698 // TODO(mrry,shlens): Raise an error if the stride would cause 699 // information in the input to be ignored. This will require a change 700 // in the kernel implementation. 701 DimensionHandle output_rows, output_cols; 702 703 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( 704 c, in_rows_dim, filter_rows_dim, stride_rows, padding, &output_rows)); 705 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( 706 c, in_cols_dim, filter_cols_dim, stride_cols, padding, &output_cols)); 707 708 ShapeHandle output_shape; 709 if (data_format == "NCHW") { 710 output_shape = 711 c->MakeShape({batch_size_dim, output_depth, output_rows, output_cols}); 712 } else { 713 output_shape = 714 c->MakeShape({batch_size_dim, output_rows, output_cols, output_depth}); 715 } 716 c->set_output(0, output_shape); 717 return Status::OK(); 718 } 719 720 Status AvgPoolShape(shape_inference::InferenceContext* c) { 721 string data_format_str; 722 TensorFormat data_format; 723 Status s = c->GetAttr("data_format", &data_format_str); 724 if (s.ok()) { 725 FormatFromString(data_format_str, &data_format); 726 } else { 727 data_format = FORMAT_NHWC; 728 } 729 730 const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4; 731 ShapeHandle input_shape; 732 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape)); 733 734 TF_RETURN_IF_ERROR( 735 CheckFormatConstraintsOnShape(data_format, input_shape, "input", c)); 736 737 std::vector<int32> strides; 738 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); 739 if (strides.size() != 4) { 740 return errors::InvalidArgument( 741 "AvgPool requires the stride attribute to contain 4 values, but got: ", 742 strides.size()); 743 } 744 745 std::vector<int32> kernel_sizes; 746 TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes)); 747 if (kernel_sizes.size() != 4) { 748 return errors::InvalidArgument( 749 "AvgPool requires the ksize attribute to contain 4 values, but got: ", 750 kernel_sizes.size()); 751 } 752 753 int32 stride_rows = GetTensorDim(strides, data_format, 'H'); 754 int32 stride_cols = GetTensorDim(strides, data_format, 'W'); 755 int32 kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H'); 756 int32 kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W'); 757 758 constexpr int num_spatial_dims = 2; 759 DimensionHandle batch_size_dim = c->Dim( 760 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N')); 761 DimensionHandle in_rows_dim = c->Dim( 762 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H')); 763 DimensionHandle in_cols_dim = c->Dim( 764 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W')); 765 DimensionHandle depth_dim = c->Dim( 766 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C')); 767 768 Padding padding; 769 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); 770 771 // TODO(mrry,shlens): Raise an error if the stride would cause 772 // information in the input to be ignored. This will require a change 773 // in the kernel implementation. 774 775 DimensionHandle output_rows, output_cols; 776 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( 777 c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows)); 778 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( 779 c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols)); 780 781 ShapeHandle output_shape; 782 TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim, 783 {output_rows, output_cols}, depth_dim, 784 &output_shape, c)); 785 c->set_output(0, output_shape); 786 return Status::OK(); 787 } 788 789 Status FusedBatchNormShape(shape_inference::InferenceContext* c) { 790 ShapeHandle x; 791 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &x)); 792 793 bool is_training; 794 TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training)); 795 int number_inputs = (is_training) ? 3 : 5; 796 string data_format_str; 797 TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str)); 798 TensorFormat data_format; 799 if (!FormatFromString(data_format_str, &data_format)) { 800 return errors::InvalidArgument("Invalid data format string: ", 801 data_format_str); 802 } 803 int channel_dim_index = GetTensorFeatureDimIndex(4, data_format); 804 DimensionHandle channel_dim = c->Dim(x, channel_dim_index); 805 806 // covers scale, offset, and if is_training is false, mean, variance 807 for (int i = 1; i < number_inputs; ++i) { 808 ShapeHandle vec; 809 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec)); 810 TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim)); 811 } 812 813 ShapeHandle y; 814 TF_RETURN_IF_ERROR(c->ReplaceDim(x, channel_dim_index, channel_dim, &y)); 815 c->set_output(0, y); 816 ShapeHandle vector_shape = c->Vector(channel_dim); 817 c->set_output(1, vector_shape); 818 c->set_output(2, vector_shape); 819 c->set_output(3, vector_shape); 820 c->set_output(4, vector_shape); 821 return Status::OK(); 822 } 823 824 Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) { 825 ShapeHandle y_backprop; 826 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &y_backprop)); 827 ShapeHandle x; 828 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &x)); 829 830 bool is_training; 831 TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training)); 832 string data_format_str; 833 TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format_str)); 834 TensorFormat data_format; 835 if (!FormatFromString(data_format_str, &data_format)) { 836 return errors::InvalidArgument("Invalid data format string: ", 837 data_format_str); 838 } 839 int channel_dim_index = GetTensorFeatureDimIndex(4, data_format); 840 DimensionHandle channel_dim = c->Dim(y_backprop, channel_dim_index); 841 TF_RETURN_IF_ERROR( 842 c->Merge(channel_dim, c->Dim(x, channel_dim_index), &channel_dim)); 843 844 // covers scale, mean (reserve_space_1), variance (reserve_space_2) 845 for (int i = 2; i < 5; ++i) { 846 ShapeHandle vec; 847 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec)); 848 TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim)); 849 } 850 851 ShapeHandle x_backprop; 852 TF_RETURN_IF_ERROR( 853 c->ReplaceDim(y_backprop, channel_dim_index, channel_dim, &x_backprop)); 854 c->set_output(0, x_backprop); 855 c->set_output(1, c->Vector(channel_dim)); 856 c->set_output(2, c->Vector(channel_dim)); 857 // Set the correct shapes for reserve_spaces 858 // so that gradients can be performed when 859 // the op is in a symbolic condition. 860 if (is_training) { 861 c->set_output(3, c->Vector(0)); 862 c->set_output(4, c->Vector(0)); 863 } else { 864 c->set_output(3, c->Vector(channel_dim)); 865 c->set_output(4, c->Vector(channel_dim)); 866 } 867 return Status::OK(); 868 } 869 870 Status MaxPoolShape(shape_inference::InferenceContext* c) { 871 string data_format_str; 872 TensorFormat data_format; 873 Status s = c->GetAttr("data_format", &data_format_str); 874 if (s.ok()) { 875 FormatFromString(data_format_str, &data_format); 876 } else { 877 data_format = FORMAT_NHWC; 878 } 879 880 const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4; 881 ShapeHandle input_shape; 882 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape)); 883 884 TF_RETURN_IF_ERROR( 885 CheckFormatConstraintsOnShape(data_format, input_shape, "input", c)); 886 887 std::vector<int32> strides; 888 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); 889 if (strides.size() != 4) { 890 return errors::InvalidArgument( 891 "MaxPool requires the stride attribute to contain 4 values, but got: ", 892 strides.size()); 893 } 894 895 std::vector<int32> kernel_sizes; 896 TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes)); 897 if (kernel_sizes.size() != 4) { 898 return errors::InvalidArgument( 899 "MaxPool requires the ksize attribute to contain 4 values, but got: ", 900 kernel_sizes.size()); 901 } 902 903 int32 stride_depth = GetTensorDim(strides, data_format, 'C'); 904 int32 stride_rows = GetTensorDim(strides, data_format, 'H'); 905 int32 stride_cols = GetTensorDim(strides, data_format, 'W'); 906 int32 kernel_depth = GetTensorDim(kernel_sizes, data_format, 'C'); 907 int32 kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H'); 908 int32 kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W'); 909 910 constexpr int num_spatial_dims = 2; 911 DimensionHandle batch_size_dim = c->Dim( 912 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N')); 913 DimensionHandle in_rows_dim = c->Dim( 914 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H')); 915 DimensionHandle in_cols_dim = c->Dim( 916 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W')); 917 DimensionHandle in_depth_dim = c->Dim( 918 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C')); 919 920 Padding padding; 921 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); 922 923 ShapeHandle output_shape; 924 DimensionHandle output_rows, output_cols, output_depth; 925 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( 926 c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows)); 927 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( 928 c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols)); 929 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( 930 c, in_depth_dim, kernel_depth, stride_depth, padding, &output_depth)); 931 932 TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim, 933 {output_rows, output_cols}, 934 output_depth, &output_shape, c)); 935 936 c->set_output(0, output_shape); 937 return Status::OK(); 938 } 939 940 Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs) { 941 string data_format_str; 942 TensorFormat data_format; 943 Status s = c->GetAttr("data_format", &data_format_str); 944 if (s.ok()) { 945 FormatFromString(data_format_str, &data_format); 946 } else { 947 data_format = FORMAT_NHWC; 948 } 949 950 const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4; 951 ShapeHandle input_shape; 952 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape)); 953 954 TF_RETURN_IF_ERROR( 955 CheckFormatConstraintsOnShape(data_format, input_shape, "input", c)); 956 957 std::vector<int32> kernel_sizes; 958 std::vector<int32> strides; 959 960 if (c->num_inputs() + 2 == num_inputs) { 961 TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes)); 962 963 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); 964 } else { 965 // Verify shape of ksize and strides input. 966 ShapeHandle size; 967 DimensionHandle unused; 968 TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 2), 1, &size)); 969 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 4, &unused)); 970 TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 1, &size)); 971 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 4, &unused)); 972 973 const Tensor* kernel_sizes_tensor = c->input_tensor(c->num_inputs() - 2); 974 if (kernel_sizes_tensor == nullptr) { 975 c->set_output(0, c->UnknownShape()); 976 return Status::OK(); 977 } 978 kernel_sizes.resize(kernel_sizes_tensor->shape().num_elements()); 979 auto kernel_sizes_vec = kernel_sizes_tensor->flat<int32>(); 980 std::copy_n(&kernel_sizes_vec(0), kernel_sizes.size(), 981 kernel_sizes.begin()); 982 983 const Tensor* strides_tensor = c->input_tensor(c->num_inputs() - 1); 984 if (strides_tensor == nullptr) { 985 c->set_output(0, c->UnknownShape()); 986 return Status::OK(); 987 } 988 strides.resize(strides_tensor->shape().num_elements()); 989 auto strides_vec = strides_tensor->flat<int32>(); 990 std::copy_n(&strides_vec(0), strides.size(), strides.begin()); 991 } 992 993 if (strides.size() != 4) { 994 return errors::InvalidArgument( 995 "MaxPool requires the stride attribute to contain 4 values, but " 996 "got: ", 997 strides.size()); 998 } 999 if (kernel_sizes.size() != 4) { 1000 return errors::InvalidArgument( 1001 "MaxPool requires the ksize attribute to contain 4 values, but got: ", 1002 kernel_sizes.size()); 1003 } 1004 1005 int32 stride_depth = GetTensorDim(strides, data_format, 'C'); 1006 int32 stride_rows = GetTensorDim(strides, data_format, 'H'); 1007 int32 stride_cols = GetTensorDim(strides, data_format, 'W'); 1008 int32 kernel_depth = GetTensorDim(kernel_sizes, data_format, 'C'); 1009 int32 kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H'); 1010 int32 kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W'); 1011 1012 constexpr int num_spatial_dims = 2; 1013 DimensionHandle batch_size_dim = c->Dim( 1014 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N')); 1015 DimensionHandle in_rows_dim = c->Dim( 1016 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H')); 1017 DimensionHandle in_cols_dim = c->Dim( 1018 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W')); 1019 DimensionHandle in_depth_dim = c->Dim( 1020 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C')); 1021 1022 Padding padding; 1023 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); 1024 1025 ShapeHandle output_shape; 1026 DimensionHandle output_rows, output_cols, output_depth; 1027 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( 1028 c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows)); 1029 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( 1030 c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols)); 1031 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( 1032 c, in_depth_dim, kernel_depth, stride_depth, padding, &output_depth)); 1033 1034 TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim, 1035 {output_rows, output_cols}, 1036 output_depth, &output_shape, c)); 1037 1038 c->set_output(0, output_shape); 1039 return Status::OK(); 1040 } 1041 1042 Status Pool3DShape(shape_inference::InferenceContext* c) { 1043 ShapeHandle input_shape; 1044 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape)); 1045 1046 string data_format; 1047 Status s = c->GetAttr("data_format", &data_format); 1048 1049 std::vector<int32> strides; 1050 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); 1051 if (strides.size() != 5) { 1052 return errors::InvalidArgument( 1053 "Pool3D ops require the stride attribute to contain 5 values, but " 1054 "got: ", 1055 strides.size()); 1056 } 1057 1058 std::vector<int32> kernel_sizes; 1059 TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes)); 1060 if (kernel_sizes.size() != 5) { 1061 return errors::InvalidArgument( 1062 "Pool3D requires the ksize attribute to contain 5 values, but got: ", 1063 kernel_sizes.size()); 1064 } 1065 1066 int32 stride_planes, stride_rows, stride_cols; 1067 int32 kernel_planes, kernel_rows, kernel_cols; 1068 1069 if (s.ok() && data_format == "NCDHW") { 1070 // Convert input_shape to NDHWC. 1071 auto dim = [&](char dimension) { 1072 return c->Dim(input_shape, GetTensorDimIndex<3>(FORMAT_NCHW, dimension)); 1073 }; 1074 input_shape = 1075 c->MakeShape({{dim('N'), dim('0'), dim('1'), dim('2'), dim('C')}}); 1076 stride_planes = strides[2]; 1077 stride_rows = strides[3]; 1078 stride_cols = strides[4]; 1079 kernel_planes = kernel_sizes[2]; 1080 kernel_rows = kernel_sizes[3]; 1081 kernel_cols = kernel_sizes[4]; 1082 } else { 1083 stride_planes = strides[1]; 1084 stride_rows = strides[2]; 1085 stride_cols = strides[3]; 1086 kernel_planes = kernel_sizes[1]; 1087 kernel_rows = kernel_sizes[2]; 1088 kernel_cols = kernel_sizes[3]; 1089 } 1090 1091 DimensionHandle batch_size_dim = c->Dim(input_shape, 0); 1092 DimensionHandle in_planes_dim = c->Dim(input_shape, 1); 1093 DimensionHandle in_rows_dim = c->Dim(input_shape, 2); 1094 DimensionHandle in_cols_dim = c->Dim(input_shape, 3); 1095 DimensionHandle output_depth_dim = c->Dim(input_shape, 4); 1096 1097 Padding padding; 1098 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); 1099 1100 // TODO(mrry,shlens): Raise an error if the stride would cause 1101 // information in the input to be ignored. This will require a change 1102 // in the kernel implementation. 1103 DimensionHandle output_planes, output_rows, output_cols; 1104 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( 1105 c, in_planes_dim, kernel_planes, stride_planes, padding, &output_planes)); 1106 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( 1107 c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows)); 1108 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( 1109 c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols)); 1110 1111 ShapeHandle output_shape; 1112 if (data_format == "NCDHW") { 1113 output_shape = c->MakeShape({batch_size_dim, output_depth_dim, 1114 output_planes, output_rows, output_cols}); 1115 } else { 1116 output_shape = c->MakeShape({batch_size_dim, output_planes, output_rows, 1117 output_cols, output_depth_dim}); 1118 } 1119 1120 c->set_output(0, output_shape); 1121 return Status::OK(); 1122 } 1123 1124 Status UnknownShape(shape_inference::InferenceContext* c) { 1125 for (int i = 0; i < c->num_outputs(); ++i) { 1126 c->set_output(i, c->UnknownShape()); 1127 } 1128 return Status::OK(); 1129 } 1130 1131 template <typename T> 1132 Status ReductionShapeHelper(const Tensor* reduction_indices_t, 1133 const int32 input_rank, 1134 std::set<int64>* true_indices) { 1135 auto reduction_indices = reduction_indices_t->flat<T>(); 1136 for (int i = 0; i < reduction_indices_t->NumElements(); ++i) { 1137 const T reduction_index = reduction_indices(i); 1138 if (reduction_index < -input_rank || reduction_index >= input_rank) { 1139 return errors::InvalidArgument("Invalid reduction dimension ", 1140 reduction_index, " for input with ", 1141 input_rank, " dimensions."); 1142 } 1143 1144 auto wrapped_index = reduction_index; 1145 if (wrapped_index < 0) { 1146 wrapped_index += input_rank; 1147 } 1148 1149 true_indices->insert(wrapped_index); 1150 } 1151 return Status::OK(); 1152 } 1153 1154 Status ReductionShape(InferenceContext* c) { 1155 ShapeHandle input = c->input(0); 1156 1157 ShapeHandle indices; 1158 // Older versions of TensorFlow accidentally allowed higher rank tensors like 1159 // [[1,2]] or [[1],[2]] to represent axis=[1,2]. 1160 if (c->graph_def_version() < 21) { 1161 indices = c->input(1); 1162 } else { 1163 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &indices)); 1164 } 1165 1166 bool keep_dims; 1167 TF_RETURN_IF_ERROR(c->GetAttr("keep_dims", &keep_dims)); 1168 1169 const Tensor* reduction_indices_t = c->input_tensor(1); 1170 if (reduction_indices_t == nullptr || !c->RankKnown(input)) { 1171 // If we do not have the reduction values at runtime, or the 1172 // rank of the input, we don't know the output shape. 1173 1174 if (keep_dims && c->RankKnown(input)) { 1175 // output rank matches input input if <keep_dims>. 1176 c->set_output(0, c->UnknownShapeOfRank(c->Rank(input))); 1177 return Status::OK(); 1178 } else { 1179 return shape_inference::UnknownShape(c); 1180 } 1181 } 1182 1183 const int32 input_rank = c->Rank(input); 1184 std::set<int64> true_indices; 1185 if (reduction_indices_t->dtype() == DataType::DT_INT32) { 1186 TF_RETURN_IF_ERROR(ReductionShapeHelper<int32>(reduction_indices_t, 1187 input_rank, &true_indices)); 1188 } else if (reduction_indices_t->dtype() == DataType::DT_INT64) { 1189 TF_RETURN_IF_ERROR(ReductionShapeHelper<int64>(reduction_indices_t, 1190 input_rank, &true_indices)); 1191 } else { 1192 return errors::InvalidArgument( 1193 "reduction_indices can only be int32 or int64"); 1194 } 1195 1196 std::vector<DimensionHandle> dims; 1197 for (int i = 0; i < input_rank; ++i) { 1198 if (true_indices.count(i) > 0) { 1199 if (keep_dims) { 1200 dims.emplace_back(c->MakeDim(1)); 1201 } 1202 } else { 1203 dims.emplace_back(c->Dim(input, i)); 1204 } 1205 } 1206 1207 c->set_output(0, c->MakeShape(dims)); 1208 return Status::OK(); 1209 } 1210 1211 Status ConcatShapeHelper(InferenceContext* c, int start_value_index, 1212 int end_value_index, int dim_index) { 1213 ShapeHandle unused; 1214 TF_RETURN_IF_ERROR(c->WithRank(c->input(dim_index), 0, &unused)); 1215 const Tensor* concat_dim_t = c->input_tensor(dim_index); 1216 if (concat_dim_t == nullptr) { 1217 // Return an unknown shape with same rank as inputs, or an unknown rank 1218 // if no input's rank is known. 1219 1220 // Find rank. 1221 int32 rank = InferenceContext::kUnknownRank; 1222 for (int i = start_value_index; i < end_value_index; ++i) { 1223 if (rank == InferenceContext::kUnknownRank) rank = c->Rank(c->input(i)); 1224 if (rank != InferenceContext::kUnknownRank) { 1225 break; 1226 } 1227 } 1228 if (rank == InferenceContext::kUnknownRank) { 1229 c->set_output(0, c->UnknownShape()); 1230 return Status::OK(); 1231 } else if (rank == 0) { 1232 return errors::InvalidArgument( 1233 "Can't concatenate scalars (use tf.stack instead)"); 1234 } else { 1235 for (int i = start_value_index; i < end_value_index; ++i) { 1236 // Check that all the inputs are of the correct rank. 1237 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), rank, &unused)); 1238 } 1239 } 1240 // Build result of <rank> different unknown dims. 1241 std::vector<DimensionHandle> dims; 1242 dims.reserve(rank); 1243 for (int i = 0; i < rank; ++i) dims.push_back(c->UnknownDim()); 1244 c->set_output(0, c->MakeShape(dims)); 1245 return Status::OK(); 1246 } 1247 1248 // Merge all the non-concat dims, and sum the concat dim to make an output 1249 // shape. 1250 const int32 concat_dim = concat_dim_t->scalar<int32>()(); 1251 1252 // Minimum required number of dimensions. 1253 const int min_rank = concat_dim < 0 ? -concat_dim : concat_dim + 1; 1254 1255 ShapeHandle output_before; 1256 ShapeHandle output_after; 1257 1258 ShapeHandle input = c->input(end_value_index - 1); 1259 TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, min_rank, &input)); 1260 TF_RETURN_IF_ERROR(c->Subshape(input, 0, concat_dim, &output_before)); 1261 DimensionHandle output_middle = c->Dim(input, concat_dim); 1262 if (concat_dim == -1) { 1263 output_after = c->Scalar(); // no dimensions. 1264 } else { 1265 TF_RETURN_IF_ERROR(c->Subshape(input, concat_dim + 1, &output_after)); 1266 } 1267 1268 for (int i = end_value_index - 2; i >= start_value_index; --i) { 1269 ShapeHandle before; 1270 ShapeHandle after; 1271 input = c->input(i); 1272 TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, min_rank, &input)); 1273 TF_RETURN_IF_ERROR(c->Subshape(input, 0, concat_dim, &before)); 1274 DimensionHandle middle = c->Dim(input, concat_dim); 1275 if (concat_dim == -1) { 1276 after = c->Scalar(); 1277 } else { 1278 TF_RETURN_IF_ERROR(c->Subshape(input, concat_dim + 1, &after)); 1279 } 1280 1281 TF_RETURN_IF_ERROR(c->Merge(before, output_before, &output_before)); 1282 TF_RETURN_IF_ERROR(c->Add(output_middle, middle, &output_middle)); 1283 TF_RETURN_IF_ERROR(c->Merge(after, output_after, &output_after)); 1284 } 1285 1286 ShapeHandle s; 1287 TF_RETURN_IF_ERROR( 1288 c->Concatenate(output_before, c->Vector(output_middle), &s)); 1289 TF_RETURN_IF_ERROR(c->Concatenate(s, output_after, &s)); 1290 c->set_output(0, s); 1291 return Status::OK(); 1292 } 1293 1294 Status ConcatShape(InferenceContext* c, int num_inputs_to_concat) { 1295 return ConcatShapeHelper(c, 1 /* start_value_index */, 1296 1 + num_inputs_to_concat /* end_value_index */, 1297 0 /* dim_index */); 1298 } 1299 1300 Status ConcatV2Shape(InferenceContext* c) { 1301 return ConcatShapeHelper(c, 0 /* start_value_index */, 1302 c->num_inputs() - 1 /* end_value_index */, 1303 c->num_inputs() - 1 /* dim_index */); 1304 } 1305 1306 Status QuantizedConcatV2Shape(InferenceContext* c, int num_inputs_to_concat) { 1307 return ConcatShapeHelper(c, 0 /* start_value_index */, 1308 num_inputs_to_concat /* end_value_index */, 1309 num_inputs_to_concat /* dim_index */); 1310 } 1311 1312 Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c, 1313 ShapeHandle shape_x, 1314 ShapeHandle shape_y, 1315 ShapeHandle* out) { 1316 CHECK_NOTNULL(out); 1317 if (!c->RankKnown(shape_x) || !c->RankKnown(shape_y)) { 1318 *out = c->UnknownShape(); 1319 return Status::OK(); 1320 } 1321 const int32 rank_x = c->Rank(shape_x); 1322 const int32 rank_y = c->Rank(shape_y); 1323 const int32 rank_out = std::max(rank_x, rank_y); 1324 1325 // To compute the broadcast dimensions, we zip together shape_x and shape_y 1326 // and 1327 // pad with 1 to make them the same length. 1328 std::vector<DimensionHandle> dims; 1329 DimensionHandle dim_one; 1330 if (rank_x != rank_y) dim_one = c->MakeDim(1); 1331 for (int i = 0; i < rank_out; ++i) { 1332 const auto dim_x = i < (rank_out - rank_x) 1333 ? dim_one 1334 : c->Dim(shape_x, i - (rank_out - rank_x)); 1335 const bool dim_y_is_one = (i < (rank_out - rank_y)); 1336 const auto dim_y = 1337 dim_y_is_one ? dim_one : c->Dim(shape_y, i - (rank_out - rank_y)); 1338 if (!c->ValueKnown(dim_x) || !c->ValueKnown(dim_y)) { 1339 // One or both dimensions is unknown. 1340 // 1341 // - If either dimension is greater than 1, we assume that the program is 1342 // correct, and the other dimension will be broadcast to match it. 1343 // TODO(cwhipkey): For shape inference, if we eliminate the shape checks 1344 // in C++ op code, we must still assert that the unknown dim is either 1 1345 // or the same as the known dim. 1346 // - If either dimension is 1, the other dimension is the output. 1347 if (c->Value(dim_x) > 1) { 1348 dims.push_back(dim_x); 1349 } else if (c->Value(dim_y) > 1) { 1350 dims.push_back(dim_y); 1351 } else if (c->Value(dim_x) == 1) { 1352 dims.push_back(dim_y); 1353 } else if (c->Value(dim_y) == 1) { 1354 dims.push_back(dim_x); 1355 } else if (dim_y.SameHandle(dim_x)) { 1356 dims.push_back(dim_x); 1357 } else { 1358 dims.push_back(c->UnknownDim()); 1359 } 1360 } else if (c->Value(dim_x) == 1 || c->Value(dim_y) == 1) { 1361 if (c->Value(dim_x) == 1 && !dim_y_is_one) { 1362 // We will broadcast dim_x to dim_y. 1363 dims.push_back(dim_y); 1364 } else { 1365 DCHECK_EQ(c->Value(dim_y), 1); 1366 // We will broadcast dim_y to dim_x. 1367 dims.push_back(dim_x); 1368 } 1369 } else { 1370 DimensionHandle dim; 1371 TF_RETURN_IF_ERROR(c->Merge(dim_x, dim_y, &dim)); 1372 dims.push_back(dim); 1373 } 1374 } 1375 1376 *out = c->MakeShape(dims); 1377 return Status::OK(); 1378 } 1379 1380 Status RandomShape(shape_inference::InferenceContext* c) { 1381 shape_inference::ShapeHandle out; 1382 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out)); 1383 c->set_output(0, out); 1384 return Status::OK(); 1385 } 1386 1387 namespace { 1388 1389 // This SliceHelper processes the output shape of the `slice` 1390 // when the tensor of `sizes` is available. 1391 template <typename T> 1392 Status SliceHelper(InferenceContext* c, ShapeHandle begin_value, 1393 const Tensor* sizes_value, 1394 std::vector<DimensionHandle>* dims) { 1395 auto sizes_vec = sizes_value->vec<T>(); 1396 for (int i = 0; i < sizes_value->NumElements(); ++i) { 1397 DimensionHandle dim = c->Dim(c->input(0), i); 1398 if (sizes_vec(i) != -1) { 1399 auto dim_val = c->Value(dim); 1400 if (sizes_vec(i) < 0) { 1401 return errors::InvalidArgument( 1402 "Out of bounds slicing on dimension ", i, " of length ", dim_val, 1403 ": sizes vector cannot be < -1, but was ", sizes_vec(i)); 1404 } 1405 1406 dims->emplace_back(c->MakeDim(sizes_vec(i))); 1407 } else { 1408 DimensionHandle result; 1409 TF_RETURN_IF_ERROR(c->Subtract(dim, c->Dim(begin_value, i), &result)); 1410 dims->emplace_back(result); 1411 } 1412 } 1413 1414 return Status::OK(); 1415 } 1416 } // namespace 1417 1418 Status SliceShape(InferenceContext* c) { 1419 ShapeHandle input = c->input(0); 1420 ShapeHandle begin_shape; 1421 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &begin_shape)); 1422 ShapeHandle sizes_shape; 1423 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &sizes_shape)); 1424 1425 // Merge to check compatibility of begin and sizes tensors. 1426 TF_RETURN_IF_ERROR(c->Merge(begin_shape, sizes_shape, &begin_shape)); 1427 1428 DimensionHandle ndims = c->Dim(begin_shape, 0); 1429 if (c->ValueKnown(ndims)) { 1430 TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(ndims), &input)); 1431 } 1432 1433 // NOTE(mrry): Use MakeShapeFromShapeTensor to handle partially-known 1434 // values, even though the `begin` value does not represent a shape. 1435 ShapeHandle begin_value; 1436 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &begin_value)); 1437 1438 // We check the tensor value here and will only use 1439 // `MakeShapeFromShapeTensor` when `sizes_value` is null. 1440 // The reason is that `sizes` might contain -1, which can't 1441 // be represented (-1 in the ShapeHandle would mean "unknown"). 1442 const Tensor* sizes_value = c->input_tensor(2); 1443 1444 if (sizes_value != nullptr) { 1445 TF_RETURN_IF_ERROR( 1446 c->WithRank(begin_value, sizes_value->NumElements(), &begin_value)); 1447 std::vector<DimensionHandle> dims; 1448 // If the begin and sizes tensors are available, then 1449 // we can be precise about the shape of the output. 1450 if (sizes_value->dtype() == DT_INT64) { 1451 TF_RETURN_IF_ERROR( 1452 SliceHelper<int64>(c, begin_value, sizes_value, &dims)); 1453 } else { 1454 TF_RETURN_IF_ERROR( 1455 SliceHelper<int32>(c, begin_value, sizes_value, &dims)); 1456 } 1457 c->set_output(0, c->MakeShape(dims)); 1458 return Status::OK(); 1459 } else { 1460 // In case `sizes` is not available (`sizes_value` is null), 1461 // we could try to use `MakeShapeFromShapeTensor` here. 1462 // If sizes contain -1, we will simply consider it as `Unknown`. 1463 // This is less than ideal but still an improvement of shape inference. 1464 // The following is an example that returns [None, 1, None] with this 1465 // code path: 1466 // z = tf.zeros((1, 2, 3)) 1467 // m = tf.slice(z, [0, 0, 0], [tf.constant(1) + 0, 1, -1]) 1468 // m.get_shape().as_list() 1469 ShapeHandle sizes_value; 1470 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &sizes_value)); 1471 if (c->RankKnown(sizes_value)) { 1472 TF_RETURN_IF_ERROR( 1473 c->WithRank(begin_value, c->Rank(sizes_value), &begin_value)); 1474 std::vector<DimensionHandle> dims; 1475 dims.reserve(c->Rank(sizes_value)); 1476 for (int i = 0; i < c->Rank(sizes_value); ++i) { 1477 dims.emplace_back(c->Dim(sizes_value, i)); 1478 } 1479 c->set_output(0, c->MakeShape(dims)); 1480 return Status::OK(); 1481 } 1482 // We might know the rank of the input. 1483 if (c->RankKnown(input)) { 1484 c->set_output(0, c->UnknownShapeOfRank(c->Rank(input))); 1485 return Status::OK(); 1486 } else { 1487 return shape_inference::UnknownShape(c); 1488 } 1489 } 1490 1491 return Status::OK(); 1492 } 1493 1494 Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape, 1495 ShapeHandle values_shape, ShapeHandle shape_shape) { 1496 // Validate ranks. 1497 ShapeHandle unused_shape; 1498 TF_RETURN_IF_ERROR(c->WithRank(indices_shape, 2, &unused_shape)); 1499 TF_RETURN_IF_ERROR(c->WithRank(values_shape, 1, &unused_shape)); 1500 TF_RETURN_IF_ERROR(c->WithRank(shape_shape, 1, &unused_shape)); 1501 1502 // Number of elements in indices and values must match. 1503 DimensionHandle num_index_elements_dim = c->Dim(indices_shape, 0); 1504 if (c->ValueKnown(num_index_elements_dim)) { 1505 DimensionHandle num_values_elements_dim = c->Dim(values_shape, 0); 1506 if (c->ValueKnown(num_values_elements_dim)) { 1507 int64 num_index_elements = c->Value(num_index_elements_dim); 1508 int64 num_values_elements = c->Value(num_values_elements_dim); 1509 if (num_index_elements != num_values_elements) { 1510 return errors::InvalidArgument("Number of elements in index (", 1511 num_index_elements, ") and values (", 1512 num_values_elements, ") do not match."); 1513 } 1514 } 1515 } 1516 1517 // Rank embedded in indices must match shape. 1518 DimensionHandle index_rank_dim = c->Dim(indices_shape, 1); 1519 if (c->ValueKnown(index_rank_dim)) { 1520 DimensionHandle shape_rank_dim = c->Dim(shape_shape, 0); 1521 if (c->ValueKnown(shape_rank_dim)) { 1522 int64 index_rank = c->Value(index_rank_dim); 1523 int32 shape_rank = c->Value(shape_rank_dim); 1524 if (index_rank != shape_rank) { 1525 return errors::InvalidArgument("Index rank (", index_rank, 1526 ") and shape rank (", shape_rank, 1527 ") do not match."); 1528 } 1529 } 1530 } 1531 1532 return Status::OK(); 1533 } 1534 1535 Status ScatterNdUpdateShape(InferenceContext* c) { 1536 ShapeHandle input_shape = c->input(0); 1537 if (c->input_handle_shapes_and_types(0) != nullptr) { 1538 // This is called for tf.scatter_nd_update; input is a Variable handle. 1539 const auto& shape_and_type = *(c->input_handle_shapes_and_types(0)); 1540 if (shape_and_type.size() == 1) { 1541 input_shape = shape_and_type[0].shape; 1542 } 1543 } 1544 ShapeHandle indices_shape; 1545 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices_shape)); 1546 ShapeHandle updates_shape; 1547 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 1, &updates_shape)); 1548 1549 if (c->Value(c->NumElements(input_shape)) == 0 && 1550 (c->Value(c->NumElements(indices_shape)) > 0 || 1551 c->Value(c->NumElements(updates_shape)) > 0)) { 1552 return errors::InvalidArgument( 1553 "Indices and updates specified for empty output shape"); 1554 } 1555 1556 if (c->RankKnown(indices_shape) && c->RankKnown(updates_shape)) { 1557 const int64 num_outer_dims = c->Rank(indices_shape) - 1; 1558 const DimensionHandle index_size = c->Dim(indices_shape, -1); 1559 1560 // We can only do more validation if the last dimension of indices 1561 // is a known value. 1562 if (c->ValueKnown(index_size)) { 1563 const int64 ix = c->Value(index_size); 1564 ShapeHandle unused; 1565 ShapeHandle prefix_indices; 1566 TF_RETURN_IF_ERROR( 1567 c->Subshape(indices_shape, 0, num_outer_dims, &prefix_indices)); 1568 ShapeHandle prefix_updates; 1569 TF_RETURN_IF_ERROR( 1570 c->Subshape(updates_shape, 0, num_outer_dims, &prefix_updates)); 1571 1572 Status s = c->Merge(prefix_indices, prefix_updates, &unused); 1573 if (!s.ok()) { 1574 return errors::InvalidArgument( 1575 "The outer ", num_outer_dims, 1576 " dimensions of indices.shape=", c->DebugString(indices_shape), 1577 " must match the outer ", num_outer_dims, 1578 " dimensions of updates.shape=", c->DebugString(updates_shape), 1579 ": ", s.error_message()); 1580 } 1581 1582 ShapeHandle input_suffix; 1583 TF_RETURN_IF_ERROR(c->Subshape(input_shape, ix, &input_suffix)); 1584 ShapeHandle suffix_updates; 1585 TF_RETURN_IF_ERROR( 1586 c->Subshape(updates_shape, num_outer_dims, &suffix_updates)); 1587 s = c->Merge(input_suffix, suffix_updates, &unused); 1588 if (!s.ok()) { 1589 return errors::InvalidArgument( 1590 "The inner ", c->Rank(input_shape) - ix, 1591 " dimensions of input.shape=", c->DebugString(input_shape), 1592 " must match the inner ", c->Rank(updates_shape) - num_outer_dims, 1593 " dimensions of updates.shape=", c->DebugString(updates_shape), 1594 ": ", s.error_message()); 1595 } 1596 } 1597 } 1598 1599 if (c->input_handle_shapes_and_types(0) == nullptr && c->num_outputs() > 0) { 1600 // This is called for tf.scatter_nd; output is a tensor with this shape. 1601 c->set_output(0, input_shape); 1602 } 1603 return Status::OK(); 1604 } 1605 1606 Status ExplicitShape(InferenceContext* c) { 1607 PartialTensorShape shape; 1608 TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape)); 1609 ShapeHandle output_shape; 1610 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &output_shape)); 1611 c->set_output(0, output_shape); 1612 return Status::OK(); 1613 } 1614 1615 Status ExplicitShapes(InferenceContext* c) { 1616 std::vector<PartialTensorShape> shapes; 1617 TF_RETURN_IF_ERROR(c->GetAttr("shapes", &shapes)); 1618 if (shapes.empty()) { 1619 return errors::Internal("shapes attribute is empty"); 1620 } 1621 for (int i = 0; i < shapes.size(); ++i) { 1622 ShapeHandle output_shape; 1623 TF_RETURN_IF_ERROR( 1624 c->MakeShapeFromPartialTensorShape(shapes[i], &output_shape)); 1625 c->set_output(i, output_shape); 1626 } 1627 return Status::OK(); 1628 } 1629 1630 Status SparseReduceShapeFn(InferenceContext* c) { 1631 // Input 0: input_indices 1632 // Input 1: input_values 1633 // Input 2: input_shape 1634 // Input 3: reduction_axes 1635 // Attr: keep_dims 1636 bool keep_dims = false; 1637 TF_RETURN_IF_ERROR(c->GetAttr("keep_dims", &keep_dims)); 1638 1639 const Tensor* shape_tensor = c->input_tensor(2); 1640 const Tensor* axes_tensor = c->input_tensor(3); 1641 if (shape_tensor != nullptr && axes_tensor != nullptr) { 1642 auto shape_vec = shape_tensor->flat<int64>(); 1643 auto axes_vec = axes_tensor->flat<int32>(); 1644 1645 int64 ndims = shape_vec.size(); 1646 std::unordered_set<int64> axes; 1647 for (int i = 0; i < axes_vec.size(); i++) { 1648 axes.insert((axes_vec(i) + ndims) % ndims); 1649 } 1650 1651 std::vector<DimensionHandle> dims; 1652 if (keep_dims) { 1653 dims.reserve(ndims); 1654 for (int d = 0; d < ndims; ++d) { 1655 if (axes.find(d) == axes.end()) { 1656 dims.push_back(c->MakeDim(shape_vec(d))); 1657 } else { 1658 dims.push_back(c->MakeDim(1)); 1659 } 1660 } 1661 } else { 1662 for (int d = 0; d < ndims; ++d) { 1663 if (axes.find(d) == axes.end()) { 1664 dims.push_back(c->MakeDim(shape_vec(d))); 1665 } 1666 } 1667 } 1668 1669 c->set_output(0, c->MakeShape(dims)); 1670 return Status::OK(); 1671 } 1672 return UnknownShape(c); 1673 } 1674 1675 } // namespace shape_inference 1676 1677 } // namespace tensorflow 1678