1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/tf2xla/lib/util.h" 17 #include "tensorflow/compiler/tf2xla/type_util.h" 18 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 19 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 20 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 21 #include "tensorflow/compiler/xla/util.h" 22 #include "tensorflow/core/framework/op_kernel.h" 23 24 namespace tensorflow { 25 namespace { 26 27 // Create a diagonal / batch diagonal matrix with 'input' on the diagonal. 28 xla::StatusOr<xla::ComputationDataHandle> CreateDiagonal( 29 const xla::ComputationDataHandle& input, int64 last_dim_size, 30 tensorflow::gtl::ArraySlice<int64> other_dims, XlaOpKernelContext* ctx, 31 xla::ComputationBuilder* builder) { 32 // Create two matrices that have the following forms, and compare them: 33 // 34 // [[0, 0, 0, 0] [[0, 1, 2, 3] 35 // [1, 1, 1, 1] [0, 1, 2, 3] 36 // [2, 2, 2, 2] [0, 1, 2, 3] 37 // [3, 3, 3, 3]] [0, 1, 2, 3]] 38 // 39 // This produces a predicate matrix of the right size, with "true" on the 40 // diagonal. 41 xla::ComputationDataHandle iota; 42 TF_RETURN_IF_ERROR( 43 XlaHelpers::Iota(builder, DataType::DT_INT32, last_dim_size, &iota)); 44 xla::ComputationDataHandle iota_broadcast = 45 builder->Broadcast(iota, {last_dim_size}); 46 xla::ComputationDataHandle mask = builder->Eq(iota_broadcast, iota, {0}); 47 48 // If this is a batched diagonal, broadcast the mask across the other 49 // dimensions. 50 if (!other_dims.empty()) { 51 mask = builder->Broadcast(mask, other_dims); 52 } 53 54 // Broadcast the input, and then use the mask computed above to select the 55 // diagonal: 56 // e.g, in 2D: 57 // [[t, f, f] [[1, 1, 1] [[0, 0, 0] [[1, 0, 0] 58 // select( [f, t, f] , [4, 4, 4] , [0, 0, 0] ) = [0, 4, 0] 59 // [f, f, t]] [9, 9, 9]] [0, 0, 0]] [0, 0, 9]] 60 // 61 // Broadcasting the input is less-than-trivial, since we need to broadcast 62 // into a "middle" dimension. We can do this with a reshape + implicit 63 // broadcast. 64 // TODO(b/30112114): Replace with in-dim broadcast when those are supported. 65 std::vector<int64> broadcast_dims(other_dims.begin(), other_dims.end()); 66 broadcast_dims.push_back(1LL); 67 broadcast_dims.push_back(last_dim_size); 68 xla::ComputationDataHandle input_broadcast = 69 builder->Reshape(input, broadcast_dims); 70 71 broadcast_dims[broadcast_dims.size() - 2] = last_dim_size; 72 xla::PrimitiveType element_type; 73 TF_RETURN_IF_ERROR( 74 DataTypeToPrimitiveType(ctx->input_type(0), &element_type)); 75 auto broadcast_shape = 76 xla::ShapeUtil::MakeShape(element_type, broadcast_dims); 77 xla::ComputationDataHandle zeros = Zeros(builder, broadcast_shape); 78 79 input_broadcast = builder->Add(input_broadcast, zeros); 80 return builder->Select(mask, input_broadcast, zeros); 81 } 82 83 class DiagOp : public XlaOpKernel { 84 public: 85 explicit DiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 86 87 void Compile(XlaOpKernelContext* ctx) override { 88 xla::ComputationBuilder* builder = ctx->builder(); 89 90 OP_REQUIRES(ctx, ctx->num_inputs() >= 1, 91 errors::InvalidArgument("Diag op must have at an input")); 92 const TensorShape input_shape = ctx->InputShape(0); 93 94 auto dims = input_shape.dim_sizes(); 95 OP_REQUIRES(ctx, !dims.empty(), 96 errors::InvalidArgument("Expected 1 <= dims, got shape ", 97 input_shape.DebugString())); 98 99 xla::ComputationDataHandle input = ctx->Input(0); 100 101 // Picture: 102 // tf.diag([1, 2, 3, 4]) ==> [[1, 0, 0, 0] 103 // [0, 2, 0, 0] 104 // [0, 0, 3, 0] 105 // [0, 0, 0, 4]] 106 107 // Flattens the input to 1D. 108 int64 size = input_shape.num_elements(); 109 input = builder->Reshape(input, {size}); 110 111 // Create an R2 with the R1 diagonal. 112 auto diag_or_status = 113 CreateDiagonal(input, size, /*other_dims=*/{}, ctx, builder); 114 OP_REQUIRES_OK(ctx, diag_or_status.status()); 115 xla::ComputationDataHandle diag = diag_or_status.ValueOrDie(); 116 117 // Reshapes to the final shape. 118 std::vector<int64> new_dims(dims.size() * 2); 119 std::copy(dims.begin(), dims.end(), new_dims.begin()); 120 std::copy(dims.begin(), dims.end(), new_dims.begin() + dims.size()); 121 diag = builder->Reshape(diag, new_dims); 122 123 ctx->SetOutput(0, diag); 124 } 125 }; 126 127 REGISTER_XLA_OP(Name("Diag"), DiagOp); 128 129 class DiagPartOp : public XlaOpKernel { 130 public: 131 explicit DiagPartOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 132 133 void Compile(XlaOpKernelContext* ctx) override { 134 xla::ComputationBuilder* builder = ctx->builder(); 135 136 const TensorShape input_shape = ctx->InputShape(0); 137 auto dims = input_shape.dim_sizes(); 138 139 int num_dims = dims.size(); 140 const int out_dims = num_dims / 2; 141 142 OP_REQUIRES(ctx, 2 <= num_dims, 143 errors::InvalidArgument("Expected 2 <= dims, got shape ", 144 input_shape.DebugString())); 145 OP_REQUIRES(ctx, num_dims % 2 == 0, 146 errors::InvalidArgument("The input tensor must have even rank; " 147 "got shape ", 148 input_shape.DebugString())); 149 int64 new_size = 1; 150 std::vector<int64> new_dims; 151 for (int i = 0; i < out_dims; i++) { 152 OP_REQUIRES( 153 ctx, dims[i] == dims[i + out_dims], 154 errors::InvalidArgument("Invalid shape ", input_shape.DebugString(), 155 ": dimensions ", i, " and ", i + out_dims, 156 " do not match.")); 157 new_size *= dims[i]; 158 new_dims.push_back(dims[i]); 159 } 160 161 xla::ComputationDataHandle diag = ctx->Input(0); 162 163 // TODO(b/30878775): use Slice with strides when supported, in place of 164 // the Pad -> Reshape -> Slice. 165 166 // Picture: 167 // [[1, 0, 0, 0] pad and reshape to [[1, 0, 0, 0, 0], 168 // [0, 2, 0, 0] =================> [2, 0, 0, 0, 0], 169 // [0, 0, 3, 0] [3, 0, 0, 0, 0], 170 // [0, 0, 0, 4]] [4, 0, 0, 0, 0]] 171 // and then slice out the first column. 172 173 // Flattens the input to 1D. 174 int64 size = input_shape.num_elements(); 175 diag = builder->Reshape(diag, {size}); 176 177 // Adds padding after the last element of 'new_size'. 178 xla::PaddingConfig config; 179 auto* dim = config.add_dimensions(); 180 dim->set_edge_padding_high(new_size); 181 auto zero = XlaHelpers::Zero(builder, input_type(0)); 182 diag = builder->Pad(diag, zero, config); 183 184 // Reshapes so the diagonal is now in the first column. 185 diag = builder->Reshape(diag, {new_size, new_size + 1}); 186 187 // Slices out the first column and reshapes to the final shape. 188 diag = builder->Slice(diag, {0, 0}, {new_size, 1}, {1, 1}); 189 diag = builder->Reshape(diag, new_dims); 190 191 ctx->SetOutput(0, diag); 192 } 193 }; 194 195 REGISTER_XLA_OP(Name("DiagPart"), DiagPartOp); 196 197 class MatrixDiagOp : public XlaOpKernel { 198 public: 199 explicit MatrixDiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 200 201 void Compile(XlaOpKernelContext* ctx) override { 202 xla::ComputationBuilder* builder = ctx->builder(); 203 204 OP_REQUIRES(ctx, ctx->num_inputs() >= 1, 205 errors::InvalidArgument("MatrixDiag op must have at an input")); 206 const TensorShape input_shape = ctx->InputShape(0); 207 208 auto dims = input_shape.dim_sizes(); 209 OP_REQUIRES(ctx, !dims.empty(), 210 errors::InvalidArgument("Expected 1 <= dims, got shape ", 211 input_shape.DebugString())); 212 213 xla::ComputationDataHandle diag = ctx->Input(0); 214 215 int last_dim = dims.size() - 1; 216 int64 last_dim_size = input_shape.dim_size(last_dim); 217 tensorflow::gtl::ArraySlice<int64> other_dims(dims); 218 other_dims.pop_back(); 219 220 auto diag_or_status = 221 CreateDiagonal(diag, last_dim_size, other_dims, ctx, builder); 222 OP_REQUIRES_OK(ctx, diag_or_status.status()); 223 diag = diag_or_status.ValueOrDie(); 224 ctx->SetOutput(0, diag); 225 } 226 }; 227 228 REGISTER_XLA_OP(Name("MatrixDiag"), MatrixDiagOp); 229 230 class MatrixDiagPartOp : public XlaOpKernel { 231 public: 232 explicit MatrixDiagPartOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 233 234 void Compile(XlaOpKernelContext* ctx) override { 235 xla::ComputationBuilder* builder = ctx->builder(); 236 237 const TensorShape input_shape = ctx->InputShape(0); 238 auto dims = input_shape.dim_sizes(); 239 240 OP_REQUIRES(ctx, 2 <= dims.size(), 241 errors::InvalidArgument("Expected 2 <= dims, got shape ", 242 input_shape.DebugString())); 243 244 xla::ComputationDataHandle diag = ctx->Input(0); 245 246 int last_dim = dims.size() - 1; 247 int64 last_dim_size = dims[last_dim]; 248 249 // The smaller of the last two dimension sizes. 250 int64 smaller_dim_size = std::min(dims[last_dim - 1], dims[last_dim]); 251 252 // TODO(b/30878775): use Slice with strides when supported, in place of 253 // the Pad -> Reshape -> Slice. 254 255 // Picture: for each 2D matrix in the tensor's last two dimensions: 256 // [[1, 0, 0, 0] pad and reshape to [[1, 0, 0, 0, 0], 257 // [0, 2, 0, 0] =================> [2, 0, 0, 0, 0], 258 // [0, 0, 3, 0]] [3, 0, 0, 0, 0], 259 // and then slice out the first column. 260 // 261 // Another example, with tall and narrow input. 262 // [[1, 0] pad and reshape to [[1, 0, 0], 263 // [0, 2] =================> [2, 0, 0]] 264 // [0, 0] 265 // [0, 0]] 266 267 // Collapses the last two dimensions. 268 std::vector<int64> flattened_dims(dims.begin(), dims.end() - 1); 269 flattened_dims.back() *= dims.back(); 270 diag = builder->Reshape(diag, flattened_dims); 271 272 // Slices or pads the last dimension to 'target_size'. 273 int64 actual_size = flattened_dims.back(); 274 int64 target_size = smaller_dim_size * (last_dim_size + 1); 275 if (actual_size < target_size) { 276 xla::PaddingConfig config = 277 xla::MakeNoPaddingConfig(flattened_dims.size()); 278 auto* dim = config.mutable_dimensions(flattened_dims.size() - 1); 279 dim->set_edge_padding_high(target_size - actual_size); 280 auto zero = XlaHelpers::Zero(builder, input_type(0)); 281 diag = builder->Pad(diag, zero, config); 282 } else if (actual_size > target_size) { 283 std::vector<int64> start(flattened_dims.size(), 0); 284 std::vector<int64> limits(flattened_dims.begin(), flattened_dims.end()); 285 std::vector<int64> strides(flattened_dims.size(), 1); 286 limits[flattened_dims.size() - 1] = target_size; 287 diag = builder->Slice(diag, start, limits, strides); 288 } 289 290 // Reshape so the target values are in the first position of the last 291 // dimension. 292 std::vector<int64> unflattened_dims(dims.begin(), dims.end()); 293 dims[last_dim - 1] = smaller_dim_size; 294 dims[last_dim] = last_dim_size + 1; 295 diag = builder->Reshape(diag, dims); 296 297 // Slices out the first column and reshapes to the final shape. 298 std::vector<int64> start(dims.size(), 0); 299 std::vector<int64> limits(dims.begin(), dims.end()); 300 std::vector<int64> strides(dims.size(), 1); 301 limits[last_dim] = 1; 302 diag = builder->Slice(diag, start, limits, strides); 303 304 // Collapses away the last dimension. 305 dims.pop_back(); 306 diag = builder->Reshape(diag, dims); 307 308 ctx->SetOutput(0, diag); 309 } 310 }; 311 312 REGISTER_XLA_OP(Name("MatrixDiagPart"), MatrixDiagPartOp); 313 314 } // namespace 315 } // namespace tensorflow 316