1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/python/local_computation_builder.h" 17 #include "tensorflow/compiler/xla/executable_run_options.h" 18 #include "tensorflow/compiler/xla/util.h" 19 #include "tensorflow/core/platform/default/thread_annotations.h" 20 21 namespace xla { 22 23 namespace swig { 24 25 // TODO(b/34473877) Ideally XLA would support AllReduce among arbitrary sets of 26 // device handles instead of needing to set the number of replicas at XLA 27 // service initialization time. 28 tensorflow::mutex g_local_client_mutex(tensorflow::LINKER_INITIALIZED); 29 int g_replica_count GUARDED_BY(g_local_client_mutex) = 1; 30 LocalClient* g_local_client GUARDED_BY(g_local_client_mutex) = nullptr; 31 32 Status InitializeReplicaCount(int replica_count) { 33 if (replica_count < 1) { 34 return InvalidArgument("Replica count must be >= 1; got %d.", 35 replica_count); 36 } 37 tensorflow::mutex_lock lock(g_local_client_mutex); 38 if (g_local_client != nullptr) { 39 return FailedPrecondition( 40 "Attempted to set the replica count to %d, but a local XLA service was " 41 "previously created with a replica count of %d.", 42 replica_count, g_replica_count); 43 } 44 g_replica_count = replica_count; 45 return Status::OK(); 46 } 47 48 int GetReplicaCount() { 49 tensorflow::mutex_lock lock(g_local_client_mutex); 50 return g_replica_count; 51 } 52 53 LocalClient* GetOrCreateLocalClient() { 54 tensorflow::mutex_lock lock(g_local_client_mutex); 55 if (g_local_client != nullptr) { 56 return g_local_client; 57 } 58 LocalClientOptions options; 59 options.set_number_of_replicas(g_replica_count); 60 g_local_client = ClientLibrary::GetOrCreateLocalClient(options).ValueOrDie(); 61 CHECK(g_local_client != nullptr); 62 return g_local_client; 63 } 64 65 Status TransferToInfeedLocal(const Literal& literal) { 66 VLOG(1) << "Infeeding literal without replica number; shape: " 67 << literal.shape(); 68 LocalClient* client = GetOrCreateLocalClient(); 69 return client->TransferToInfeedLocal(literal, /*device_ordinal=*/0); 70 } 71 72 Status TransferToInfeedLocalReplica(const Literal& literal, 73 int replica_number) { 74 VLOG(1) << "Infeeding shape " << literal.shape() 75 << " to replica number: " << replica_number; 76 LocalClient* client = GetOrCreateLocalClient(); 77 TF_ASSIGN_OR_RETURN(int device_ordinal, 78 client->ReplicaNumberToDeviceOrdinal(replica_number)); 79 return client->TransferToInfeedLocal(literal, device_ordinal); 80 } 81 82 StatusOr<std::unique_ptr<Literal>> TransferFromOutfeedLocalReplica( 83 const Shape& shape, int replica_number) { 84 VLOG(1) << "Outfeeding literal from replica number: " << replica_number 85 << " shape: " << shape; 86 LocalClient* client = GetOrCreateLocalClient(); 87 TF_ASSIGN_OR_RETURN(int device_ordinal, 88 client->ReplicaNumberToDeviceOrdinal(replica_number)); 89 return client->TransferFromOutfeedLocal(shape, device_ordinal); 90 } 91 92 LocalShapedBuffer::LocalShapedBuffer( 93 std::unique_ptr<ScopedShapedBuffer> shaped_buffer) 94 : shaped_buffer_(std::move(shaped_buffer)) {} 95 96 const std::unique_ptr<ScopedShapedBuffer>& LocalShapedBuffer::shaped_buffer() 97 const { 98 return shaped_buffer_; 99 } 100 101 static StatusOr<std::unique_ptr<ScopedShapedBuffer>> ToBuffer( 102 LocalClient* client, int device_ordinal, const Literal& arg) { 103 return client->LiteralToShapedBuffer(arg, device_ordinal, 104 client->backend().memory_allocator()); 105 } 106 107 /* static */ 108 LocalShapedBuffer* LocalShapedBuffer::FromLiteral( 109 const Literal& argument, 110 const tensorflow::gtl::optional<Shape>& shape_with_layout) { 111 LocalClient* client = GetOrCreateLocalClient(); 112 std::unique_ptr<ScopedShapedBuffer> buf; 113 if (shape_with_layout) { 114 std::unique_ptr<Literal> relaid = 115 argument.Relayout(shape_with_layout.value()); 116 buf = ToBuffer(client, /*device_ordinal=*/0, *relaid).ConsumeValueOrDie(); 117 } else { 118 buf = ToBuffer(client, /*device_ordinal=*/0, argument).ConsumeValueOrDie(); 119 } 120 return new LocalShapedBuffer(std::move(buf)); 121 } 122 123 std::unique_ptr<Literal> LocalShapedBuffer::ToLiteral() const { 124 LocalClient* client = GetOrCreateLocalClient(); 125 return client->ShapedBufferToLiteral(*shaped_buffer()).ConsumeValueOrDie(); 126 } 127 128 CompiledLocalComputation::CompiledLocalComputation( 129 std::unique_ptr<LocalExecutable> executable) 130 : executable_(std::move(executable)) {} 131 132 StatusOr<std::unique_ptr<Literal>> CompiledLocalComputation::Execute( 133 const std::vector<Literal>& arguments, 134 const std::vector<tensorflow::gtl::optional<Shape>>& shapes_with_layout) { 135 LocalClient* client = GetOrCreateLocalClient(); 136 137 VLOG(1) << "Execution requested with " << GetReplicaCount() << " replicas."; 138 139 // Each replica populates a StatusOr result, but only replica zero actually 140 // retrieves its literal value. 141 std::vector<StatusOr<std::unique_ptr<Literal>>> results(GetReplicaCount()); 142 { 143 tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "xlarun", 144 GetReplicaCount()); 145 146 for (int replica = 0; replica < GetReplicaCount(); ++replica) { 147 pool.Schedule([this, client, replica, &arguments, &shapes_with_layout, 148 &results] { 149 StatusOr<int> device_ordinal_status = 150 client->ReplicaNumberToDeviceOrdinal(replica); 151 if (!device_ordinal_status.ok()) { 152 results[replica] = device_ordinal_status.status(); 153 return; 154 } 155 const int device_ordinal = device_ordinal_status.ValueOrDie(); 156 VLOG(3) << "Replica " << replica 157 << " mapped to device ordinal for execution: " 158 << device_ordinal; 159 160 // Transfer arguments in 161 std::vector<std::unique_ptr<ScopedShapedBuffer>> scoped_buffers; 162 scoped_buffers.reserve(arguments.size()); 163 for (int i = 0; i < arguments.size(); ++i) { 164 const Literal& argument = arguments[i]; 165 const tensorflow::gtl::optional<Shape>& shape_with_layout = 166 shapes_with_layout[i]; 167 168 StatusOr<std::unique_ptr<ScopedShapedBuffer>> pushed; 169 if (shape_with_layout) { 170 std::unique_ptr<Literal> relaid = 171 argument.Relayout(shape_with_layout.value()); 172 pushed = ToBuffer(client, device_ordinal, *relaid); 173 } else { 174 pushed = ToBuffer(client, device_ordinal, argument); 175 } 176 if (!pushed.ok()) { 177 results[replica] = pushed.status(); 178 return; 179 } 180 181 scoped_buffers.push_back(std::move(pushed).ValueOrDie()); 182 } 183 184 // Execute 185 std::vector<const ShapedBuffer*> argument_buffers; 186 argument_buffers.reserve(scoped_buffers.size()); 187 for (auto& buffer : scoped_buffers) { 188 argument_buffers.push_back(buffer.get()); 189 } 190 191 DeviceAssignment device_assignment = 192 client->backend() 193 .computation_placer() 194 ->AssignDevices(GetReplicaCount(), /*computation_count=*/1) 195 .ConsumeValueOrDie(); 196 197 ExecutableRunOptions options; 198 options.set_device_ordinal(device_ordinal); 199 options.set_allocator(client->backend().memory_allocator()); 200 options.set_inter_op_thread_pool( 201 client->backend().inter_op_thread_pool()); 202 options.set_intra_op_thread_pool( 203 client->backend().eigen_intra_op_thread_pool_device()); 204 options.set_device_assignment(&device_assignment); 205 StatusOr<std::unique_ptr<ScopedShapedBuffer>> result_buffer_status = 206 executable_->Run(argument_buffers, options); 207 if (!result_buffer_status.ok()) { 208 results[replica] = result_buffer_status.status(); 209 return; 210 } 211 212 // Transfer result out 213 results[replica] = 214 client->ShapedBufferToLiteral(*result_buffer_status.ValueOrDie()); 215 }); 216 } 217 } 218 219 for (int replica = 0; replica < GetReplicaCount(); ++replica) { 220 const auto& statusor = results[replica]; 221 if (!statusor.ok()) { 222 return InternalError( 223 "Failed running replica %d (other replicas may have failed as well): " 224 "%s.", 225 replica, statusor.status().ToString().c_str()); 226 } 227 } 228 229 return std::move(results[0]); 230 } 231 232 LocalShapedBuffer* CompiledLocalComputation::ExecuteWithShapedBuffers( 233 tensorflow::gtl::ArraySlice<LocalShapedBuffer*> argument_handles) { 234 LocalClient* client = GetOrCreateLocalClient(); 235 236 std::vector<const ShapedBuffer*> argument_buffers; 237 argument_buffers.reserve(argument_handles.size()); 238 for (auto& handle : argument_handles) { 239 argument_buffers.push_back(handle->shaped_buffer().get()); 240 } 241 242 // Execute 243 ExecutableRunOptions options; 244 options.set_allocator(client->backend().memory_allocator()); 245 options.set_inter_op_thread_pool(client->backend().inter_op_thread_pool()); 246 options.set_intra_op_thread_pool( 247 client->backend().eigen_intra_op_thread_pool_device()); 248 std::unique_ptr<ScopedShapedBuffer> result_buffer = 249 executable_->Run(argument_buffers, options).ConsumeValueOrDie(); 250 251 return new LocalShapedBuffer(std::move(result_buffer)); 252 } 253 254 LocalComputation::LocalComputation(Computation computation) 255 : computation_(std::move(computation)) {} 256 257 StatusOr<CompiledLocalComputation*> LocalComputation::Compile( 258 const std::vector<Shape>& argument_shapes, 259 const ExecutableBuildOptions* build_options) { 260 std::vector<const Shape*> argument_shape_pointers; 261 argument_shape_pointers.reserve(argument_shapes.size()); 262 for (auto& argument_shape : argument_shapes) { 263 argument_shape_pointers.push_back(&argument_shape); 264 } 265 266 LocalClient* client = GetOrCreateLocalClient(); 267 ExecutableBuildOptions options; 268 if (build_options != nullptr) { 269 options = *build_options; 270 } 271 TF_ASSIGN_OR_RETURN( 272 auto local_executable, 273 client->Compile(computation_, argument_shape_pointers, options)); 274 return new CompiledLocalComputation(std::move(local_executable)); 275 } 276 277 const Computation& LocalComputation::computation() const { 278 return computation_; 279 } 280 281 StatusOr<Shape> LocalComputation::GetReturnValueShape() const { 282 TF_ASSIGN_OR_RETURN(ProgramShape program_shape, 283 computation_.GetProgramShape()); 284 return std::move(*program_shape.mutable_result()); 285 } 286 287 LocalComputationBuilder::LocalComputationBuilder(const string& computation_name) 288 : builder_(GetOrCreateLocalClient(), computation_name) {} 289 290 void LocalComputationBuilder::SetOpMetadata(const OpMetadata& metadata) { 291 builder_.SetOpMetadata(metadata); 292 } 293 294 void LocalComputationBuilder::ClearOpMetadata() { builder_.ClearOpMetadata(); } 295 296 StatusOr<LocalComputation*> LocalComputationBuilder::Build() { 297 TF_ASSIGN_OR_RETURN(Computation computation, builder_.Build()); 298 return new LocalComputation(std::move(computation)); 299 } 300 301 ComputationDataHandle LocalComputationBuilder::Parameter(int64 parameter_number, 302 const Shape& shape, 303 const string& name) { 304 return builder_.Parameter(parameter_number, shape, name); 305 } 306 307 std::unique_ptr<Shape> LocalComputationBuilder::GetShape( 308 const ComputationDataHandle& operand) { 309 return builder_.GetShape(operand).ConsumeValueOrDie(); 310 } 311 312 StatusOr<Shape> LocalComputationBuilder::GetReturnValueShape() { 313 TF_ASSIGN_OR_RETURN(ProgramShape program_shape, builder_.GetProgramShape()); 314 return program_shape.result(); 315 } 316 317 ComputationDataHandle LocalComputationBuilder::Infeed(const Shape& shape) { 318 return builder_.Infeed(shape); 319 } 320 321 void LocalComputationBuilder::Outfeed(const ComputationDataHandle& operand, 322 const Shape& shape, 323 const string& outfeed_config) { 324 builder_.Outfeed(operand, shape, outfeed_config); 325 } 326 327 ComputationDataHandle LocalComputationBuilder::ConstantLiteral( 328 const Literal& literal) { 329 return builder_.ConstantLiteral(literal); 330 } 331 332 ComputationDataHandle LocalComputationBuilder::Broadcast( 333 const ComputationDataHandle& operand, 334 tensorflow::gtl::ArraySlice<int64> broadcast_sizes) { 335 return builder_.Broadcast(operand, broadcast_sizes); 336 } 337 338 ComputationDataHandle LocalComputationBuilder::Pad( 339 const ComputationDataHandle& operand, 340 const ComputationDataHandle& padding_value, 341 const PaddingConfig& padding_config) { 342 return builder_.Pad(operand, padding_value, padding_config); 343 } 344 345 ComputationDataHandle LocalComputationBuilder::Reshape( 346 const ComputationDataHandle& operand, 347 tensorflow::gtl::ArraySlice<int64> dimensions, 348 tensorflow::gtl::ArraySlice<int64> new_sizes) { 349 return builder_.Reshape(operand, dimensions, new_sizes); 350 } 351 352 ComputationDataHandle LocalComputationBuilder::Collapse( 353 const ComputationDataHandle& operand, 354 tensorflow::gtl::ArraySlice<int64> dimensions) { 355 return builder_.Collapse(operand, dimensions); 356 } 357 358 ComputationDataHandle LocalComputationBuilder::CrossReplicaSum( 359 const ComputationDataHandle& operand) { 360 return builder_.CrossReplicaSum(operand); 361 } 362 363 ComputationDataHandle LocalComputationBuilder::Slice( 364 const ComputationDataHandle& operand, 365 tensorflow::gtl::ArraySlice<int64> start_indices, 366 tensorflow::gtl::ArraySlice<int64> limit_indices, 367 tensorflow::gtl::ArraySlice<int64> strides) { 368 return builder_.Slice(operand, start_indices, limit_indices, strides); 369 } 370 371 ComputationDataHandle LocalComputationBuilder::DynamicSlice( 372 const ComputationDataHandle& operand, 373 const ComputationDataHandle& start_indices, 374 tensorflow::gtl::ArraySlice<int64> slice_sizes) { 375 return builder_.DynamicSlice(operand, start_indices, slice_sizes); 376 } 377 378 ComputationDataHandle LocalComputationBuilder::DynamicUpdateSlice( 379 const ComputationDataHandle& operand, const ComputationDataHandle& update, 380 const ComputationDataHandle& start_indices) { 381 return builder_.DynamicUpdateSlice(operand, update, start_indices); 382 } 383 384 ComputationDataHandle LocalComputationBuilder::ConcatInDim( 385 tensorflow::gtl::ArraySlice<ComputationDataHandle> operands, 386 int64 dimension) { 387 return builder_.ConcatInDim(operands, dimension); 388 } 389 390 ComputationDataHandle 391 LocalComputationBuilder::SelectAndScatterWithGeneralPadding( 392 const ComputationDataHandle& operand, const LocalComputation& select, 393 tensorflow::gtl::ArraySlice<int64> window_dimensions, 394 tensorflow::gtl::ArraySlice<int64> window_strides, 395 tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, 396 const ComputationDataHandle& source, 397 const ComputationDataHandle& init_value, const LocalComputation& scatter) { 398 return builder_.SelectAndScatterWithGeneralPadding( 399 operand, select.computation(), window_dimensions, window_strides, padding, 400 source, init_value, scatter.computation()); 401 } 402 403 ComputationDataHandle LocalComputationBuilder::Tuple( 404 tensorflow::gtl::ArraySlice<ComputationDataHandle> elements) { 405 return builder_.Tuple(elements); 406 } 407 408 ComputationDataHandle LocalComputationBuilder::GetTupleElement( 409 const ComputationDataHandle& tuple_data, int64 index) { 410 return builder_.GetTupleElement(tuple_data, index); 411 } 412 413 ComputationDataHandle LocalComputationBuilder::Dot( 414 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs) { 415 return builder_.Dot(lhs, rhs); 416 } 417 418 ComputationDataHandle LocalComputationBuilder::DotGeneral( 419 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 420 const DotDimensionNumbers& dimension_numbers) { 421 return builder_.DotGeneral(lhs, rhs, dimension_numbers); 422 } 423 424 ComputationDataHandle LocalComputationBuilder::ConvGeneralDilated( 425 const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, 426 tensorflow::gtl::ArraySlice<int64> window_strides, 427 tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding, 428 tensorflow::gtl::ArraySlice<int64> lhs_dilation, 429 tensorflow::gtl::ArraySlice<int64> rhs_dilation, 430 const ConvolutionDimensionNumbers& dimension_numbers) { 431 return builder_.ConvGeneralDilated(lhs, rhs, window_strides, padding, 432 lhs_dilation, rhs_dilation, 433 dimension_numbers); 434 } 435 436 ComputationDataHandle LocalComputationBuilder::ConvertElementType( 437 const ComputationDataHandle& operand, PrimitiveType new_element_type) { 438 return builder_.ConvertElementType(operand, new_element_type); 439 } 440 441 ComputationDataHandle LocalComputationBuilder::Call( 442 const LocalComputation& local_computation, 443 tensorflow::gtl::ArraySlice<ComputationDataHandle> operands) { 444 return builder_.Call(local_computation.computation(), operands); 445 } 446 447 ComputationDataHandle LocalComputationBuilder::Transpose( 448 const ComputationDataHandle& operand, 449 tensorflow::gtl::ArraySlice<int64> permutation) { 450 return builder_.Transpose(operand, permutation); 451 } 452 453 ComputationDataHandle LocalComputationBuilder::Rev( 454 const ComputationDataHandle& operand, 455 tensorflow::gtl::ArraySlice<int64> dimensions) { 456 return builder_.Rev(operand, dimensions); 457 } 458 459 ComputationDataHandle LocalComputationBuilder::Map( 460 tensorflow::gtl::ArraySlice<ComputationDataHandle> operands, 461 const LocalComputation& local_computation, 462 tensorflow::gtl::ArraySlice<int64> dimensions, 463 tensorflow::gtl::ArraySlice<ComputationDataHandle> static_operands) { 464 return builder_.Map(operands, local_computation.computation(), dimensions, 465 static_operands); 466 } 467 468 ComputationDataHandle LocalComputationBuilder::Reduce( 469 const ComputationDataHandle& operand, 470 const ComputationDataHandle& init_value, 471 const LocalComputation& local_computation, 472 tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce) { 473 return builder_.Reduce(operand, init_value, local_computation.computation(), 474 dimensions_to_reduce); 475 } 476 477 ComputationDataHandle LocalComputationBuilder::ReduceWindowWithGeneralPadding( 478 const ComputationDataHandle& operand, 479 const ComputationDataHandle& init_value, 480 const LocalComputation& local_computation, 481 tensorflow::gtl::ArraySlice<int64> window_dimensions, 482 tensorflow::gtl::ArraySlice<int64> window_strides, 483 tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) { 484 return builder_.ReduceWindowWithGeneralPadding( 485 operand, init_value, local_computation.computation(), window_dimensions, 486 window_strides, padding); 487 } 488 489 ComputationDataHandle LocalComputationBuilder::RngNormal( 490 const ComputationDataHandle& mu, const ComputationDataHandle& sigma, 491 const Shape& shape) { 492 return builder_.RngNormal(mu, sigma, shape); 493 } 494 495 ComputationDataHandle LocalComputationBuilder::RngUniform( 496 const ComputationDataHandle& a, const ComputationDataHandle& b, 497 const Shape& shape) { 498 return builder_.RngUniform(a, b, shape); 499 } 500 501 ComputationDataHandle LocalComputationBuilder::While( 502 const LocalComputation& condition, const LocalComputation& body, 503 const ComputationDataHandle& init) { 504 return builder_.While(condition.computation(), body.computation(), init); 505 } 506 507 ComputationDataHandle LocalComputationBuilder::Conditional( 508 const ComputationDataHandle& predicate, 509 const ComputationDataHandle& true_operand, 510 const LocalComputation& true_computation, 511 const ComputationDataHandle& false_operand, 512 const LocalComputation& false_computation) { 513 return builder_.Conditional(predicate, true_operand, 514 true_computation.computation(), false_operand, 515 false_computation.computation()); 516 } 517 518 #define _FORWARD(method_name, return_sig, args_sig, args) \ 519 return_sig LocalComputationBuilder::method_name args_sig { \ 520 return builder_.method_name args; \ 521 } 522 523 #define _FORWARD_UNOP(method_name) \ 524 _FORWARD(method_name, ComputationDataHandle, \ 525 (const ComputationDataHandle& operand), (operand)) 526 527 #define _FORWARD_BINOP(method_name) \ 528 _FORWARD( \ 529 method_name, ComputationDataHandle, \ 530 (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ 531 tensorflow::gtl::ArraySlice<int64> broadcast_dimensions), \ 532 (lhs, rhs, broadcast_dimensions)) 533 534 #define _FORWARD_TRIOP(method_name) \ 535 _FORWARD( \ 536 method_name, ComputationDataHandle, \ 537 (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ 538 const ComputationDataHandle& ehs), \ 539 (lhs, rhs, ehs)) 540 541 _FORWARD_TRIOP(Select) 542 _FORWARD_TRIOP(Clamp) 543 _FORWARD_BINOP(Eq) 544 _FORWARD_BINOP(Ne) 545 _FORWARD_BINOP(Ge) 546 _FORWARD_BINOP(Gt) 547 _FORWARD_BINOP(Lt) 548 _FORWARD_BINOP(Le) 549 _FORWARD_BINOP(Add) 550 _FORWARD_BINOP(Sub) 551 _FORWARD_BINOP(Mul) 552 _FORWARD_BINOP(Div) 553 _FORWARD_BINOP(Rem) 554 _FORWARD_BINOP(Max) 555 _FORWARD_BINOP(Min) 556 _FORWARD_BINOP(And) 557 _FORWARD_BINOP(Or) 558 _FORWARD_UNOP(Not) 559 _FORWARD_UNOP(Abs) 560 _FORWARD_UNOP(Exp) 561 _FORWARD_UNOP(Floor) 562 _FORWARD_UNOP(Ceil) 563 _FORWARD_UNOP(Round) 564 _FORWARD_UNOP(Log) 565 _FORWARD_UNOP(Sign) 566 _FORWARD_UNOP(Cos) 567 _FORWARD_UNOP(Sin) 568 _FORWARD_UNOP(Tanh) 569 _FORWARD_UNOP(SqrtF32) 570 _FORWARD_UNOP(SquareF32) 571 _FORWARD_BINOP(Pow) 572 _FORWARD_UNOP(IsFinite) 573 _FORWARD_UNOP(ReciprocalF32) 574 _FORWARD_UNOP(Neg) 575 _FORWARD_UNOP(Sort) 576 577 #undef _FORWARD 578 #undef _FORWARD_UNOP 579 #undef _FORWARD_BINOP 580 #undef _FORWARD_TRIOP 581 582 void DeleteLocalShapedBuffer(LocalShapedBuffer* local_shaped_buffer) { 583 delete local_shaped_buffer; 584 } 585 586 void DeleteCompiledLocalComputation(CompiledLocalComputation* computation) { 587 delete computation; 588 } 589 590 void DeleteLocalComputation(LocalComputation* computation) { 591 delete computation; 592 } 593 594 } // namespace swig 595 596 } // namespace xla 597