Home | History | Annotate | Download | only in service
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_TRACKER_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_TRACKER_H_
     18 
     19 #include <list>
     20 #include <map>
     21 #include <memory>
     22 #include <set>
     23 #include <string>
     24 
     25 #include "tensorflow/compiler/xla/service/hlo_module.h"
     26 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
     27 #include "tensorflow/compiler/xla/service/session.pb.h"
     28 #include "tensorflow/compiler/xla/service/user_computation.h"
     29 #include "tensorflow/compiler/xla/service/versioned_computation_handle.h"
     30 #include "tensorflow/compiler/xla/statusor.h"
     31 #include "tensorflow/compiler/xla/types.h"
     32 #include "tensorflow/compiler/xla/xla_data.pb.h"
     33 #include "tensorflow/core/platform/macros.h"
     34 #include "tensorflow/core/platform/mutex.h"
     35 #include "tensorflow/core/platform/thread_annotations.h"
     36 #include "tensorflow/core/platform/types.h"
     37 
     38 namespace xla {
     39 
     40 // Tracks computations for the XLA service; computations can be registered
     41 // with a UserComputation instance and can be resolved from a handle for later
     42 // use.
     43 //
     44 // This class is also capable of serializing/deserializing computations that it
     45 // tracks (and to serialize properly you need to serialize all referred-to
     46 // computations as well).
     47 class ComputationTracker {
     48  public:
     49   ComputationTracker();
     50 
     51   // Creates a new UserComputation object and returns the corresponding
     52   // ComputationHandle for it.
     53   //
     54   // Precondition: user_computation is not already present in the map.
     55   ComputationHandle NewComputation(const string& computation_name);
     56 
     57   // Restores session data for a computation that has been serialized, and
     58   // allocates a new computation handle for it.
     59   StatusOr<ComputationHandle> LoadSessionModule(
     60       const SessionModule& session_module);
     61 
     62   // Snapshots a computation (referenced by the provided handle) at its latest
     63   // version, returning a module where it is the entry, and any referred-to
     64   // computations are entrained as "embedded" (non-entry) computations.
     65   StatusOr<std::unique_ptr<SessionModule>> SnapshotComputation(
     66       const ComputationHandle& computation);
     67 
     68   // Resolves a ComputationHandle to a UserComputation that is present in the
     69   // map.
     70   StatusOr<UserComputation*> Resolve(
     71       const ComputationHandle& computation) const;
     72 
     73   // Builds an HLO module using the specified computation as the entry. The
     74   // module will include the entry computation as well as all computations which
     75   // are called directly or indirectly from the entry computation via operations
     76   // like "map". config is the HLO module configuration to use for the
     77   // constructed module.
     78   // If include_unreachable_instructions is true, then instructions
     79   // which are not reachable from the root are lowered into HloInstructions
     80   // including unreachable parameters. This ensures the entry HloComputation has
     81   // the same program shape (ProgramShape) as the entry UserComputation.
     82   StatusOr<std::unique_ptr<HloModule>> BuildHloModule(
     83       const VersionedComputationHandle& entry_handle,
     84       const HloModuleConfig& config,
     85       bool include_unreachable_instructions = true) const;
     86 
     87   string ToString() const;
     88 
     89  private:
     90   // Bumps the next_computation_ number and returns the allocated number wrapped
     91   // in a ComputationHandle.
     92   ComputationHandle AllocateHandle()
     93       EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_);
     94 
     95   // Loads a session computation into a UserComputation, registers it, and
     96   // returns the computation handle of the registered computation. If old_to_new
     97   // is provided, it is used for remapping references to computations present in
     98   // session_computation.
     99   //
    100   // old_to_new will be updated with the mapping from session_computation's old
    101   // handle to the returned handle value, and may not be null.
    102   StatusOr<ComputationHandle> LoadSessionComputation(
    103       const SessionComputation& session_computation,
    104       std::map<int64, ComputationHandle>* old_to_new)
    105       EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_);
    106 
    107   // Internal implementation of Resolve method which requires, but does not
    108   // acquire the mutex.
    109   StatusOr<UserComputation*> ResolveInternal(
    110       const ComputationHandle& computation) const
    111       EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_);
    112 
    113   // Builds a post order sort of a computation ("entry") and all of its embedded
    114   // computations including all transitively embedded computations. An embedded
    115   // computation (the callee) will always appear in the sort before the
    116   // computation which calls the embedded computation (the caller). Necessarily,
    117   // the entry computation is the last element in the sort. visited and
    118   // post_order should be empty when calling. post_order contains the post order
    119   // sort when the function return.
    120   void ComputeComputationPostOrder(
    121       const VersionedComputationHandle& versioned_handle,
    122       std::set<VersionedComputationHandle>* visited,
    123       std::list<VersionedComputationHandle>* post_order) const
    124       EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_);
    125 
    126   string ToStringInternal() const EXCLUSIVE_LOCKS_REQUIRED(computation_mutex_);
    127 
    128   // Guards the computation mapping. Marked mutable so that the Resolve method
    129   // can remain const; Resolve does't really modify the tracker in any way, but
    130   // it has to lock the mutex for safety.
    131   mutable tensorflow::mutex computation_mutex_;
    132 
    133   // The next sequence number to assign to a computation, guarded by the same
    134   // mutex as the mapping as they'll be mutated at the same time.
    135   int64 next_computation_ GUARDED_BY(computation_mutex_);
    136 
    137   // Mapping from ComputationHandle value to the corresponding registered
    138   // UserComputation object.
    139   std::map<int64, std::unique_ptr<UserComputation>> opaque_to_computation_
    140       GUARDED_BY(computation_mutex_);
    141 
    142   TF_DISALLOW_COPY_AND_ASSIGN(ComputationTracker);
    143 };
    144 
    145 }  // namespace xla
    146 
    147 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_TRACKER_H_
    148