Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_USER_COMPUTATION_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_USER_COMPUTATION_H_
     18 
     19 #include <functional>
     20 #include <map>
     21 #include <memory>
     22 #include <string>
     23 #include <vector>
     24 
     25 #include "tensorflow/compiler/xla/service/hlo_computation.h"
     26 #include "tensorflow/compiler/xla/service/session.pb.h"
     27 #include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
     28 #include "tensorflow/compiler/xla/statusor.h"
     29 #include "tensorflow/compiler/xla/types.h"
     30 #include "tensorflow/compiler/xla/xla.pb.h"
     31 #include "tensorflow/compiler/xla/xla_data.pb.h"
     32 #include "tensorflow/core/platform/macros.h"
     33 #include "tensorflow/core/platform/mutex.h"
     34 #include "tensorflow/core/platform/thread_annotations.h"
     35 #include "tensorflow/core/platform/types.h"
     36 
     37 namespace xla {
     38 
     39 // A UserComputation is the built-up computation that users create via the
     40 // XLA Service interface.
     41 //
     42 // The XLA service adds instructions to a user computation via this
     43 // interface. The state of the computation is stored as a SessionComputation
     44 // proto which holds a record of all operation-building requests received by the
     45 // XLA service.
     46 //
     47 // UserComputations are lowered to HloComputations which are passed to the high
     48 // level compiler interface.
     49 class UserComputation {
     50  public:
     51   // Factory used when restoring a computation from serialized session
     52   // computation (computation snapshot) data. Remaps any references to
     53   // computation handle via the old_to_new mapping.
     54   //
     55   // An error will occur if the old_to_new mapping cannot resolve a reference to
     56   // a computation that is present in session_computation.
     57   static StatusOr<std::unique_ptr<UserComputation>> MakeWithRemapping(
     58       const SessionComputation& session_computation,
     59       const ComputationHandle& handle,
     60       const std::map<int64, ComputationHandle>& old_to_new);
     61 
     62   // Creates an empty computation with the given name and computation handle.
     63   explicit UserComputation(const string& name, const ComputationHandle& handle);
     64 
     65   // Enqueues a parameter-retrieving instruction onto this user computation.
     66   // Returns an error status if the parameter number is already registered with
     67   // different values.
     68   StatusOr<ComputationDataHandle> AddParameterInstruction(
     69       const ParameterRequest& parameter_request);
     70 
     71   // Enqueues a pad instruction onto this user computation.
     72   StatusOr<ComputationDataHandle> AddPadInstruction(
     73       const PadRequest& pad_request);
     74 
     75   // Enqueues a tracing instruction onto this user computation.
     76   // Returns an error status if the operand cannot be resolved.
     77   Status AddTraceInstruction(const TraceRequest& trace_request);
     78 
     79   // Enqueues a random number generation instruction onto this user computation.
     80   StatusOr<ComputationDataHandle> AddRngInstruction(
     81       const RngRequest& rng_request);
     82 
     83   // Enqueues a unary instruction onto this user computation.
     84   // Returns an error status if the operand index is out of bounds.
     85   StatusOr<ComputationDataHandle> AddUnaryInstruction(
     86       const UnaryOpRequest& unary_request);
     87 
     88   // Enqueues a batch norm training instruction onto this user computation.
     89   StatusOr<ComputationDataHandle> AddBatchNormTrainingInstruction(
     90       const BatchNormTrainingRequest& batch_norm_training_request);
     91 
     92   // Enqueues a batch norm inference instruction onto this user computation.
     93   StatusOr<ComputationDataHandle> AddBatchNormInferenceInstruction(
     94       const BatchNormInferenceRequest& batch_norm_inference_request);
     95 
     96   // Enqueues a batch norm grad instruction onto this user computation.
     97   StatusOr<ComputationDataHandle> AddBatchNormGradInstruction(
     98       const BatchNormGradRequest& batch_norm_grad_request);
     99 
    100   // Enqueues a binary instruction onto this user computation.
    101   // Returns an error status if the operand indices are out of bounds.
    102   StatusOr<ComputationDataHandle> AddBinaryInstruction(
    103       const BinaryOpRequest& binary_request);
    104 
    105   // Enqueues a ternary instruction onto this user computation.
    106   // Returns an error status if the operand indices are out of bounds.
    107   StatusOr<ComputationDataHandle> AddTernaryInstruction(
    108       const TernaryOpRequest& ternary_request);
    109 
    110   // Enqueues a variadic instruction onto this user computation.
    111   // Returns an error status if the operand indices are out of bounds.
    112   StatusOr<ComputationDataHandle> AddVariadicInstruction(
    113       const VariadicOpRequest& variadic_request);
    114 
    115   // Enqueues a constant instruction onto this user computation.
    116   StatusOr<ComputationDataHandle> AddConstantInstruction(
    117       const ConstantRequest& constant_request);
    118 
    119   // Enqueues a get tuple element instruction onto this user computation.
    120   StatusOr<ComputationDataHandle> AddGetTupleElementInstruction(
    121       const GetTupleElementRequest& get_tuple_element_request);
    122 
    123   // Enqueues a map instruction onto this user computation.
    124   StatusOr<ComputationDataHandle> AddMapInstruction(
    125       const MapRequest& map_request,
    126       const UserComputation& to_apply_computation);
    127 
    128   // Enqueues a reduce-precision instruction onto this user computation.
    129   StatusOr<ComputationDataHandle> AddReducePrecisionInstruction(
    130       const ReducePrecisionRequest& reduce_precision_request);
    131 
    132   // Enqueues a convolution instruction onto this user computation.
    133   StatusOr<ComputationDataHandle> AddConvolveInstruction(
    134       const ConvolveRequest& convolve_request);
    135 
    136   // Enqueues an FFT instruction onto this user computation.
    137   StatusOr<ComputationDataHandle> AddFftInstruction(
    138       const FftRequest& fft_request);
    139 
    140   // Enqueues a cross replica sum instruction onto this user computation.
    141   StatusOr<ComputationDataHandle> AddCrossReplicaSumInstruction(
    142       const CrossReplicaSumRequest& cross_replica_sum_request);
    143 
    144   // Enqueues an infeed instruction onto this user computation.
    145   StatusOr<ComputationDataHandle> AddInfeedInstruction(
    146       const InfeedRequest& infeed_request);
    147 
    148   // Enqueues an outfeed instruction onto this user computation.
    149   StatusOr<ComputationDataHandle> AddOutfeedInstruction(
    150       const OutfeedRequest& outfeed_request);
    151 
    152   // Enqueues a host compute instruction onto this user computation.
    153   StatusOr<ComputationDataHandle> AddHostComputeInstruction(
    154       const HostComputeRequest& host_compute_request);
    155 
    156   // Enqueues a call instruction onto this user computation.
    157   StatusOr<ComputationDataHandle> AddCallInstruction(
    158       const CallRequest& call_request,
    159       const UserComputation& to_apply_computation);
    160 
    161   // Enqueues a custom call instruction onto this user computation.
    162   StatusOr<ComputationDataHandle> AddCustomCallInstruction(
    163       const CustomCallRequest& custom_call_request);
    164 
    165   // Enqueues a dot instruction onto this user computation.
    166   StatusOr<ComputationDataHandle> AddDotInstruction(
    167       const DotRequest& dot_request);
    168 
    169   // Enqueues a broadcast instruction onto this user computation.
    170   StatusOr<ComputationDataHandle> AddBroadcastInstruction(
    171       const BroadcastRequest& broadcast_request);
    172 
    173   // Enqueues a reshape instruction onto this user computation.
    174   StatusOr<ComputationDataHandle> AddReshapeInstruction(
    175       const ReshapeRequest& reshape_request);
    176 
    177   // Enqueues a transpose instruction onto this user computation.
    178   StatusOr<ComputationDataHandle> AddTransposeInstruction(
    179       const TransposeRequest& transpose_request);
    180 
    181   // Enqueues a slice instruction onto this user computation.
    182   StatusOr<ComputationDataHandle> AddSliceInstruction(
    183       const SliceRequest& slice_request);
    184 
    185   // Enqueues a dynamic slice instruction onto this user computation.
    186   StatusOr<ComputationDataHandle> AddDynamicSliceInstruction(
    187       const DynamicSliceRequest& dynamic_slice_request);
    188 
    189   // Enqueues a dynamic update slice instruction onto this user computation.
    190   StatusOr<ComputationDataHandle> AddDynamicUpdateSliceInstruction(
    191       const DynamicUpdateSliceRequest& dynamic_update_slice_request);
    192 
    193   // Enqueues a concatenate instruction onto this user computation.
    194   StatusOr<ComputationDataHandle> AddConcatenateInstruction(
    195       const ConcatenateRequest& concatenate_request);
    196 
    197   // Enqueues a convert instruction onto this user computation.
    198   StatusOr<ComputationDataHandle> AddConvertInstruction(
    199       const ConvertRequest& convert_request);
    200 
    201   // Enqueues a bitcast element instruction onto this user computation.
    202   StatusOr<ComputationDataHandle> AddBitcastConvertInstruction(
    203       const ConvertRequest& convert_request);
    204 
    205   // Enqueues a reduce instruction onto this user computation.
    206   StatusOr<ComputationDataHandle> AddReduceInstruction(
    207       const ReduceRequest& reduce_request,
    208       const UserComputation& to_apply_computation);
    209 
    210   // Enqueues a windowed reduce instruction onto this user computation.
    211   StatusOr<ComputationDataHandle> AddReduceWindowInstruction(
    212       const ReduceWindowRequest& reduce_window_request,
    213       const UserComputation& to_apply_computation);
    214 
    215   // Enqueues a select-and-scatter instruction onto this user
    216   // computation.
    217   StatusOr<ComputationDataHandle> AddSelectAndScatterInstruction(
    218       const SelectAndScatterRequest& select_and_scatter_request,
    219       const UserComputation& select_computation,
    220       const UserComputation& scatter_computation);
    221 
    222   // Enqueues a reverse instruction onto this user computation.
    223   StatusOr<ComputationDataHandle> AddReverseInstruction(
    224       const ReverseRequest& reverse_request);
    225 
    226   // Enqueues a while instruction onto this user computation.
    227   StatusOr<ComputationDataHandle> AddWhileInstruction(
    228       const WhileRequest& while_request,
    229       const UserComputation& condition_computation,
    230       const UserComputation& body_computation);
    231 
    232   // Enqueues a conditional instruction on this user computation.
    233   StatusOr<ComputationDataHandle> AddConditionalInstruction(
    234       const ConditionalRequest& conditional_request,
    235       const UserComputation& true_computation,
    236       const UserComputation& false_computation);
    237 
    238   // Enqueues a Send instruction onto this user computation.
    239   Status AddSendInstruction(const SendRequest& send_request);
    240 
    241   // Enqueues a Recv instruction onto this user computation.
    242   StatusOr<ComputationDataHandle> AddRecvInstruction(
    243       const RecvRequest& recv_request);
    244 
    245   // Enqueues a Gather instruction onto this user computation.
    246   StatusOr<ComputationDataHandle> AddGatherInstruction(
    247       const GatherRequest& gather_request);
    248 
    249   // Returns the user-provided name of this user computation, which is provided
    250   // via the XLA computation-building API.
    251   const string& name() const { return name_; }
    252 
    253   // Subsequent executions of this computation will compute the value
    254   // represented by handle, rather than the last expression enqueued
    255   // on the computation.
    256   Status SetReturnValue(const ComputationDataHandle& handle);
    257 
    258   // Return a versioned handle for this computation.
    259   VersionedComputationHandle GetVersionedHandle() const;
    260 
    261   // Return a versioned handle for this computation with a version equal to the
    262   // point at which given operation was added to the computation.
    263   VersionedComputationHandle GetVersionedHandleAtOperation(
    264       const ComputationDataHandle& operation) const;
    265 
    266   // Return a version value representing the current state of the
    267   // computation.
    268   VersionedComputationHandle::Version version() const;
    269 
    270   // Computes and returns the program shape for the user computation -- gathers
    271   // parameters and result type into a single proto. A shared_ptr is used
    272   // because the returned pointer refers to an internally cached value which may
    273   // be discarded by the UserComputation object. This avoid unnecessary copies.
    274   //
    275   // If the parameter space is not dense (i.e. there are holes in the parameter
    276   // numbers provided) then an error status is returned.
    277   StatusOr<std::shared_ptr<const ProgramShape>> ComputeProgramShape(
    278       VersionedComputationHandle::Version version) const;
    279 
    280   // Returns true if the given data handle does not depend on any parameter with
    281   // index higher then num_parameters. That is, the value can be computed at
    282   // compile time if we know the first num_parameters arguments.
    283   StatusOr<bool> IsConstant(const ComputationDataHandle& handle,
    284                             int64 num_parameters);
    285 
    286   // Returns the output shape of the operation indicated by the given handle.
    287   StatusOr<Shape> GetShape(const ComputationDataHandle& handle);
    288 
    289   // Sets metadata on the Hlo instruction referenced by the given handle.
    290   Status SetOpMetadata(const ComputationDataHandle& handle,
    291                        const OpMetadata& metadata);
    292 
    293   // Sets the device assignment on the Hlo instruction referenced by 'handle'.
    294   Status SetOpSharding(const ComputationDataHandle& handle,
    295                        const OpSharding& sharding);
    296 
    297   // Builds a HLO computation from the UserComputation. The parameter "resolver"
    298   // is a function which returns a pointer to the HloComputation corresponding
    299   // to the given ComputationHandle at the given version. The resolver is used
    300   // for operations, such as map, which call other computations and need a
    301   // pointer to the called HloComputation to construct the respective HLO
    302   // instructions. If include_unreachable_instructions is true, then
    303   // instructions which are not reachable from the root are lowered into
    304   // HloInstructions.
    305   using HloComputationResolver =
    306       std::function<HloComputation*(const VersionedComputationHandle& handle)>;
    307   StatusOr<std::unique_ptr<HloComputation>> BuildHloComputation(
    308       VersionedComputationHandle::Version version,
    309       HloComputationResolver hlo_resolver, const DebugOptions& debug_options,
    310       bool include_unreachable_instructions = true) const;
    311 
    312   // Return a vector containing the embedded computations used by this
    313   // UserComputation. Only embedded computations which are called directly by
    314   // this UserComputation are included. That is, the transitive closure of
    315   // embedded computations is not included.
    316   std::vector<VersionedComputationHandle> GetEmbeddedComputations(
    317       VersionedComputationHandle::Version version) const;
    318 
    319   // Returns the number of OperationRequest objects in this UserComputation.
    320   // The 'version' of a computation is identical to the number of
    321   // OperationRequests in the UserComputation.
    322   int64 request_count(VersionedComputationHandle::Version version) const {
    323     return version;
    324   }
    325 
    326   // Returns a copy of the internal session state for this computation -- this
    327   // is useful for serializing the guts of a user computation, though references
    328   // to other handles (e.g. referred-to computations) must be handled with care
    329   // in the serialization / de-serialization process.
    330   SessionComputation CloneSessionComputation(
    331       VersionedComputationHandle::Version version) const;
    332 
    333   // Warning: typically we don't want to look up computation data handles until
    334   // the computation is finished being built, for consistency purposes. We
    335   // expose this routine for error reporting purposes so that we can provide
    336   // more meaningful error messages from the XLA service layer.
    337   //
    338   // Returns the operation request that the handle comes from.
    339   StatusOr<const OperationRequest*> LookUpRequestForErrorReporting(
    340       const ComputationDataHandle& handle) const;
    341 
    342   // Retrieves the parameter metadata for the given parameter number.
    343   //
    344   // If the parameter number is invalid for this computation, nullopt is
    345   // returned. When the return value has_value(), nullptr will never be
    346   // the held value.
    347   tensorflow::gtl::optional<const OpMetadata*> ParameterMetadata(
    348       int parameter_number) const;
    349 
    350  private:
    351   // Warning: dangerous mutating operation that doesn't respect versioning.
    352   // This is only used at initialization time when constructing from a
    353   // SessionComputation a la MakeWithRemapping.
    354   //
    355   // Remaps references to old computations (with handle values in the keys of
    356   // old_to_new) to the computation handle given in the values. This is useful
    357   // when loading computations from snapshots, to finish initialization, before
    358   // the user computation is released into the wild.
    359   Status RemapEmbeddedComputations(
    360       const std::map<int64, ComputationHandle>& old_to_new)
    361       EXCLUSIVE_LOCKS_REQUIRED(mutex_);
    362 
    363   // Returns the OperationRequest corresponding to the given handle.
    364   StatusOr<const OperationRequest*> LookUpRequest(
    365       const ComputationDataHandle& handle) const
    366       EXCLUSIVE_LOCKS_REQUIRED(mutex_);
    367 
    368   // Creates a new ComputationDataHandle with the next available handle value.
    369   ComputationDataHandle CreateComputationDataHandle()
    370       EXCLUSIVE_LOCKS_REQUIRED(mutex_);
    371 
    372   // Checks whether the parameter numbers of the parameter operations are
    373   // contiguous starting from zero. Returns appropriate error status if not.
    374   Status CheckParametersAreContiguous(
    375       VersionedComputationHandle::Version version) const
    376       EXCLUSIVE_LOCKS_REQUIRED(mutex_);
    377 
    378   VersionedComputationHandle GetVersionedHandleInternal() const
    379       EXCLUSIVE_LOCKS_REQUIRED(mutex_);
    380 
    381   // Name of the computation.
    382   string name_;
    383 
    384   mutable tensorflow::mutex mutex_;
    385 
    386   // State of the computation as a record of all operation-building requests.
    387   SessionComputation session_computation_ GUARDED_BY(mutex_);
    388 
    389   // Mapping from parameter number to operation request containing the
    390   // respective ParameterRequest.
    391   std::map<int64, OperationRequest*> parameters_ GUARDED_BY(mutex_);
    392 
    393   // The next ComputationDataHandle value to assign. Handle values are assigned
    394   // sequentially.
    395   int64 next_handle_value_ GUARDED_BY(mutex_);
    396 
    397   // If handle_to_return_.has_handle() then an Execution of this Computation
    398   // will compute the value represented by handle_to_return_, otherwise it will
    399   // compute the value of (next_handle_value_ - 1).
    400   ComputationDataHandle handle_to_return_ GUARDED_BY(mutex_);
    401 
    402   // Memoized ProgramShape and its version. A shared_ptr is used because
    403   // references to this object are returned by ComputeProgramShape.
    404   mutable int64 program_shape_version_ GUARDED_BY(mutex_) = 0;
    405   mutable std::shared_ptr<const ProgramShape> program_shape_ GUARDED_BY(mutex_);
    406 
    407   TF_DISALLOW_COPY_AND_ASSIGN(UserComputation);
    408 };
    409 
    410 }  // namespace xla
    411 
    412 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_USER_COMPUTATION_H_
    413