1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/core/framework/common_shape_fns.h" 16 #include "tensorflow/core/framework/attr_value.pb.h" 17 18 namespace tensorflow { 19 20 Status GetWindowedOutputSizeVerboseV2(int64 input_size, int64 filter_size, 21 int64 dilation_rate, int64 stride, 22 Padding padding_type, int64* output_size, 23 int64* padding_before, 24 int64* padding_after) { 25 if (stride <= 0) { 26 return errors::InvalidArgument("Stride must be > 0, but got ", stride); 27 } 28 if (dilation_rate < 1) { 29 return errors::InvalidArgument("Dilation rate must be >= 1, but got ", 30 dilation_rate); 31 } 32 33 // See also the parallel implementation in GetWindowedOutputSizeFromDimsV2. 34 int64 effective_filter_size = (filter_size - 1) * dilation_rate + 1; 35 switch (padding_type) { 36 case Padding::VALID: 37 *output_size = (input_size - effective_filter_size + stride) / stride; 38 *padding_before = *padding_after = 0; 39 break; 40 case Padding::SAME: 41 *output_size = (input_size + stride - 1) / stride; 42 const int64 padding_needed = 43 std::max(0LL, (*output_size - 1) * stride + effective_filter_size - 44 input_size); 45 // For odd values of total padding, add more padding at the 'right' 46 // side of the given dimension. 47 *padding_before = padding_needed / 2; 48 *padding_after = padding_needed - *padding_before; 49 break; 50 } 51 if (*output_size < 0) { 52 return errors::InvalidArgument( 53 "Computed output size would be negative: ", *output_size, 54 " [input_size: ", input_size, 55 ", effective_filter_size: ", effective_filter_size, 56 ", stride: ", stride, "]"); 57 } 58 return Status::OK(); 59 } 60 61 Status GetWindowedOutputSizeVerbose(int64 input_size, int64 filter_size, 62 int64 stride, Padding padding_type, 63 int64* output_size, int64* padding_before, 64 int64* padding_after) { 65 return GetWindowedOutputSizeVerboseV2(input_size, filter_size, 66 /*dilation_rate=*/1, stride, 67 padding_type, output_size, 68 padding_before, padding_after); 69 } 70 71 Status GetWindowedOutputSize(int64 input_size, int64 filter_size, int64 stride, 72 Padding padding_type, int64* output_size, 73 int64* padding_size) { 74 int64 padding_after_unused; 75 return GetWindowedOutputSizeVerbose(input_size, filter_size, stride, 76 padding_type, output_size, padding_size, 77 &padding_after_unused); 78 } 79 80 Status GetWindowedOutputSizeV2(int64 input_size, int64 filter_size, 81 int64 dilation_rate, int64 stride, 82 Padding padding_type, int64* output_size, 83 int64* padding_size) { 84 int64 padding_after_unused; 85 return GetWindowedOutputSizeVerboseV2(input_size, filter_size, dilation_rate, 86 stride, padding_type, output_size, 87 padding_size, &padding_after_unused); 88 } 89 90 Status Get3dOutputSize(const std::array<int64, 3>& input, 91 const std::array<int64, 3>& window, 92 const std::array<int64, 3>& strides, 93 Padding padding_type, std::array<int64, 3>* output_ptr, 94 std::array<int64, 3>* padding_ptr) { 95 for (size_t i = 0; i < input.size(); ++i) { 96 TF_RETURN_IF_ERROR(GetWindowedOutputSize(input[i], window[i], strides[i], 97 padding_type, &(*output_ptr)[i], 98 &(*padding_ptr)[i])); 99 } 100 return Status::OK(); 101 } 102 103 Status Get3dOutputSizeV2(const std::array<int64, 3>& input, 104 const std::array<int64, 3>& window, 105 const std::array<int64, 3>& dilations, 106 const std::array<int64, 3>& strides, 107 Padding padding_type, std::array<int64, 3>* output_ptr, 108 std::array<int64, 3>* padding_ptr) { 109 for (size_t i = 0; i < input.size(); ++i) { 110 TF_RETURN_IF_ERROR(GetWindowedOutputSizeV2( 111 input[i], window[i], dilations[i], strides[i], padding_type, 112 &(*output_ptr)[i], &(*padding_ptr)[i])); 113 } 114 return Status::OK(); 115 } 116 117 namespace shape_inference { 118 119 // The V2 version computes windowed output size with arbitrary dilation_rate, 120 // while the original version only handles the cases where dilation_rates equal 121 // to 1. 122 Status GetWindowedOutputSizeFromDimsV2( 123 shape_inference::InferenceContext* c, 124 shape_inference::DimensionHandle input_size, 125 shape_inference::DimensionOrConstant filter_size, int64 dilation_rate, 126 int64 stride, Padding padding_type, 127 shape_inference::DimensionHandle* output_size) { 128 if (stride <= 0) { 129 return errors::InvalidArgument("Stride must be > 0, but got ", stride); 130 } 131 132 if (dilation_rate < 1) { 133 return errors::InvalidArgument("Dilation rate must be >= 1, but got ", 134 dilation_rate); 135 } 136 137 // See also the parallel implementation in GetWindowedOutputSizeVerbose. 138 switch (padding_type) { 139 case Padding::VALID: 140 if (dilation_rate > 1) { 141 DimensionHandle window_size; 142 TF_RETURN_IF_ERROR( 143 c->Subtract(c->MakeDim(filter_size), 1, &window_size)); 144 TF_RETURN_IF_ERROR( 145 c->Multiply(window_size, dilation_rate, &window_size)); 146 TF_RETURN_IF_ERROR(c->Add(window_size, 1, &window_size)); 147 TF_RETURN_IF_ERROR(c->Subtract(input_size, window_size, output_size)); 148 } else { 149 TF_RETURN_IF_ERROR(c->Subtract(input_size, filter_size, output_size)); 150 } 151 TF_RETURN_IF_ERROR(c->Add(*output_size, stride, output_size)); 152 TF_RETURN_IF_ERROR(c->Divide(*output_size, stride, 153 /*evenly_divisible=*/false, output_size)); 154 break; 155 case Padding::SAME: 156 TF_RETURN_IF_ERROR(c->Add(input_size, stride - 1, output_size)); 157 TF_RETURN_IF_ERROR(c->Divide(*output_size, stride, 158 /*evenly_divisible=*/false, output_size)); 159 break; 160 } 161 return Status::OK(); 162 } 163 164 Status GetWindowedOutputSizeFromDims( 165 shape_inference::InferenceContext* c, 166 shape_inference::DimensionHandle input_size, 167 shape_inference::DimensionOrConstant filter_size, int64 stride, 168 Padding padding_type, shape_inference::DimensionHandle* output_size) { 169 return GetWindowedOutputSizeFromDimsV2(c, input_size, filter_size, 170 /*dilation_rate=*/1, stride, 171 padding_type, output_size); 172 } 173 174 Status UnchangedShape(shape_inference::InferenceContext* c) { 175 c->set_output(0, c->input(0)); 176 return Status::OK(); 177 } 178 179 Status MatMulShape(shape_inference::InferenceContext* c) { 180 ShapeHandle a; 181 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &a)); 182 183 ShapeHandle b; 184 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &b)); 185 186 bool transpose_a, transpose_b; 187 TF_RETURN_IF_ERROR(c->GetAttr("transpose_a", &transpose_a)); 188 TF_RETURN_IF_ERROR(c->GetAttr("transpose_b", &transpose_b)); 189 DimensionHandle output_rows = transpose_a ? c->Dim(a, 1) : c->Dim(a, 0); 190 DimensionHandle output_cols = transpose_b ? c->Dim(b, 0) : c->Dim(b, 1); 191 192 // Validate that the inner shapes are compatible. 193 DimensionHandle inner_a = transpose_a ? c->Dim(a, 0) : c->Dim(a, 1); 194 DimensionHandle inner_b = transpose_b ? c->Dim(b, 1) : c->Dim(b, 0); 195 DimensionHandle merged; 196 TF_RETURN_IF_ERROR(c->Merge(inner_a, inner_b, &merged)); 197 198 c->set_output(0, c->Matrix(output_rows, output_cols)); 199 return Status::OK(); 200 } 201 202 Status BiasAddShape(shape_inference::InferenceContext* c) { 203 ShapeHandle input_shape; 204 205 // Fetch the data_format attribute, which may not exist. 206 string data_format; 207 Status s = c->GetAttr("data_format", &data_format); 208 209 if (s.ok() && data_format == "NCHW") { 210 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &input_shape)); 211 } else { 212 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape)); 213 } 214 215 ShapeHandle bias_shape; 216 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &bias_shape)); 217 DimensionHandle bias_dim = c->Dim(bias_shape, 0); 218 219 // If rank unknown, return unknown shape. 220 if (!c->RankKnown(input_shape)) { 221 c->set_output(0, c->UnknownShape()); 222 return Status::OK(); 223 } 224 225 // Output has the same shape as the input, and matches the length of 226 // the bias in its bias dimension. 227 ShapeHandle output_shape; 228 if (s.ok() && data_format == "NCHW") { 229 // Merge the length of bias_shape into the third to last dimension 230 ShapeHandle first; 231 TF_RETURN_IF_ERROR(c->Subshape(input_shape, 0, -3, &first)); 232 233 ShapeHandle last; 234 TF_RETURN_IF_ERROR(c->Subshape(input_shape, -2, &last)); 235 236 DimensionHandle input_bias_dim = c->Dim(input_shape, -3); 237 DimensionHandle merged_bias_dim; 238 TF_RETURN_IF_ERROR(c->Merge(input_bias_dim, bias_dim, &merged_bias_dim)); 239 ShapeHandle merged_bias = c->Vector(merged_bias_dim); 240 241 ShapeHandle temp; 242 TF_RETURN_IF_ERROR(c->Concatenate(first, merged_bias, &temp)); 243 TF_RETURN_IF_ERROR(c->Concatenate(temp, last, &output_shape)); 244 } else { 245 ShapeHandle all_but_bias; 246 TF_RETURN_IF_ERROR(c->Subshape(input_shape, 0, -1, &all_but_bias)); 247 248 DimensionHandle input_bias_dim = c->Dim(input_shape, -1); 249 DimensionHandle merged_bias_dim; 250 TF_RETURN_IF_ERROR(c->Merge(input_bias_dim, bias_dim, &merged_bias_dim)); 251 252 ShapeHandle merged_bias = c->Vector(merged_bias_dim); 253 TF_RETURN_IF_ERROR( 254 c->Concatenate(all_but_bias, merged_bias, &output_shape)); 255 } 256 257 c->set_output(0, output_shape); 258 return Status::OK(); 259 } 260 261 Status BiasAddGradShape(shape_inference::InferenceContext* c) { 262 ShapeHandle input_shape; 263 // Fetch the data_format attribute, which may not exist. 264 string data_format; 265 Status s = c->GetAttr("data_format", &data_format); 266 267 if (s.ok() && data_format == "NCHW") { 268 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &input_shape)); 269 c->set_output(0, c->Vector(c->Dim(input_shape, -3))); 270 } else { 271 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape)); 272 c->set_output(0, c->Vector(c->Dim(input_shape, -1))); 273 } 274 275 return Status::OK(); 276 } 277 278 Status CheckFormatConstraintsOnShape(const TensorFormat tensor_format, 279 const ShapeHandle shape_handle, 280 const string& tensor_name, 281 shape_inference::InferenceContext* c) { 282 if (tensor_format == FORMAT_NCHW_VECT_C) { 283 // Check that the vect dim has size 4. 284 const int num_dims = c->Rank(shape_handle); 285 DimensionHandle vect_dim = c->Dim( 286 shape_handle, GetTensorInnerFeatureDimIndex(num_dims, tensor_format)); 287 DimensionHandle unused_vect_dim; 288 TF_RETURN_IF_ERROR(c->WithValue(vect_dim, 4, &unused_vect_dim)); 289 } 290 291 return Status::OK(); 292 } 293 294 Status MakeShapeFromFormat(TensorFormat format, DimensionOrConstant N, 295 const std::vector<DimensionOrConstant>& spatial, 296 DimensionOrConstant C, ShapeHandle* out, 297 shape_inference::InferenceContext* context) { 298 const int num_dims = GetTensorDimsFromSpatialDims(spatial.size(), format); 299 std::vector<DimensionHandle> dims_actual(num_dims); 300 dims_actual[GetTensorBatchDimIndex(num_dims, format)] = context->MakeDim(N); 301 int outer_c_index = GetTensorFeatureDimIndex(num_dims, format); 302 dims_actual[outer_c_index] = context->MakeDim(C); 303 if (format == FORMAT_NCHW_VECT_C) { 304 dims_actual[GetTensorInnerFeatureDimIndex(num_dims, format)] = 305 context->MakeDim(4); 306 } 307 for (int spatial_dim = 0; spatial_dim < spatial.size(); spatial_dim++) { 308 dims_actual[GetTensorSpatialDimIndex(num_dims, format, spatial_dim)] = 309 context->MakeDim(spatial[spatial_dim]); 310 } 311 *out = context->MakeShape(dims_actual); 312 return Status::OK(); 313 } 314 315 Status DimensionsFromShape(ShapeHandle shape, TensorFormat format, 316 DimensionHandle* batch_dim, 317 gtl::MutableArraySlice<DimensionHandle> spatial_dims, 318 DimensionHandle* filter_dim, 319 InferenceContext* context) { 320 const int32 rank = GetTensorDimsFromSpatialDims(spatial_dims.size(), format); 321 // Batch. 322 *batch_dim = context->Dim(shape, GetTensorBatchDimIndex(rank, format)); 323 // Spatial. 324 for (int spatial_dim_index = 0; spatial_dim_index < spatial_dims.size(); 325 ++spatial_dim_index) { 326 spatial_dims[spatial_dim_index] = context->Dim( 327 shape, GetTensorSpatialDimIndex(rank, format, spatial_dim_index)); 328 } 329 // Channel. 330 *filter_dim = context->Dim(shape, GetTensorFeatureDimIndex(rank, format)); 331 if (format == FORMAT_NCHW_VECT_C) { 332 TF_RETURN_IF_ERROR(context->Multiply( 333 *filter_dim, 334 context->Dim(shape, GetTensorInnerFeatureDimIndex(rank, format)), 335 filter_dim)); 336 } 337 return Status::OK(); 338 } 339 340 Status ShapeFromDimensions(DimensionHandle batch_dim, 341 gtl::ArraySlice<DimensionHandle> spatial_dims, 342 DimensionHandle filter_dim, TensorFormat format, 343 InferenceContext* context, ShapeHandle* shape) { 344 const int32 rank = GetTensorDimsFromSpatialDims(spatial_dims.size(), format); 345 std::vector<DimensionHandle> out_dims(rank); 346 347 // Batch. 348 out_dims[tensorflow::GetTensorBatchDimIndex(rank, format)] = batch_dim; 349 // Spatial. 350 for (int spatial_dim_index = 0; spatial_dim_index < spatial_dims.size(); 351 ++spatial_dim_index) { 352 out_dims[tensorflow::GetTensorSpatialDimIndex( 353 rank, format, spatial_dim_index)] = spatial_dims[spatial_dim_index]; 354 } 355 // Channel. 356 if (format == tensorflow::FORMAT_NCHW_VECT_C) { 357 // When format is NCHW_VECT_C, factor the feature map count 358 // into the outer feature count and the inner feature count (=4). 359 TF_RETURN_IF_ERROR(context->Divide( 360 filter_dim, 4, /*evenly_divisible=*/true, 361 &out_dims[tensorflow::GetTensorFeatureDimIndex(rank, format)])); 362 out_dims[GetTensorInnerFeatureDimIndex(rank, format)] = context->MakeDim(4); 363 } else { 364 out_dims[tensorflow::GetTensorFeatureDimIndex(rank, format)] = filter_dim; 365 } 366 367 *shape = context->MakeShape(out_dims); 368 return tensorflow::Status::OK(); 369 } 370 371 Status Conv2DShape(shape_inference::InferenceContext* c) { 372 string data_format_str, filter_format_str; 373 if (!c->GetAttr("data_format", &data_format_str).ok()) { 374 data_format_str = "NHWC"; 375 } 376 if (!c->GetAttr("filter_format", &filter_format_str).ok()) { 377 filter_format_str = "HWIO"; 378 } 379 380 TensorFormat data_format; 381 if (!FormatFromString(data_format_str, &data_format)) { 382 return errors::InvalidArgument("Invalid data format string: ", 383 data_format_str); 384 } 385 FilterTensorFormat filter_format; 386 if (!FilterFormatFromString(filter_format_str, &filter_format)) { 387 return errors::InvalidArgument("Invalid filter format string: ", 388 filter_format_str); 389 } 390 391 constexpr int num_spatial_dims = 2; 392 const int rank = GetTensorDimsFromSpatialDims(num_spatial_dims, data_format); 393 ShapeHandle conv_input_shape; 394 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &conv_input_shape)); 395 TF_RETURN_IF_ERROR(CheckFormatConstraintsOnShape( 396 data_format, conv_input_shape, "conv_input", c)); 397 398 // The filter rank should match the input (4 for NCHW, 5 for NCHW_VECT_C). 399 ShapeHandle filter_shape; 400 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &filter_shape)); 401 TF_RETURN_IF_ERROR( 402 CheckFormatConstraintsOnShape(data_format, filter_shape, "filter", c)); 403 404 std::vector<int32> dilations; 405 TF_RETURN_IF_ERROR(c->GetAttr("dilations", &dilations)); 406 407 if (dilations.size() != 4) { 408 return errors::InvalidArgument( 409 "Conv2D requires the dilation attribute to contain 4 values, but got: ", 410 dilations.size()); 411 } 412 413 std::vector<int32> strides; 414 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); 415 416 // strides.size() should be 4 (NCHW) even if the input is 5 (NCHW_VECT_C). 417 if (strides.size() != 4) { 418 return errors::InvalidArgument("Conv2D on data format ", data_format_str, 419 " requires the stride attribute to contain" 420 " 4 values, but got: ", 421 strides.size()); 422 } 423 424 const int32 stride_rows = GetTensorDim(strides, data_format, 'H'); 425 const int32 stride_cols = GetTensorDim(strides, data_format, 'W'); 426 const int32 dilation_rows = GetTensorDim(dilations, data_format, 'H'); 427 const int32 dilation_cols = GetTensorDim(dilations, data_format, 'W'); 428 429 DimensionHandle batch_size_dim; 430 DimensionHandle input_depth_dim; 431 gtl::InlinedVector<DimensionHandle, 2> input_spatial_dims(2); 432 TF_RETURN_IF_ERROR(DimensionsFromShape(conv_input_shape, data_format, 433 &batch_size_dim, &input_spatial_dims, 434 &input_depth_dim, c)); 435 436 DimensionHandle output_depth_dim = c->Dim( 437 filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'O')); 438 DimensionHandle filter_rows_dim = c->Dim( 439 filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'H')); 440 DimensionHandle filter_cols_dim = c->Dim( 441 filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'W')); 442 DimensionHandle filter_input_depth_dim; 443 if (filter_format == FORMAT_OIHW_VECT_I) { 444 TF_RETURN_IF_ERROR(c->Multiply( 445 c->Dim(filter_shape, 446 GetFilterDimIndex<num_spatial_dims>(filter_format, 'I')), 447 c->Dim(filter_shape, 448 GetFilterTensorInnerInputChannelsDimIndex(rank, filter_format)), 449 &filter_input_depth_dim)); 450 } else { 451 filter_input_depth_dim = c->Dim( 452 filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'I')); 453 } 454 455 // Check that the input tensor and the filter tensor agree on the input 456 // channel count. 457 DimensionHandle unused; 458 TF_RETURN_IF_ERROR( 459 c->Merge(input_depth_dim, filter_input_depth_dim, &unused)); 460 461 Padding padding; 462 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); 463 464 DimensionHandle output_rows, output_cols; 465 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2( 466 c, input_spatial_dims[0], filter_rows_dim, dilation_rows, stride_rows, 467 padding, &output_rows)); 468 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2( 469 c, input_spatial_dims[1], filter_cols_dim, dilation_cols, stride_cols, 470 padding, &output_cols)); 471 472 ShapeHandle output_shape; 473 TF_RETURN_IF_ERROR( 474 ShapeFromDimensions(batch_size_dim, {output_rows, output_cols}, 475 output_depth_dim, data_format, c, &output_shape)); 476 c->set_output(0, output_shape); 477 return Status::OK(); 478 } 479 480 // TODO(mjanusz): Unify all conv/pooling shape functions. 481 Status Conv3DShape(shape_inference::InferenceContext* c) { 482 ShapeHandle input_shape; 483 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape)); 484 ShapeHandle filter_shape; 485 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 5, &filter_shape)); 486 487 string data_format; 488 Status s = c->GetAttr("data_format", &data_format); 489 490 std::vector<int32> strides; 491 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); 492 if (strides.size() != 5) { 493 return errors::InvalidArgument( 494 "Conv3D requires the stride attribute to contain 5 values, but got: ", 495 strides.size()); 496 } 497 498 int32 stride_planes, stride_rows, stride_cols; 499 if (s.ok() && data_format == "NCDHW") { 500 // Convert input_shape to NDHWC. 501 auto dim = [&](char dimension) { 502 return c->Dim(input_shape, GetTensorDimIndex<3>(FORMAT_NCHW, dimension)); 503 }; 504 input_shape = 505 c->MakeShape({{dim('N'), dim('0'), dim('1'), dim('2'), dim('C')}}); 506 stride_planes = strides[2]; 507 stride_cols = strides[3]; 508 stride_rows = strides[4]; 509 } else { 510 stride_planes = strides[1]; 511 stride_rows = strides[2]; 512 stride_cols = strides[3]; 513 } 514 515 DimensionHandle batch_size_dim = c->Dim(input_shape, 0); 516 DimensionHandle in_planes_dim = c->Dim(input_shape, 1); 517 DimensionHandle in_rows_dim = c->Dim(input_shape, 2); 518 DimensionHandle in_cols_dim = c->Dim(input_shape, 3); 519 520 DimensionHandle filter_planes_dim = c->Dim(filter_shape, 0); 521 DimensionHandle filter_rows_dim = c->Dim(filter_shape, 1); 522 DimensionHandle filter_cols_dim = c->Dim(filter_shape, 2); 523 DimensionHandle output_depth_dim = c->Dim(filter_shape, 4); 524 525 DimensionHandle unused; 526 TF_RETURN_IF_ERROR( 527 c->Merge(c->Dim(input_shape, 4), c->Dim(filter_shape, 3), &unused)); 528 529 Padding padding; 530 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); 531 DimensionHandle output_planes, output_rows, output_cols; 532 533 TF_RETURN_IF_ERROR( 534 GetWindowedOutputSizeFromDims(c, in_planes_dim, filter_planes_dim, 535 stride_planes, padding, &output_planes)); 536 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( 537 c, in_rows_dim, filter_rows_dim, stride_rows, padding, &output_rows)); 538 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( 539 c, in_cols_dim, filter_cols_dim, stride_cols, padding, &output_cols)); 540 541 ShapeHandle output_shape; 542 if (data_format == "NCDHW") { 543 output_shape = c->MakeShape({batch_size_dim, output_depth_dim, 544 output_planes, output_rows, output_cols}); 545 } else { 546 output_shape = c->MakeShape({batch_size_dim, output_planes, output_rows, 547 output_cols, output_depth_dim}); 548 } 549 c->set_output(0, output_shape); 550 return Status::OK(); 551 } 552 553 Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c) { 554 ShapeHandle input_shape; 555 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape)); 556 ShapeHandle filter_shape; 557 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &filter_shape)); 558 559 std::vector<int32> strides; 560 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); 561 562 if (strides.size() != 4) { 563 return errors::InvalidArgument( 564 "DepthwiseConv2D requires the stride attribute to contain 4 values, " 565 "but got: ", 566 strides.size()); 567 } 568 569 string data_format; 570 Status s = c->GetAttr("data_format", &data_format); 571 int32 stride_rows; 572 int32 stride_cols; 573 if (s.ok() && data_format == "NCHW") { 574 // Canonicalize input shape to NHWC so the shape inference code below can 575 // process it. 576 input_shape = 577 c->MakeShape({{c->Dim(input_shape, 0), c->Dim(input_shape, 2), 578 c->Dim(input_shape, 3), c->Dim(input_shape, 1)}}); 579 stride_rows = strides[2]; 580 stride_cols = strides[3]; 581 } else { 582 stride_rows = strides[1]; 583 stride_cols = strides[2]; 584 } 585 586 DimensionHandle batch_size_dim = c->Dim(input_shape, 0); 587 DimensionHandle in_rows_dim = c->Dim(input_shape, 1); 588 DimensionHandle in_cols_dim = c->Dim(input_shape, 2); 589 590 DimensionHandle filter_rows_dim = c->Dim(filter_shape, 0); 591 DimensionHandle filter_cols_dim = c->Dim(filter_shape, 1); 592 DimensionHandle input_depth = c->Dim(filter_shape, 2); 593 DimensionHandle depth_multiplier = c->Dim(filter_shape, 3); 594 595 // Check that the input depths are compatible. 596 TF_RETURN_IF_ERROR( 597 c->Merge(c->Dim(input_shape, 3), input_depth, &input_depth)); 598 599 DimensionHandle output_depth; 600 TF_RETURN_IF_ERROR(c->Multiply(input_depth, depth_multiplier, &output_depth)); 601 602 Padding padding; 603 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); 604 605 // TODO(mrry,shlens): Raise an error if the stride would cause 606 // information in the input to be ignored. This will require a change 607 // in the kernel implementation. 608 DimensionHandle output_rows, output_cols; 609 610 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( 611 c, in_rows_dim, filter_rows_dim, stride_rows, padding, &output_rows)); 612 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( 613 c, in_cols_dim, filter_cols_dim, stride_cols, padding, &output_cols)); 614 615 ShapeHandle output_shape; 616 if (data_format == "NCHW") { 617 output_shape = 618 c->MakeShape({batch_size_dim, output_depth, output_rows, output_cols}); 619 } else { 620 output_shape = 621 c->MakeShape({batch_size_dim, output_rows, output_cols, output_depth}); 622 } 623 c->set_output(0, output_shape); 624 return Status::OK(); 625 } 626 627 Status AvgPoolShape(shape_inference::InferenceContext* c) { 628 string data_format_str; 629 TensorFormat data_format; 630 Status s = c->GetAttr("data_format", &data_format_str); 631 if (s.ok()) { 632 FormatFromString(data_format_str, &data_format); 633 } else { 634 data_format = FORMAT_NHWC; 635 } 636 637 const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4; 638 ShapeHandle input_shape; 639 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape)); 640 641 TF_RETURN_IF_ERROR( 642 CheckFormatConstraintsOnShape(data_format, input_shape, "input", c)); 643 644 std::vector<int32> strides; 645 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); 646 if (strides.size() != 4) { 647 return errors::InvalidArgument( 648 "AvgPool requires the stride attribute to contain 4 values, but got: ", 649 strides.size()); 650 } 651 652 std::vector<int32> kernel_sizes; 653 TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes)); 654 if (kernel_sizes.size() != 4) { 655 return errors::InvalidArgument( 656 "AvgPool requires the ksize attribute to contain 4 values, but got: ", 657 kernel_sizes.size()); 658 } 659 660 int32 stride_rows = GetTensorDim(strides, data_format, 'H'); 661 int32 stride_cols = GetTensorDim(strides, data_format, 'W'); 662 int32 kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H'); 663 int32 kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W'); 664 665 constexpr int num_spatial_dims = 2; 666 DimensionHandle batch_size_dim = c->Dim( 667 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N')); 668 DimensionHandle in_rows_dim = c->Dim( 669 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H')); 670 DimensionHandle in_cols_dim = c->Dim( 671 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W')); 672 DimensionHandle depth_dim = c->Dim( 673 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C')); 674 675 Padding padding; 676 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); 677 678 // TODO(mrry,shlens): Raise an error if the stride would cause 679 // information in the input to be ignored. This will require a change 680 // in the kernel implementation. 681 682 DimensionHandle output_rows, output_cols; 683 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( 684 c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows)); 685 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( 686 c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols)); 687 688 ShapeHandle output_shape; 689 TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim, 690 {output_rows, output_cols}, depth_dim, 691 &output_shape, c)); 692 c->set_output(0, output_shape); 693 return Status::OK(); 694 } 695 696 Status FusedBatchNormShape(shape_inference::InferenceContext* c) { 697 ShapeHandle x; 698 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &x)); 699 700 bool is_training; 701 TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training)); 702 int number_inputs = (is_training) ? 3 : 5; 703 string data_format; 704 TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format)); 705 DimensionHandle channel_dim = 706 (data_format == "NHWC") ? c->Dim(x, 3) : c->Dim(x, 1); 707 708 // covers scale, offset, and if is_training is false, mean, variance 709 for (int i = 1; i < number_inputs; ++i) { 710 ShapeHandle vec; 711 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec)); 712 TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim)); 713 } 714 715 ShapeHandle y; 716 if (data_format == "NHWC") { 717 TF_RETURN_IF_ERROR(c->ReplaceDim(x, 3, channel_dim, &y)); 718 } else { 719 TF_RETURN_IF_ERROR(c->ReplaceDim(x, 1, channel_dim, &y)); 720 } 721 c->set_output(0, y); 722 ShapeHandle vector_shape = c->Vector(channel_dim); 723 c->set_output(1, vector_shape); 724 c->set_output(2, vector_shape); 725 c->set_output(3, vector_shape); 726 c->set_output(4, vector_shape); 727 return Status::OK(); 728 } 729 730 Status FusedBatchNormGradShape(shape_inference::InferenceContext* c) { 731 ShapeHandle y_backprop; 732 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &y_backprop)); 733 ShapeHandle x; 734 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &x)); 735 736 bool is_training; 737 string data_format; 738 TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training)); 739 TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format)); 740 DimensionHandle channel_dim = 741 (data_format == "NHWC") ? c->Dim(y_backprop, 3) : c->Dim(y_backprop, 1); 742 if (data_format == "NHWC") { 743 TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(x, 3), &channel_dim)); 744 } else { 745 TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(x, 1), &channel_dim)); 746 } 747 748 // covers scale, mean (reserve_space_1), variance (reserve_space_2) 749 for (int i = 2; i < 5; ++i) { 750 ShapeHandle vec; 751 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec)); 752 TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim)); 753 } 754 755 ShapeHandle x_backprop; 756 if (data_format == "NHWC") { 757 TF_RETURN_IF_ERROR(c->ReplaceDim(y_backprop, 3, channel_dim, &x_backprop)); 758 } else { 759 TF_RETURN_IF_ERROR(c->ReplaceDim(y_backprop, 1, channel_dim, &x_backprop)); 760 } 761 c->set_output(0, x_backprop); 762 c->set_output(1, c->Vector(channel_dim)); 763 c->set_output(2, c->Vector(channel_dim)); 764 // Set the correct shapes for reserve_spaces 765 // so that gradients can be performed when 766 // the op is in a symbolic condition. 767 if (is_training) { 768 c->set_output(3, c->Vector(0)); 769 c->set_output(4, c->Vector(0)); 770 } else { 771 c->set_output(3, c->Vector(channel_dim)); 772 c->set_output(4, c->Vector(channel_dim)); 773 } 774 return Status::OK(); 775 } 776 777 Status MaxPoolShape(shape_inference::InferenceContext* c) { 778 string data_format_str; 779 TensorFormat data_format; 780 Status s = c->GetAttr("data_format", &data_format_str); 781 if (s.ok()) { 782 FormatFromString(data_format_str, &data_format); 783 } else { 784 data_format = FORMAT_NHWC; 785 } 786 787 const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4; 788 ShapeHandle input_shape; 789 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape)); 790 791 TF_RETURN_IF_ERROR( 792 CheckFormatConstraintsOnShape(data_format, input_shape, "input", c)); 793 794 std::vector<int32> strides; 795 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); 796 if (strides.size() != 4) { 797 return errors::InvalidArgument( 798 "MaxPool requires the stride attribute to contain 4 values, but got: ", 799 strides.size()); 800 } 801 802 std::vector<int32> kernel_sizes; 803 TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes)); 804 if (kernel_sizes.size() != 4) { 805 return errors::InvalidArgument( 806 "MaxPool requires the ksize attribute to contain 4 values, but got: ", 807 kernel_sizes.size()); 808 } 809 810 int32 stride_depth = GetTensorDim(strides, data_format, 'C'); 811 int32 stride_rows = GetTensorDim(strides, data_format, 'H'); 812 int32 stride_cols = GetTensorDim(strides, data_format, 'W'); 813 int32 kernel_depth = GetTensorDim(kernel_sizes, data_format, 'C'); 814 int32 kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H'); 815 int32 kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W'); 816 817 constexpr int num_spatial_dims = 2; 818 DimensionHandle batch_size_dim = c->Dim( 819 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N')); 820 DimensionHandle in_rows_dim = c->Dim( 821 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H')); 822 DimensionHandle in_cols_dim = c->Dim( 823 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W')); 824 DimensionHandle in_depth_dim = c->Dim( 825 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C')); 826 827 Padding padding; 828 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); 829 830 ShapeHandle output_shape; 831 DimensionHandle output_rows, output_cols, output_depth; 832 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( 833 c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows)); 834 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( 835 c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols)); 836 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( 837 c, in_depth_dim, kernel_depth, stride_depth, padding, &output_depth)); 838 839 TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim, 840 {output_rows, output_cols}, 841 output_depth, &output_shape, c)); 842 843 c->set_output(0, output_shape); 844 return Status::OK(); 845 } 846 847 Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs) { 848 string data_format_str; 849 TensorFormat data_format; 850 Status s = c->GetAttr("data_format", &data_format_str); 851 if (s.ok()) { 852 FormatFromString(data_format_str, &data_format); 853 } else { 854 data_format = FORMAT_NHWC; 855 } 856 857 const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4; 858 ShapeHandle input_shape; 859 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape)); 860 861 TF_RETURN_IF_ERROR( 862 CheckFormatConstraintsOnShape(data_format, input_shape, "input", c)); 863 864 std::vector<int32> kernel_sizes; 865 std::vector<int32> strides; 866 867 if (c->num_inputs() + 2 == num_inputs) { 868 TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes)); 869 870 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); 871 } else { 872 // Verify shape of ksize and strides input. 873 ShapeHandle size; 874 DimensionHandle unused; 875 TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 2), 1, &size)); 876 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 4, &unused)); 877 TF_RETURN_IF_ERROR(c->WithRank(c->input(c->num_inputs() - 1), 1, &size)); 878 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 4, &unused)); 879 880 const Tensor* kernel_sizes_tensor = c->input_tensor(c->num_inputs() - 2); 881 if (kernel_sizes_tensor == nullptr) { 882 c->set_output(0, c->UnknownShape()); 883 return Status::OK(); 884 } 885 kernel_sizes.resize(kernel_sizes_tensor->shape().num_elements()); 886 auto kernel_sizes_vec = kernel_sizes_tensor->flat<int32>(); 887 std::copy_n(&kernel_sizes_vec(0), kernel_sizes.size(), 888 kernel_sizes.begin()); 889 890 const Tensor* strides_tensor = c->input_tensor(c->num_inputs() - 1); 891 if (strides_tensor == nullptr) { 892 c->set_output(0, c->UnknownShape()); 893 return Status::OK(); 894 } 895 strides.resize(strides_tensor->shape().num_elements()); 896 auto strides_vec = strides_tensor->flat<int32>(); 897 std::copy_n(&strides_vec(0), strides.size(), strides.begin()); 898 } 899 900 if (strides.size() != 4) { 901 return errors::InvalidArgument( 902 "MaxPool requires the stride attribute to contain 4 values, but " 903 "got: ", 904 strides.size()); 905 } 906 if (kernel_sizes.size() != 4) { 907 return errors::InvalidArgument( 908 "MaxPool requires the ksize attribute to contain 4 values, but got: ", 909 kernel_sizes.size()); 910 } 911 912 int32 stride_depth = GetTensorDim(strides, data_format, 'C'); 913 int32 stride_rows = GetTensorDim(strides, data_format, 'H'); 914 int32 stride_cols = GetTensorDim(strides, data_format, 'W'); 915 int32 kernel_depth = GetTensorDim(kernel_sizes, data_format, 'C'); 916 int32 kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H'); 917 int32 kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W'); 918 919 constexpr int num_spatial_dims = 2; 920 DimensionHandle batch_size_dim = c->Dim( 921 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N')); 922 DimensionHandle in_rows_dim = c->Dim( 923 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H')); 924 DimensionHandle in_cols_dim = c->Dim( 925 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W')); 926 DimensionHandle in_depth_dim = c->Dim( 927 input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C')); 928 929 Padding padding; 930 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); 931 932 ShapeHandle output_shape; 933 DimensionHandle output_rows, output_cols, output_depth; 934 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( 935 c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows)); 936 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( 937 c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols)); 938 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( 939 c, in_depth_dim, kernel_depth, stride_depth, padding, &output_depth)); 940 941 TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim, 942 {output_rows, output_cols}, 943 output_depth, &output_shape, c)); 944 945 c->set_output(0, output_shape); 946 return Status::OK(); 947 } 948 949 Status Pool3DShape(shape_inference::InferenceContext* c) { 950 ShapeHandle input_shape; 951 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape)); 952 953 string data_format; 954 Status s = c->GetAttr("data_format", &data_format); 955 956 std::vector<int32> strides; 957 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); 958 if (strides.size() != 5) { 959 return errors::InvalidArgument( 960 "Pool3D ops require the stride attribute to contain 5 values, but " 961 "got: ", 962 strides.size()); 963 } 964 965 std::vector<int32> kernel_sizes; 966 TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes)); 967 if (kernel_sizes.size() != 5) { 968 return errors::InvalidArgument( 969 "Pool3D requires the ksize attribute to contain 5 values, but got: ", 970 kernel_sizes.size()); 971 } 972 973 int32 stride_planes, stride_rows, stride_cols; 974 int32 kernel_planes, kernel_rows, kernel_cols; 975 976 if (s.ok() && data_format == "NCDHW") { 977 // Convert input_shape to NDHWC. 978 auto dim = [&](char dimension) { 979 return c->Dim(input_shape, GetTensorDimIndex<3>(FORMAT_NCHW, dimension)); 980 }; 981 input_shape = 982 c->MakeShape({{dim('N'), dim('0'), dim('1'), dim('2'), dim('C')}}); 983 stride_planes = strides[2]; 984 stride_rows = strides[3]; 985 stride_cols = strides[4]; 986 kernel_planes = kernel_sizes[2]; 987 kernel_rows = kernel_sizes[3]; 988 kernel_cols = kernel_sizes[4]; 989 } else { 990 stride_planes = strides[1]; 991 stride_rows = strides[2]; 992 stride_cols = strides[3]; 993 kernel_planes = kernel_sizes[1]; 994 kernel_rows = kernel_sizes[2]; 995 kernel_cols = kernel_sizes[3]; 996 } 997 998 DimensionHandle batch_size_dim = c->Dim(input_shape, 0); 999 DimensionHandle in_planes_dim = c->Dim(input_shape, 1); 1000 DimensionHandle in_rows_dim = c->Dim(input_shape, 2); 1001 DimensionHandle in_cols_dim = c->Dim(input_shape, 3); 1002 DimensionHandle output_depth_dim = c->Dim(input_shape, 4); 1003 1004 Padding padding; 1005 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); 1006 1007 // TODO(mrry,shlens): Raise an error if the stride would cause 1008 // information in the input to be ignored. This will require a change 1009 // in the kernel implementation. 1010 DimensionHandle output_planes, output_rows, output_cols; 1011 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( 1012 c, in_planes_dim, kernel_planes, stride_planes, padding, &output_planes)); 1013 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( 1014 c, in_rows_dim, kernel_rows, stride_rows, padding, &output_rows)); 1015 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( 1016 c, in_cols_dim, kernel_cols, stride_cols, padding, &output_cols)); 1017 1018 ShapeHandle output_shape; 1019 if (data_format == "NCDHW") { 1020 output_shape = c->MakeShape({batch_size_dim, output_depth_dim, 1021 output_planes, output_rows, output_cols}); 1022 } else { 1023 output_shape = c->MakeShape({batch_size_dim, output_planes, output_rows, 1024 output_cols, output_depth_dim}); 1025 } 1026 1027 c->set_output(0, output_shape); 1028 return Status::OK(); 1029 } 1030 1031 Status UnknownShape(shape_inference::InferenceContext* c) { 1032 for (int i = 0; i < c->num_outputs(); ++i) { 1033 c->set_output(i, c->UnknownShape()); 1034 } 1035 return Status::OK(); 1036 } 1037 1038 template <typename T> 1039 Status ReductionShapeHelper(const Tensor* reduction_indices_t, 1040 const int32 input_rank, 1041 std::set<int64>& true_indices) { 1042 auto reduction_indices = reduction_indices_t->flat<T>(); 1043 for (int i = 0; i < reduction_indices_t->NumElements(); ++i) { 1044 const T reduction_index = reduction_indices(i); 1045 if (reduction_index < -input_rank || reduction_index >= input_rank) { 1046 return errors::InvalidArgument("Invalid reduction dimension ", 1047 reduction_index, " for input with ", 1048 input_rank, " dimensions."); 1049 } 1050 1051 auto wrapped_index = reduction_index; 1052 if (wrapped_index < 0) { 1053 wrapped_index += input_rank; 1054 } 1055 1056 true_indices.insert(wrapped_index); 1057 } 1058 return Status::OK(); 1059 } 1060 1061 Status ReductionShape(InferenceContext* c) { 1062 ShapeHandle input = c->input(0); 1063 1064 ShapeHandle indices; 1065 // Older versions of TensorFlow accidentally allowed higher rank tensors like 1066 // [[1,2]] or [[1],[2]] to represent axis=[1,2]. 1067 if (c->graph_def_version() < 21) { 1068 indices = c->input(1); 1069 } else { 1070 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &indices)); 1071 } 1072 1073 bool keep_dims; 1074 TF_RETURN_IF_ERROR(c->GetAttr("keep_dims", &keep_dims)); 1075 1076 const Tensor* reduction_indices_t = c->input_tensor(1); 1077 if (reduction_indices_t == nullptr || !c->RankKnown(input)) { 1078 // If we do not have the reduction values at runtime, or the 1079 // rank of the input, we don't know the output shape. 1080 1081 if (keep_dims && c->RankKnown(input)) { 1082 // output rank matches input input if <keep_dims>. 1083 c->set_output(0, c->UnknownShapeOfRank(c->Rank(input))); 1084 return Status::OK(); 1085 } else { 1086 return shape_inference::UnknownShape(c); 1087 } 1088 } 1089 1090 const int32 input_rank = c->Rank(input); 1091 std::set<int64> true_indices; 1092 if (reduction_indices_t->dtype() == DataType::DT_INT32) { 1093 TF_RETURN_IF_ERROR(ReductionShapeHelper<int32>(reduction_indices_t, 1094 input_rank, true_indices)); 1095 } else if (reduction_indices_t->dtype() == DataType::DT_INT64) { 1096 TF_RETURN_IF_ERROR(ReductionShapeHelper<int64>(reduction_indices_t, 1097 input_rank, true_indices)); 1098 } else { 1099 return errors::InvalidArgument( 1100 "reduction_indices can only be int32 or int64"); 1101 } 1102 1103 std::vector<DimensionHandle> dims; 1104 for (int i = 0; i < input_rank; ++i) { 1105 if (true_indices.count(i) > 0) { 1106 if (keep_dims) { 1107 dims.emplace_back(c->MakeDim(1)); 1108 } 1109 } else { 1110 dims.emplace_back(c->Dim(input, i)); 1111 } 1112 } 1113 1114 c->set_output(0, c->MakeShape(dims)); 1115 return Status::OK(); 1116 } 1117 1118 Status ConcatShapeHelper(InferenceContext* c, int start_value_index, 1119 int end_value_index, int dim_index) { 1120 ShapeHandle unused; 1121 TF_RETURN_IF_ERROR(c->WithRank(c->input(dim_index), 0, &unused)); 1122 const Tensor* concat_dim_t = c->input_tensor(dim_index); 1123 if (concat_dim_t == nullptr) { 1124 // Return an unknown shape with same rank as inputs, or an unknown rank 1125 // if no input's rank is known. 1126 1127 // Find rank. 1128 int32 rank = InferenceContext::kUnknownRank; 1129 for (int i = start_value_index; i < end_value_index; ++i) { 1130 if (rank == InferenceContext::kUnknownRank) rank = c->Rank(c->input(i)); 1131 if (rank != InferenceContext::kUnknownRank) { 1132 break; 1133 } 1134 } 1135 if (rank == InferenceContext::kUnknownRank) { 1136 c->set_output(0, c->UnknownShape()); 1137 return Status::OK(); 1138 } else if (rank == 0) { 1139 return errors::InvalidArgument( 1140 "Can't concatenate scalars (use tf.stack instead)"); 1141 } else { 1142 for (int i = start_value_index; i < end_value_index; ++i) { 1143 // Check that all the inputs are of the correct rank. 1144 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), rank, &unused)); 1145 } 1146 } 1147 // Build result of <rank> different unknown dims. 1148 std::vector<DimensionHandle> dims; 1149 dims.reserve(rank); 1150 for (int i = 0; i < rank; ++i) dims.push_back(c->UnknownDim()); 1151 c->set_output(0, c->MakeShape(dims)); 1152 return Status::OK(); 1153 } 1154 1155 // Merge all the non-concat dims, and sum the concat dim to make an output 1156 // shape. 1157 const int32 concat_dim = concat_dim_t->scalar<int32>()(); 1158 1159 // Minimum required number of dimensions. 1160 const int min_rank = concat_dim < 0 ? -concat_dim : concat_dim + 1; 1161 1162 ShapeHandle output_before; 1163 ShapeHandle output_after; 1164 1165 ShapeHandle input = c->input(end_value_index - 1); 1166 TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, min_rank, &input)); 1167 TF_RETURN_IF_ERROR(c->Subshape(input, 0, concat_dim, &output_before)); 1168 DimensionHandle output_middle = c->Dim(input, concat_dim); 1169 if (concat_dim == -1) { 1170 output_after = c->Scalar(); // no dimensions. 1171 } else { 1172 TF_RETURN_IF_ERROR(c->Subshape(input, concat_dim + 1, &output_after)); 1173 } 1174 1175 for (int i = end_value_index - 2; i >= start_value_index; --i) { 1176 ShapeHandle before; 1177 ShapeHandle after; 1178 input = c->input(i); 1179 TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, min_rank, &input)); 1180 TF_RETURN_IF_ERROR(c->Subshape(input, 0, concat_dim, &before)); 1181 DimensionHandle middle = c->Dim(input, concat_dim); 1182 if (concat_dim == -1) { 1183 after = c->Scalar(); 1184 } else { 1185 TF_RETURN_IF_ERROR(c->Subshape(input, concat_dim + 1, &after)); 1186 } 1187 1188 TF_RETURN_IF_ERROR(c->Merge(before, output_before, &output_before)); 1189 TF_RETURN_IF_ERROR(c->Add(output_middle, middle, &output_middle)); 1190 TF_RETURN_IF_ERROR(c->Merge(after, output_after, &output_after)); 1191 } 1192 1193 ShapeHandle s; 1194 TF_RETURN_IF_ERROR( 1195 c->Concatenate(output_before, c->Vector(output_middle), &s)); 1196 TF_RETURN_IF_ERROR(c->Concatenate(s, output_after, &s)); 1197 c->set_output(0, s); 1198 return Status::OK(); 1199 } 1200 1201 Status ConcatShape(InferenceContext* c, int num_inputs_to_concat) { 1202 return ConcatShapeHelper(c, 1 /* start_value_index */, 1203 1 + num_inputs_to_concat /* end_value_index */, 1204 0 /* dim_index */); 1205 } 1206 1207 Status ConcatV2Shape(InferenceContext* c) { 1208 return ConcatShapeHelper(c, 0 /* start_value_index */, 1209 c->num_inputs() - 1 /* end_value_index */, 1210 c->num_inputs() - 1 /* dim_index */); 1211 } 1212 1213 Status BroadcastBinaryOpShapeFn(InferenceContext* c) { 1214 ShapeHandle shape_x = c->input(0); 1215 ShapeHandle shape_y = c->input(1); 1216 if (!c->RankKnown(shape_x) || !c->RankKnown(shape_y)) { 1217 c->set_output(0, c->UnknownShape()); 1218 return Status::OK(); 1219 } 1220 const int32 rank_x = c->Rank(shape_x); 1221 const int32 rank_y = c->Rank(shape_y); 1222 const int32 rank_out = std::max(rank_x, rank_y); 1223 1224 // To compute the broadcast dimensions, we zip together shape_x and shape_y 1225 // and 1226 // pad with 1 to make them the same length. 1227 std::vector<DimensionHandle> dims; 1228 DimensionHandle dim_one; 1229 if (rank_x != rank_y) dim_one = c->MakeDim(1); 1230 for (int i = 0; i < rank_out; ++i) { 1231 const auto dim_x = i < (rank_out - rank_x) 1232 ? dim_one 1233 : c->Dim(shape_x, i - (rank_out - rank_x)); 1234 const bool dim_y_is_one = (i < (rank_out - rank_y)); 1235 const auto dim_y = 1236 dim_y_is_one ? dim_one : c->Dim(shape_y, i - (rank_out - rank_y)); 1237 if (!c->ValueKnown(dim_x) || !c->ValueKnown(dim_y)) { 1238 // One or both dimensions is unknown. 1239 // 1240 // - If either dimension is greater than 1, we assume that the program is 1241 // correct, and the other dimension will be broadcast to match it. 1242 // TODO(cwhipkey): For shape inference, if we eliminate the shape checks 1243 // in C++ op code, we must still assert that the unknown dim is either 1 1244 // or the same as the known dim. 1245 // - If either dimension is 1, the other dimension is the output. 1246 if (c->Value(dim_x) > 1) { 1247 dims.push_back(dim_x); 1248 } else if (c->Value(dim_y) > 1) { 1249 dims.push_back(dim_y); 1250 } else if (c->Value(dim_x) == 1) { 1251 dims.push_back(dim_y); 1252 } else if (c->Value(dim_y) == 1) { 1253 dims.push_back(dim_x); 1254 } else if (dim_y.SameHandle(dim_x)) { 1255 dims.push_back(dim_x); 1256 } else { 1257 dims.push_back(c->UnknownDim()); 1258 } 1259 } else if (c->Value(dim_x) == 1 || c->Value(dim_y) == 1) { 1260 if (c->Value(dim_x) == 1 && !dim_y_is_one) { 1261 // We will broadcast dim_x to dim_y. 1262 dims.push_back(dim_y); 1263 } else { 1264 DCHECK_EQ(c->Value(dim_y), 1); 1265 // We will broadcast dim_y to dim_x. 1266 dims.push_back(dim_x); 1267 } 1268 } else { 1269 DimensionHandle dim; 1270 TF_RETURN_IF_ERROR(c->Merge(dim_x, dim_y, &dim)); 1271 dims.push_back(dim); 1272 } 1273 } 1274 1275 c->set_output(0, c->MakeShape(dims)); 1276 return Status::OK(); 1277 } 1278 1279 Status RandomShape(shape_inference::InferenceContext* c) { 1280 shape_inference::ShapeHandle out; 1281 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out)); 1282 c->set_output(0, out); 1283 return Status::OK(); 1284 } 1285 1286 Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape, 1287 ShapeHandle values_shape, ShapeHandle shape_shape) { 1288 // Validate ranks. 1289 ShapeHandle unused_shape; 1290 TF_RETURN_IF_ERROR(c->WithRank(indices_shape, 2, &unused_shape)); 1291 TF_RETURN_IF_ERROR(c->WithRank(values_shape, 1, &unused_shape)); 1292 TF_RETURN_IF_ERROR(c->WithRank(shape_shape, 1, &unused_shape)); 1293 1294 // Number of elements in indices and values must match. 1295 DimensionHandle num_index_elements_dim = c->Dim(indices_shape, 0); 1296 if (c->ValueKnown(num_index_elements_dim)) { 1297 DimensionHandle num_values_elements_dim = c->Dim(values_shape, 0); 1298 if (c->ValueKnown(num_values_elements_dim)) { 1299 int64 num_index_elements = c->Value(num_index_elements_dim); 1300 int64 num_values_elements = c->Value(num_values_elements_dim); 1301 if (num_index_elements != num_values_elements) { 1302 return errors::InvalidArgument("Number of elements in index (", 1303 num_index_elements, ") and values (", 1304 num_values_elements, ") do not match."); 1305 } 1306 } 1307 } 1308 1309 // Rank embedded in indices must match shape. 1310 DimensionHandle index_rank_dim = c->Dim(indices_shape, 1); 1311 if (c->ValueKnown(index_rank_dim)) { 1312 DimensionHandle shape_rank_dim = c->Dim(shape_shape, 0); 1313 if (c->ValueKnown(shape_rank_dim)) { 1314 int64 index_rank = c->Value(index_rank_dim); 1315 int32 shape_rank = c->Value(shape_rank_dim); 1316 if (index_rank != shape_rank) { 1317 return errors::InvalidArgument("Index rank (", index_rank, 1318 ") and shape rank (", shape_rank, 1319 ") do not match."); 1320 } 1321 } 1322 } 1323 1324 return Status::OK(); 1325 } 1326 1327 Status ScatterNdUpdateShape(InferenceContext* c) { 1328 ShapeHandle input_shape = c->input(0); 1329 if (c->input_handle_shapes_and_types(0) != nullptr) { 1330 input_shape = (*c->input_handle_shapes_and_types(0))[0].shape; 1331 } 1332 ShapeHandle indices_shape; 1333 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices_shape)); 1334 ShapeHandle updates_shape; 1335 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 1, &updates_shape)); 1336 1337 if (c->Value(c->NumElements(input_shape)) == 0 && 1338 (c->Value(c->NumElements(indices_shape)) > 0 || 1339 c->Value(c->NumElements(updates_shape)) > 0)) { 1340 return errors::InvalidArgument( 1341 "Indices and updates specified for empty output shape"); 1342 } 1343 1344 if (c->RankKnown(indices_shape) && c->RankKnown(updates_shape)) { 1345 const int64 num_outer_dims = c->Rank(indices_shape) - 1; 1346 const DimensionHandle index_size = c->Dim(indices_shape, -1); 1347 1348 // We can only do more validation if the last dimension of indices 1349 // is a known value. 1350 if (c->ValueKnown(index_size)) { 1351 const int64 ix = c->Value(index_size); 1352 ShapeHandle unused; 1353 ShapeHandle prefix_indices; 1354 TF_RETURN_IF_ERROR( 1355 c->Subshape(indices_shape, 0, num_outer_dims, &prefix_indices)); 1356 ShapeHandle prefix_updates; 1357 TF_RETURN_IF_ERROR( 1358 c->Subshape(updates_shape, 0, num_outer_dims, &prefix_updates)); 1359 1360 Status s = c->Merge(prefix_indices, prefix_updates, &unused); 1361 if (!s.ok()) { 1362 return errors::InvalidArgument( 1363 "The outer ", num_outer_dims, 1364 " dimensions of indices.shape=", c->DebugString(indices_shape), 1365 " must match the outer ", num_outer_dims, 1366 " dimensions of updates.shape=", c->DebugString(updates_shape), 1367 ": ", s.error_message()); 1368 } 1369 1370 ShapeHandle input_suffix; 1371 TF_RETURN_IF_ERROR(c->Subshape(input_shape, ix, &input_suffix)); 1372 ShapeHandle suffix_updates; 1373 TF_RETURN_IF_ERROR( 1374 c->Subshape(updates_shape, num_outer_dims, &suffix_updates)); 1375 s = c->Merge(input_suffix, suffix_updates, &unused); 1376 if (!s.ok()) { 1377 return errors::InvalidArgument( 1378 "The inner ", c->Rank(input_shape) - ix, 1379 " dimensions of input.shape=", c->DebugString(input_shape), 1380 " must match the inner ", c->Rank(updates_shape) - num_outer_dims, 1381 " dimensions of updates.shape=", c->DebugString(updates_shape), 1382 ": ", s.error_message()); 1383 } 1384 } 1385 } 1386 1387 if (c->input_handle_shapes_and_types(0) == nullptr) { 1388 c->set_output(0, input_shape); 1389 } 1390 return Status::OK(); 1391 } 1392 1393 Status ExplicitShape(InferenceContext* c) { 1394 PartialTensorShape shape; 1395 TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape)); 1396 ShapeHandle output_shape; 1397 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &output_shape)); 1398 c->set_output(0, output_shape); 1399 return Status::OK(); 1400 } 1401 1402 } // namespace shape_inference 1403 1404 } // namespace tensorflow 1405