Home | History | Annotate | Download | only in gpu
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMISSION_UTILS_H_
     17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMISSION_UTILS_H_
     18 
     19 #include <utility>
     20 
     21 #include "llvm/IR/IRBuilder.h"
     22 #include "llvm/IR/Value.h"
     23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
     24 
     25 // TODO(jlebar): Move functions related to cublas/cudnn to a separate file; they
     26 // don't belong in "ir_emission_utils".
     27 
     28 namespace xla {
     29 namespace gpu {
     30 
     31 constexpr int64 kWarpSize = 32;
     32 
     33 // Returns true if `hlo` will be implemented as a call to BLAS gemm.
     34 bool ImplementedAsGemm(const HloInstruction& hlo);
     35 
     36 // A call to cuDNN for batch normalization is represented as CustomCall HLO with
     37 // a call target equal to one of these strings.
     38 //
     39 // The operands to and outputs of these calls are the same as those of the
     40 // corresponding HLOs, except:
     41 //
     42 //  - epsilon and feature_index are proper operands, at the end of the operands
     43 //    list.  They must be HLO constants.
     44 //  - The cuDNN forward training call returns inv_stddev =
     45 //    1/sqrt(variance + epsilon) in place of plain variance.
     46 //  - Similarly, BatchNormGrad accepts inv_stddev in place of the variance
     47 //    operand.
     48 extern const char* const kCudnnBatchNormForwardInferenceCallTarget;
     49 extern const char* const kCudnnBatchNormForwardTrainingCallTarget;
     50 extern const char* const kCudnnBatchNormBackwardCallTarget;
     51 
     52 // Returns true if `hlo` will be implemented as a call to a cuDNN batch
     53 // normalization routine.
     54 //
     55 // This returns true if `hlo` is a CustomCall HLO with a call target equal to
     56 // one of the kCudnnBatchNormFoo constants above, but returns *false* for HLOs
     57 // with one of the kBatchNorm opcodes, because these are lowered either to a
     58 // sequence of generic HLOs or to a cuDNN CustomCall.
     59 bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo);
     60 
     61 // A call to cuDNN for convolution (forward, backward filter, or backward input)
     62 // is represented as a CustomCall HLO with a call target equal to one of these
     63 // strings.
     64 //
     65 // These CustomCalls have window() and convolution_dimension_numbers() set like
     66 // regular convolution ops.  They have the same LHS and RHS operands, plus two
     67 // additional constant operands: an int64 operand for the cudnn algorithm and
     68 // a bool operand for whether tensor_ops is enabled. A value of -1 for the cudnn
     69 // algorithm means that the implementation is free to choose the best algorithm
     70 // it can.
     71 //
     72 // These calls output a tuple (conv_result, scratch_memory), where conv_result
     73 // is the actual result of the convolution, and scratch_memory is temporary
     74 // memory used by cudnn.  Callers shouldn't inspect scratch_memory, as its value
     75 // is not well-defined.
     76 //
     77 // CudnnConvolutionRewriter lowers kConvolution HLOs to these custom calls.
     78 // When it does so, it chooses algorithm -1 and 0 bytes of scratch space.  Later
     79 // on in the pipeline, CudnnConvolutionAlgorithmChooser chooses an explicit
     80 // algorithm for each conv and sets the amount of scratch space needed.
     81 //
     82 // (Representing the scratch memory as an output may seem strange at first, but
     83 // it's quite sensible, from a certain point of view.  The scratch buffer is a
     84 // location in memory that the conv can write into, but which it can't legally
     85 // read from, at least until it's written something first.  But that's exactly
     86 // the definition of an output buffer.)
     87 extern const char* const kCudnnConvForwardCallTarget;
     88 extern const char* const kCudnnConvBackwardInputCallTarget;
     89 extern const char* const kCudnnConvBackwardFilterCallTarget;
     90 
     91 // Returns true if `hlo` will be implemented as a call to a cuDNN convolution
     92 // routine.
     93 //
     94 // This returns true if `hlo` is a CustomCall HLO with a call target equal to
     95 // one of the kCudnnConvFoo constants above, but returns *false* for HLOs with a
     96 // kConvolution opcode.
     97 bool IsCustomCallToDnnConvolution(const HloInstruction& hlo);
     98 
     99 // Creates a CustomCall for a cudnn forward/backward-input/backward-filter conv.
    100 // Note that these CustomCalls return a tuple (conv_result, scratch_memory).  If
    101 // you want just the conv result, you'll need to get-tuple-element the value
    102 // returned by this function.
    103 //
    104 // The created cudnn call will use the default cudnn algorithm and no scratch
    105 // space.
    106 HloInstruction* CreateCudnnConvForward(
    107     const Shape& shape, HloInstruction* input, HloInstruction* kernel,
    108     const Window& window, const ConvolutionDimensionNumbers& dnums);
    109 HloInstruction* CreateCudnnConvBackwardInput(
    110     const Shape& shape, HloInstruction* output, HloInstruction* reverse_filter,
    111     const Window& window, const ConvolutionDimensionNumbers& dnums);
    112 HloInstruction* CreateCudnnConvBackwardFilter(
    113     const Shape& shape, HloInstruction* input, HloInstruction* output,
    114     const Window& window, const ConvolutionDimensionNumbers& dnums);
    115 
    116 // Returns true if `hlo` will be implemented as a library call, e.g. cuBLAS gemm
    117 // or cuDNN convolution.
    118 bool ImplementedAsLibraryCall(const HloInstruction& hlo);
    119 
    120 bool IsReductionToVector(const HloInstruction& reduce);
    121 
    122 // Emits call to "vprintf" with given format and arguments.
    123 llvm::Value* EmitPrintf(tensorflow::StringPiece fmt,
    124                         tensorflow::gtl::ArraySlice<llvm::Value*> arguments,
    125                         llvm::IRBuilder<>* builder);
    126 
    127 // Emits code to shuffle data between threads of a warp. This has the same
    128 // semantics as the PTX "shfl.down" instruction [0] but works for values of any
    129 // size. The last operand of the emitted "shfl" is `kWarpSize - 1`.
    130 //
    131 // [0]
    132 // http://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-shfl
    133 llvm::Value* EmitShuffleDown(llvm::Value* value, llvm::Value* offset,
    134                              llvm::IRBuilder<>* builder);
    135 
    136 }  // namespace gpu
    137 }  // namespace xla
    138 
    139 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMISSION_UTILS_H_
    140