1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_USER_COMPUTATION_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_USER_COMPUTATION_H_ 18 19 #include <functional> 20 #include <map> 21 #include <memory> 22 #include <string> 23 #include <vector> 24 25 #include "tensorflow/compiler/xla/service/hlo_computation.h" 26 #include "tensorflow/compiler/xla/service/session.pb.h" 27 #include "tensorflow/compiler/xla/service/versioned_computation_handle.h" 28 #include "tensorflow/compiler/xla/statusor.h" 29 #include "tensorflow/compiler/xla/types.h" 30 #include "tensorflow/compiler/xla/xla.pb.h" 31 #include "tensorflow/compiler/xla/xla_data.pb.h" 32 #include "tensorflow/core/platform/macros.h" 33 #include "tensorflow/core/platform/mutex.h" 34 #include "tensorflow/core/platform/thread_annotations.h" 35 #include "tensorflow/core/platform/types.h" 36 37 namespace xla { 38 39 // A UserComputation is the built-up computation that users create via the 40 // XLA Service interface. 41 // 42 // The XLA service adds instructions to a user computation via this 43 // interface. The state of the computation is stored as a SessionComputation 44 // proto which holds a record of all operation-building requests received by the 45 // XLA service. 46 // 47 // UserComputations are lowered to HloComputations which are passed to the high 48 // level compiler interface. 49 class UserComputation { 50 public: 51 // Factory used when restoring a computation from serialized session 52 // computation (computation snapshot) data. Remaps any references to 53 // computation handle via the old_to_new mapping. 54 // 55 // An error will occur if the old_to_new mapping cannot resolve a reference to 56 // a computation that is present in session_computation. 57 static StatusOr<std::unique_ptr<UserComputation>> MakeWithRemapping( 58 const SessionComputation& session_computation, 59 const ComputationHandle& handle, 60 const std::map<int64, ComputationHandle>& old_to_new); 61 62 // Creates an empty computation with the given name and computation handle. 63 explicit UserComputation(const string& name, const ComputationHandle& handle); 64 65 // Enqueues a parameter-retrieving instruction onto this user computation. 66 // Returns an error status if the parameter number is already registered with 67 // different values. 68 StatusOr<ComputationDataHandle> AddParameterInstruction( 69 const ParameterRequest& parameter_request); 70 71 // Enqueues a pad instruction onto this user computation. 72 StatusOr<ComputationDataHandle> AddPadInstruction( 73 const PadRequest& pad_request); 74 75 // Enqueues a tracing instruction onto this user computation. 76 // Returns an error status if the operand cannot be resolved. 77 Status AddTraceInstruction(const TraceRequest& trace_request); 78 79 // Enqueues a random number generation instruction onto this user computation. 80 StatusOr<ComputationDataHandle> AddRngInstruction( 81 const RngRequest& rng_request); 82 83 // Enqueues a unary instruction onto this user computation. 84 // Returns an error status if the operand index is out of bounds. 85 StatusOr<ComputationDataHandle> AddUnaryInstruction( 86 const UnaryOpRequest& unary_request); 87 88 // Enqueues a batch norm training instruction onto this user computation. 89 StatusOr<ComputationDataHandle> AddBatchNormTrainingInstruction( 90 const BatchNormTrainingRequest& batch_norm_training_request); 91 92 // Enqueues a batch norm inference instruction onto this user computation. 93 StatusOr<ComputationDataHandle> AddBatchNormInferenceInstruction( 94 const BatchNormInferenceRequest& batch_norm_inference_request); 95 96 // Enqueues a batch norm grad instruction onto this user computation. 97 StatusOr<ComputationDataHandle> AddBatchNormGradInstruction( 98 const BatchNormGradRequest& batch_norm_grad_request); 99 100 // Enqueues a binary instruction onto this user computation. 101 // Returns an error status if the operand indices are out of bounds. 102 StatusOr<ComputationDataHandle> AddBinaryInstruction( 103 const BinaryOpRequest& binary_request); 104 105 // Enqueues a ternary instruction onto this user computation. 106 // Returns an error status if the operand indices are out of bounds. 107 StatusOr<ComputationDataHandle> AddTernaryInstruction( 108 const TernaryOpRequest& ternary_request); 109 110 // Enqueues a variadic instruction onto this user computation. 111 // Returns an error status if the operand indices are out of bounds. 112 StatusOr<ComputationDataHandle> AddVariadicInstruction( 113 const VariadicOpRequest& variadic_request); 114 115 // Enqueues a constant instruction onto this user computation. 116 StatusOr<ComputationDataHandle> AddConstantInstruction( 117 const ConstantRequest& constant_request); 118 119 // Enqueues a get tuple element instruction onto this user computation. 120 StatusOr<ComputationDataHandle> AddGetTupleElementInstruction( 121 const GetTupleElementRequest& get_tuple_element_request); 122 123 // Enqueues a map instruction onto this user computation. 124 StatusOr<ComputationDataHandle> AddMapInstruction( 125 const MapRequest& map_request, 126 const UserComputation& to_apply_computation); 127 128 // Enqueues a reduce-precision instruction onto this user computation. 129 StatusOr<ComputationDataHandle> AddReducePrecisionInstruction( 130 const ReducePrecisionRequest& reduce_precision_request); 131 132 // Enqueues a convolution instruction onto this user computation. 133 StatusOr<ComputationDataHandle> AddConvolveInstruction( 134 const ConvolveRequest& convolve_request); 135 136 // Enqueues an FFT instruction onto this user computation. 137 StatusOr<ComputationDataHandle> AddFftInstruction( 138 const FftRequest& fft_request); 139 140 // Enqueues a cross replica sum instruction onto this user computation. 141 StatusOr<ComputationDataHandle> AddCrossReplicaSumInstruction( 142 const CrossReplicaSumRequest& cross_replica_sum_request); 143 144 // Enqueues an infeed instruction onto this user computation. 145 StatusOr<ComputationDataHandle> AddInfeedInstruction( 146 const InfeedRequest& infeed_request); 147 148 // Enqueues an outfeed instruction onto this user computation. 149 StatusOr<ComputationDataHandle> AddOutfeedInstruction( 150 const OutfeedRequest& outfeed_request); 151 152 // Enqueues a host compute instruction onto this user computation. 153 StatusOr<ComputationDataHandle> AddHostComputeInstruction( 154 const HostComputeRequest& host_compute_request); 155 156 // Enqueues a call instruction onto this user computation. 157 StatusOr<ComputationDataHandle> AddCallInstruction( 158 const CallRequest& call_request, 159 const UserComputation& to_apply_computation); 160 161 // Enqueues a custom call instruction onto this user computation. 162 StatusOr<ComputationDataHandle> AddCustomCallInstruction( 163 const CustomCallRequest& custom_call_request); 164 165 // Enqueues a dot instruction onto this user computation. 166 StatusOr<ComputationDataHandle> AddDotInstruction( 167 const DotRequest& dot_request); 168 169 // Enqueues a broadcast instruction onto this user computation. 170 StatusOr<ComputationDataHandle> AddBroadcastInstruction( 171 const BroadcastRequest& broadcast_request); 172 173 // Enqueues a reshape instruction onto this user computation. 174 StatusOr<ComputationDataHandle> AddReshapeInstruction( 175 const ReshapeRequest& reshape_request); 176 177 // Enqueues a transpose instruction onto this user computation. 178 StatusOr<ComputationDataHandle> AddTransposeInstruction( 179 const TransposeRequest& transpose_request); 180 181 // Enqueues a slice instruction onto this user computation. 182 StatusOr<ComputationDataHandle> AddSliceInstruction( 183 const SliceRequest& slice_request); 184 185 // Enqueues a dynamic slice instruction onto this user computation. 186 StatusOr<ComputationDataHandle> AddDynamicSliceInstruction( 187 const DynamicSliceRequest& dynamic_slice_request); 188 189 // Enqueues a dynamic update slice instruction onto this user computation. 190 StatusOr<ComputationDataHandle> AddDynamicUpdateSliceInstruction( 191 const DynamicUpdateSliceRequest& dynamic_update_slice_request); 192 193 // Enqueues a concatenate instruction onto this user computation. 194 StatusOr<ComputationDataHandle> AddConcatenateInstruction( 195 const ConcatenateRequest& concatenate_request); 196 197 // Enqueues a convert instruction onto this user computation. 198 StatusOr<ComputationDataHandle> AddConvertInstruction( 199 const ConvertRequest& convert_request); 200 201 // Enqueues a bitcast element instruction onto this user computation. 202 StatusOr<ComputationDataHandle> AddBitcastConvertInstruction( 203 const ConvertRequest& convert_request); 204 205 // Enqueues a reduce instruction onto this user computation. 206 StatusOr<ComputationDataHandle> AddReduceInstruction( 207 const ReduceRequest& reduce_request, 208 const UserComputation& to_apply_computation); 209 210 // Enqueues a windowed reduce instruction onto this user computation. 211 StatusOr<ComputationDataHandle> AddReduceWindowInstruction( 212 const ReduceWindowRequest& reduce_window_request, 213 const UserComputation& to_apply_computation); 214 215 // Enqueues a select-and-scatter instruction onto this user 216 // computation. 217 StatusOr<ComputationDataHandle> AddSelectAndScatterInstruction( 218 const SelectAndScatterRequest& select_and_scatter_request, 219 const UserComputation& select_computation, 220 const UserComputation& scatter_computation); 221 222 // Enqueues a reverse instruction onto this user computation. 223 StatusOr<ComputationDataHandle> AddReverseInstruction( 224 const ReverseRequest& reverse_request); 225 226 // Enqueues a while instruction onto this user computation. 227 StatusOr<ComputationDataHandle> AddWhileInstruction( 228 const WhileRequest& while_request, 229 const UserComputation& condition_computation, 230 const UserComputation& body_computation); 231 232 // Enqueues a conditional instruction on this user computation. 233 StatusOr<ComputationDataHandle> AddConditionalInstruction( 234 const ConditionalRequest& conditional_request, 235 const UserComputation& true_computation, 236 const UserComputation& false_computation); 237 238 // Enqueues a Send instruction onto this user computation. 239 Status AddSendInstruction(const SendRequest& send_request); 240 241 // Enqueues a Recv instruction onto this user computation. 242 StatusOr<ComputationDataHandle> AddRecvInstruction( 243 const RecvRequest& recv_request); 244 245 // Enqueues a Gather instruction onto this user computation. 246 StatusOr<ComputationDataHandle> AddGatherInstruction( 247 const GatherRequest& gather_request); 248 249 // Returns the user-provided name of this user computation, which is provided 250 // via the XLA computation-building API. 251 const string& name() const { return name_; } 252 253 // Subsequent executions of this computation will compute the value 254 // represented by handle, rather than the last expression enqueued 255 // on the computation. 256 Status SetReturnValue(const ComputationDataHandle& handle); 257 258 // Return a versioned handle for this computation. 259 VersionedComputationHandle GetVersionedHandle() const; 260 261 // Return a versioned handle for this computation with a version equal to the 262 // point at which given operation was added to the computation. 263 VersionedComputationHandle GetVersionedHandleAtOperation( 264 const ComputationDataHandle& operation) const; 265 266 // Return a version value representing the current state of the 267 // computation. 268 VersionedComputationHandle::Version version() const; 269 270 // Computes and returns the program shape for the user computation -- gathers 271 // parameters and result type into a single proto. A shared_ptr is used 272 // because the returned pointer refers to an internally cached value which may 273 // be discarded by the UserComputation object. This avoid unnecessary copies. 274 // 275 // If the parameter space is not dense (i.e. there are holes in the parameter 276 // numbers provided) then an error status is returned. 277 StatusOr<std::shared_ptr<const ProgramShape>> ComputeProgramShape( 278 VersionedComputationHandle::Version version) const; 279 280 // Returns true if the given data handle does not depend on any parameter with 281 // index higher then num_parameters. That is, the value can be computed at 282 // compile time if we know the first num_parameters arguments. 283 StatusOr<bool> IsConstant(const ComputationDataHandle& handle, 284 int64 num_parameters); 285 286 // Returns the output shape of the operation indicated by the given handle. 287 StatusOr<Shape> GetShape(const ComputationDataHandle& handle); 288 289 // Sets metadata on the Hlo instruction referenced by the given handle. 290 Status SetOpMetadata(const ComputationDataHandle& handle, 291 const OpMetadata& metadata); 292 293 // Sets the device assignment on the Hlo instruction referenced by 'handle'. 294 Status SetOpSharding(const ComputationDataHandle& handle, 295 const OpSharding& sharding); 296 297 // Builds a HLO computation from the UserComputation. The parameter "resolver" 298 // is a function which returns a pointer to the HloComputation corresponding 299 // to the given ComputationHandle at the given version. The resolver is used 300 // for operations, such as map, which call other computations and need a 301 // pointer to the called HloComputation to construct the respective HLO 302 // instructions. If include_unreachable_instructions is true, then 303 // instructions which are not reachable from the root are lowered into 304 // HloInstructions. 305 using HloComputationResolver = 306 std::function<HloComputation*(const VersionedComputationHandle& handle)>; 307 StatusOr<std::unique_ptr<HloComputation>> BuildHloComputation( 308 VersionedComputationHandle::Version version, 309 HloComputationResolver hlo_resolver, const DebugOptions& debug_options, 310 bool include_unreachable_instructions = true) const; 311 312 // Return a vector containing the embedded computations used by this 313 // UserComputation. Only embedded computations which are called directly by 314 // this UserComputation are included. That is, the transitive closure of 315 // embedded computations is not included. 316 std::vector<VersionedComputationHandle> GetEmbeddedComputations( 317 VersionedComputationHandle::Version version) const; 318 319 // Returns the number of OperationRequest objects in this UserComputation. 320 // The 'version' of a computation is identical to the number of 321 // OperationRequests in the UserComputation. 322 int64 request_count(VersionedComputationHandle::Version version) const { 323 return version; 324 } 325 326 // Returns a copy of the internal session state for this computation -- this 327 // is useful for serializing the guts of a user computation, though references 328 // to other handles (e.g. referred-to computations) must be handled with care 329 // in the serialization / de-serialization process. 330 SessionComputation CloneSessionComputation( 331 VersionedComputationHandle::Version version) const; 332 333 // Warning: typically we don't want to look up computation data handles until 334 // the computation is finished being built, for consistency purposes. We 335 // expose this routine for error reporting purposes so that we can provide 336 // more meaningful error messages from the XLA service layer. 337 // 338 // Returns the operation request that the handle comes from. 339 StatusOr<const OperationRequest*> LookUpRequestForErrorReporting( 340 const ComputationDataHandle& handle) const; 341 342 // Retrieves the parameter metadata for the given parameter number. 343 // 344 // If the parameter number is invalid for this computation, nullopt is 345 // returned. When the return value has_value(), nullptr will never be 346 // the held value. 347 tensorflow::gtl::optional<const OpMetadata*> ParameterMetadata( 348 int parameter_number) const; 349 350 private: 351 // Warning: dangerous mutating operation that doesn't respect versioning. 352 // This is only used at initialization time when constructing from a 353 // SessionComputation a la MakeWithRemapping. 354 // 355 // Remaps references to old computations (with handle values in the keys of 356 // old_to_new) to the computation handle given in the values. This is useful 357 // when loading computations from snapshots, to finish initialization, before 358 // the user computation is released into the wild. 359 Status RemapEmbeddedComputations( 360 const std::map<int64, ComputationHandle>& old_to_new) 361 EXCLUSIVE_LOCKS_REQUIRED(mutex_); 362 363 // Returns the OperationRequest corresponding to the given handle. 364 StatusOr<const OperationRequest*> LookUpRequest( 365 const ComputationDataHandle& handle) const 366 EXCLUSIVE_LOCKS_REQUIRED(mutex_); 367 368 // Creates a new ComputationDataHandle with the next available handle value. 369 ComputationDataHandle CreateComputationDataHandle() 370 EXCLUSIVE_LOCKS_REQUIRED(mutex_); 371 372 // Checks whether the parameter numbers of the parameter operations are 373 // contiguous starting from zero. Returns appropriate error status if not. 374 Status CheckParametersAreContiguous( 375 VersionedComputationHandle::Version version) const 376 EXCLUSIVE_LOCKS_REQUIRED(mutex_); 377 378 VersionedComputationHandle GetVersionedHandleInternal() const 379 EXCLUSIVE_LOCKS_REQUIRED(mutex_); 380 381 // Name of the computation. 382 string name_; 383 384 mutable tensorflow::mutex mutex_; 385 386 // State of the computation as a record of all operation-building requests. 387 SessionComputation session_computation_ GUARDED_BY(mutex_); 388 389 // Mapping from parameter number to operation request containing the 390 // respective ParameterRequest. 391 std::map<int64, OperationRequest*> parameters_ GUARDED_BY(mutex_); 392 393 // The next ComputationDataHandle value to assign. Handle values are assigned 394 // sequentially. 395 int64 next_handle_value_ GUARDED_BY(mutex_); 396 397 // If handle_to_return_.has_handle() then an Execution of this Computation 398 // will compute the value represented by handle_to_return_, otherwise it will 399 // compute the value of (next_handle_value_ - 1). 400 ComputationDataHandle handle_to_return_ GUARDED_BY(mutex_); 401 402 // Memoized ProgramShape and its version. A shared_ptr is used because 403 // references to this object are returned by ComputeProgramShape. 404 mutable int64 program_shape_version_ GUARDED_BY(mutex_) = 0; 405 mutable std::shared_ptr<const ProgramShape> program_shape_ GUARDED_BY(mutex_); 406 407 TF_DISALLOW_COPY_AND_ASSIGN(UserComputation); 408 }; 409 410 } // namespace xla 411 412 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_USER_COMPUTATION_H_ 413