Home | History | Annotate | Download | only in tf2xla
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_TF2XLA_GRAPH_COMPILER_H_
     17 #define TENSORFLOW_COMPILER_TF2XLA_GRAPH_COMPILER_H_
     18 
     19 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
     20 #include "tensorflow/compiler/tf2xla/xla_context.h"
     21 #include "tensorflow/compiler/xla/client/local_client.h"
     22 #include "tensorflow/core/common_runtime/device.h"
     23 #include "tensorflow/core/common_runtime/device_mgr.h"
     24 #include "tensorflow/core/common_runtime/function.h"
     25 #include "tensorflow/core/framework/function.h"
     26 #include "tensorflow/core/framework/op_kernel.h"
     27 #include "tensorflow/core/lib/core/status.h"
     28 #include "tensorflow/core/platform/env.h"
     29 #include "tensorflow/core/platform/mutex.h"
     30 #include "tensorflow/core/platform/notification.h"
     31 #include "tensorflow/core/platform/thread_annotations.h"
     32 #include "tensorflow/core/public/version.h"
     33 
     34 namespace tensorflow {
     35 
     36 // GraphCompiler compiles the graph in topological order in the current
     37 // thread. It also resolves the nondeterminism in the graph by enforcing a
     38 // total order on all inputs to a node. This abstraction helps us create the
     39 // same XLA computation given two structurally equivalent TensorFlow graphs.
     40 // If a function call is visited during the graph traversal, it is then
     41 // compiled through the xla_context into a computation and a `Call` operation
     42 // is inserted to call into that computation.
     43 //
     44 // Note: GraphCompiler was created to remove our dependency to TF Executor in
     45 // the history. There are still some todos so that we can completely decouple
     46 // from Executor.
     47 //
     48 // TODO(yunxing): Remove usage of XlaCompilationDevice.
     49 //
     50 // TODO(yunxing): Remove the hack that wraps XlaExpression within a tensor now
     51 // that we don't use TF Executor to pass around a tensor.
     52 //
     53 // TODO(yunxing): Make XlaOpkernel not a subclass of OpKernel so that it can
     54 // handle a XlaExpression directly instead of a Tensor. This may require our own
     55 // op registration infrastructure instead of FunctionLibraryRuntime.
     56 class GraphCompiler {
     57  public:
     58   GraphCompiler(XlaContext* xla_context, XlaCompilationDevice* device,
     59                 Graph* graph, FunctionLibraryRuntime* flib,
     60                 ScopedStepContainer* step_container)
     61       : xla_context_(xla_context),
     62         device_(device),
     63         graph_(graph),
     64         flib_(flib),
     65         step_container_(step_container) {}
     66 
     67   // Compiles the graph. The results are written in `xla_context` that is passed
     68   // into the compiler.
     69   Status Compile();
     70 
     71  private:
     72   // Partially sets params. This partially set params can be reused
     73   // across multiple nodes visit.
     74   void PartiallySetupParams(OpKernelContext::Params* params);
     75 
     76   // Tests if a node is a functional node. A functional node represents a
     77   // defined computation and should be compiled using `compiler_`.
     78   bool IsFunctional(Node* n);
     79 
     80   // Compiles a functional node and writes result to OpkernelContext. A
     81   // functional node represents a defined computation and should be compiled
     82   // using `compiler_`.
     83   Status CompileFunctionalNode(Node* n, OpKernelContext* op_context);
     84 
     85   XlaContext* xla_context_;
     86   XlaCompilationDevice* device_;
     87   Graph* graph_;
     88   FunctionLibraryRuntime* flib_;
     89   ScopedStepContainer* step_container_;
     90   // A buffer to hold tensor inputs to a node, this is reused across the graph
     91   // traversal.
     92   gtl::InlinedVector<TensorValue, 4> tensor_inputs_;
     93 };
     94 
     95 }  // namespace tensorflow
     96 
     97 #endif  // TENSORFLOW_COMPILER_TF2XLA_GRAPH_COMPILER_H_
     98