     16 #include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
     18 #include <algorithm>
     19 #include <list>
     20 #include <map>
     21 #include <memory>
     22 #include <set>
     23 #include <unordered_map>
     24 #include <utility>
     25 #include <vector>
     27 #include "tensorflow/core/framework/node_def_builder.h"
     28 #include "tensorflow/core/framework/tensor_shape.pb.h"  // NOLINT
     29 #include "tensorflow/core/framework/types.h"
     30 #include "tensorflow/core/graph/algorithm.h"
     31 #include "tensorflow/core/graph/graph.h"
     32 #include "tensorflow/core/graph/graph_constructor.h"
     33 #include "tensorflow/core/lib/core/errors.h"
     34 #include "tensorflow/core/lib/core/status.h"
     35 #include "tensorflow/core/lib/strings/strcat.h"
     36 #include "tensorflow/core/platform/logging.h"
     37 #include "tensorflow/core/platform/tensor_coding.h"
     38 #include "tensorflow/core/platform/types.h"
     40 #if GOOGLE_CUDA
     42 #include "tensorflow/contrib/tensorrt/log/trt_logger.h"
     43 #include "tensorrt/include/NvInfer.h"
     45 //  Check if the types are equal. Cast to int first so that failure log message
     46 //  would work!
     47 #define CHECK_EQ_TYPE(val1, val2) CHECK_EQ((int)val1, (int)val2)
     49 namespace tensorflow {
     50 namespace tensorrt {
     51 namespace convert {
     53 namespace {
     55 inline tensorflow::Status ConvertDType(tensorflow::DataType tf_dtype,
     56                                        nvinfer1::DataType* trt_dtype) {
     57   switch (tf_dtype) {
     58     case tensorflow::DataType::DT_FLOAT:
     59       *trt_dtype = nvinfer1::DataType::kFLOAT;
     60       break;
     61     case tensorflow::DataType::DT_INT8:
     62       *trt_dtype = nvinfer1::DataType::kINT8;
     63       break;
     64     case tensorflow::DataType::DT_HALF:
     65       *trt_dtype = nvinfer1::DataType::kHALF;
     66       break;
     67     default:
     68       return tensorflow::errors::InvalidArgument("Unsupported data type");
     69   }
     70   return tensorflow::Status::OK();
     71 }
     73 inline nvinfer1::Dims GetTensorShape(const tensorflow::Tensor& tensor) {
     74   nvinfer1::Dims dims;
     75   dims.nbDims = tensor.dims();
     76   for (int i = 0; i < dims.nbDims; i++) {
     77     dims.d[i] = tensor.dim_size(i);
     78   }
     79   return dims;
     80 }
     82 inline int64_t GetShapeSize(nvinfer1::Dims shape) {
     83   // Returns total number of elements in shape
     84   int64_t count = 1;
     85   for (int d = 0; d < shape.nbDims; ++d) {
     86     count *= shape.d[d];
     87   }
     88   return count;
     89 }
     91 static std::vector<std::pair<int, int>> CreateSamePadding(
     92     const nvinfer1::DimsHW& stride, const nvinfer1::DimsHW& kernel,
     93     const std::vector<int64_t>& input_dims) {
     94   std::vector<std::pair<int, int>> padding(input_dims.size());
     95   CHECK_EQ((size_t)stride.nbDims, input_dims.size());  // TODO(jie): N+C? NC+?
     97   for (size_t i = 0; i < input_dims.size(); ++i) {
     98     // Formula to calculate the padding
     99     int p = ((input_dims[i] - 1) / stride.d[i]) * stride.d[i] + kernel.d[i] -
    100             input_dims[i];
    101     p = (p > 0) ? p : 0;
    103     // Right precedence padding, like in TensorFlow
    104     int left = p / 2;
    105     int right = p - left;
    107     VLOG(2) << "PADDING_" << i << " pre: " << left << ", post: " << right
    108             << "paras: " << input_dims[i] << ", " << stride.d[i] << ", "
    109             << "kernel: " << kernel.d[i];
    110     padding[i] = {left, right};
    111   }
    112   return padding;
    113 }
    115 class TRT_ShapedWeights {
    116  public:
    117   TRT_ShapedWeights(tensorflow::DataType type, const void* values,
    118                     nvinfer1::Dims shape)
    119       : shape_(shape), type_(type), values_(values), empty_weight_flag_(false) {
    120     // Note: this->shape.type[] is not used
    121   }
    123   explicit TRT_ShapedWeights(tensorflow::DataType type)
    124       : shape_(), type_(type), values_(nullptr), empty_weight_flag_(true) {}
    126   TRT_ShapedWeights(const TRT_ShapedWeights& rhs)
    127       : shape_(rhs.shape_),
    128         type_(rhs.type_),
    129         values_(rhs.values_),
    130         empty_weight_flag_(rhs.empty_weight_flag_) {}
    132   int64_t count() const {
    133     int64_t c = 1;
    134     for (int i = 0; i < shape_.nbDims; i++) c *= shape_.d[i];
    135     return c;
    136   }
    138   nvinfer1::Weights GetWeightsForTRT() const {
    139     nvinfer1::DataType trt_type(nvinfer1::DataType::kFLOAT);
    140     TF_CHECK_OK(ConvertDType(type_, &trt_type));
    141     if (empty_weight_flag_) return nvinfer1::Weights{trt_type, nullptr, 0};
    143     // Note: this->shape.type[] is not used
    144     return nvinfer1::Weights{trt_type, GetValues(), GetShapeSize(shape_)};
    145   }
    147   const void* GetValues() const { return values_; }
    149   void SetValues(const void* values) { values_ = values; }
    151   size_t size_bytes() const {
    152     int type_size = tensorflow::DataTypeSize(this->type_);
    153     return this->count() * type_size;
    154   }
    156   // Default converter
    157   operator nvinfer1::Weights() const { return GetWeightsForTRT(); }
    159   nvinfer1::Dims shape_;
    160   tensorflow::DataType type_;
    162  private:
    163   const void* values_;
    164   bool empty_weight_flag_;
    165 };
    167 class TRT_TensorOrWeights {
    168  public:
    169   explicit TRT_TensorOrWeights(nvinfer1::ITensor* tensor)
    170       : tensor_(tensor), weights_(DT_FLOAT), variant_(TRT_NODE_TENSOR) {}
    171   explicit TRT_TensorOrWeights(const TRT_ShapedWeights& weights)
    172       : tensor_(nullptr), weights_(weights), variant_(TRT_NODE_WEIGHTS) {}
    173   TRT_TensorOrWeights(const TRT_TensorOrWeights& rhs)
    174       : tensor_(rhs.tensor_), weights_(rhs.weights_), variant_(rhs.variant_) {}
    175   ~TRT_TensorOrWeights() {}
    177   bool is_tensor() const { return variant_ == TRT_NODE_TENSOR; }
    178   bool is_weights() const { return variant_ == TRT_NODE_WEIGHTS; }
    180   nvinfer1::ITensor* tensor() {
    181     CHECK_EQ(is_tensor(), true);
    182     return tensor_;
    183   }
    184   const nvinfer1::ITensor* tensor() const {
    185     CHECK_EQ(is_tensor(), true);
    186     return tensor_;
    187   }
    188   TRT_ShapedWeights& weights() {
    189     CHECK_EQ(is_weights(), true);
    190     return weights_;
    191   }
    192   const TRT_ShapedWeights& weights() const {
    193     CHECK_EQ(is_weights(), true);
    194     return weights_;
    195   }
    196   nvinfer1::Dims shape() const {
    197     if (is_tensor()) {
    198       return tensor()->getDimensions();
    199     } else {
    200       return weights().shape_;
    201     }
    202   }
    204  private:
    205   nvinfer1::ITensor* tensor_;
    206   TRT_ShapedWeights weights_;
    207   enum { TRT_NODE_TENSOR, TRT_NODE_WEIGHTS } variant_;
    208 };
    210 class TFAttrs {
    211  public:
    212   explicit TFAttrs(const tensorflow::NodeDef& tf_node) {
    213     for (const auto& attr : tf_node.attr()) {
    214       attrs_.insert({attr.first, &attr.second});
    215     }
    216   }
    217   bool count(string key) const { return attrs_.count(key); }
    218   tensorflow::AttrValue const* at(string key) const {
    219     if (!attrs_.count(key)) {
    220       LOG(FATAL) << "Attribute not found: " << key;
    221     }
    222     return attrs_.at(key);
    223   }
    224   template <typename T>
    225   T get(string key) const;
    226   template <typename T>
    227   T get(string key, const T& default_value) const {
    228     return attrs_.count(key) ? this->get<T>(key) : default_value;
    229   }
    231  private:
    232   typedef std::map<string, tensorflow::AttrValue const*> AttrMap;
    233   AttrMap attrs_;
    234 };
    236 template <>
    237 string TFAttrs::get<string>(string key) const {
    238   return this->at(key)->s();
    239 }
    241 template <>
    242 std::vector<int> TFAttrs::get<std::vector<int>>(string key) const {
    243   auto attr = this->at(key)->list().i();
    244   return std::vector<int>(attr.begin(), attr.end());
    245 }
    247 template <>
    248 nvinfer1::Dims TFAttrs::get<nvinfer1::Dims>(string key) const {
    249   auto values = this->get<std::vector<int>>(key);
    250   nvinfer1::Dims dims;
    251   dims.nbDims = values.size();
    252   std::copy(values.begin(), values.end(), dims.d);
    253   // Note: No dimension type information is included
    254   return dims;
    255 }
    257 template <>
    258 nvinfer1::DataType TFAttrs::get<nvinfer1::DataType>(string key) const {
    259   nvinfer1::DataType trt_dtype(nvinfer1::DataType::kFLOAT);
    260   TF_CHECK_OK(ConvertDType(this->at(key)->type(), &trt_dtype));
    261   return trt_dtype;
    262 }
    264 template <>
    265 tensorflow::DataType TFAttrs::get<tensorflow::DataType>(string key) const {
    266   return this->at(key)->type();
    267 }
    269 template <typename T>
    270 void Reorder4(nvinfer1::DimsNCHW shape, const T* idata,
    271               nvinfer1::DimsNCHW istrides, T* odata,
    272               nvinfer1::DimsNCHW ostrides) {
    273   for (int n = 0; n < shape.n(); ++n) {
    274     for (int c = 0; c < shape.c(); ++c) {
    275       for (int h = 0; h < shape.h(); ++h) {
    276         for (int w = 0; w < shape.w(); ++w) {
    277           odata[n * ostrides.n() + c * ostrides.c() + h * ostrides.h() +
    278                 w * ostrides.w()] = idata[n * istrides.n() + c * istrides.c() +
    279                                           h * istrides.h() + w * istrides.w()];
    280         }
    281       }
    282     }
    283   }
    284 }
    286 void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights,
    287                        TRT_ShapedWeights* oweights) {
    288   CHECK_EQ(iweights.type_, oweights->type_);
    289   CHECK_EQ(iweights.size_bytes(), oweights->size_bytes());
    290   int r = iweights.shape_.d[0];
    291   int s = iweights.shape_.d[1];
    292   int c = iweights.shape_.d[2];
    293   int k = iweights.shape_.d[3];
    294   oweights->shape_.d[0] = k;
    295   oweights->shape_.d[1] = c;
    296   oweights->shape_.d[2] = r;
    297   oweights->shape_.d[3] = s;
    298   nvinfer1::DimsNCHW istrides = {1, k, s * k * c, c * k};
    299   nvinfer1::DimsNCHW ostrides = {c * r * s, r * s, s, 1};
    300   switch (iweights.type_) {
    301     case tensorflow::DataType::DT_FLOAT:
    302       Reorder4({k, c, r, s}, static_cast<float const*>(iweights.GetValues()),
    303                istrides,
    304                static_cast<float*>(const_cast<void*>(oweights->GetValues())),
    305                ostrides);
    306       break;
    307     default:
    308       LOG(FATAL) << "!!!!!!!!!!!!!!!!!!!!!!!!broke!!!!!!!!!!!!";
    309   }
    310 }
    312 struct InferDeleter {
    313   template <typename T>
    314   void operator()(T* obj) const {
    315     if (obj) {
    316       obj->destroy();
    317     }
    318   }
    319 };
    321 template <typename T>
    322 inline std::shared_ptr<T> infer_object(T* obj) {
    323   return std::shared_ptr<T>(obj, InferDeleter());
    324 }
    326 // Logger for GIE info/warning/errors
    327 class Converter;
    329 using OpConverter =
    330     std::function<tensorflow::Status(Converter&, const tensorflow::NodeDef&,
    331                                      std::vector<TRT_TensorOrWeights> const&,
    332                                      std::vector<TRT_TensorOrWeights>*)>;
    334 class Converter {
    335   std::unordered_map<string, TRT_TensorOrWeights> trt_tensors_;
    336   std::unordered_map<string, OpConverter> op_registry_;
    337   nvinfer1::INetworkDefinition* trt_network_;
    338   std::list<std::vector<uint8_t>> temp_bufs_;
    340   void register_op_converters();
    342   std::vector<TRT_TensorOrWeights> get_inputs(
    343       const tensorflow::NodeDef& node_def) {
    344     std::vector<TRT_TensorOrWeights> inputs;
    345     for (const auto& input_name : node_def.input()) {
    346       VLOG(2) << "Retrieve input: " << input_name;
    347       inputs.push_back(trt_tensors_.at(input_name));
    348     }
    349     return inputs;
    350   }
    352  public:
    353   explicit Converter(nvinfer1::INetworkDefinition* trt_network)
    354       : trt_network_(trt_network) {
    355     this->register_op_converters();
    356   }
    358   TRT_ShapedWeights get_temp_weights(tensorflow::DataType type,
    359                                      nvinfer1::Dims shape) {
    360     TRT_ShapedWeights weights(type, nullptr, shape);
    361     // TODO(jie): check weights size_bytes. 0 means type error
    362     temp_bufs_.push_back(std::vector<uint8_t>(weights.size_bytes()));
    363     weights.SetValues(temp_bufs_.back().data());
    364     return weights;
    365   }
    367   TRT_ShapedWeights get_temp_weights_like(const TRT_ShapedWeights& weights) {
    368     return this->get_temp_weights(weights.type_, weights.shape_);
    369   }
    371   tensorflow::Status convert_node(const tensorflow::NodeDef& node_def) {
    372     std::vector<TRT_TensorOrWeights> inputs = this->get_inputs(node_def);
    373     string op = node_def.op();
    374     if (!op_registry_.count(op)) {
    375       return tensorflow::errors::Unimplemented(
    376           "No converter registered for op: " + op);
    377     }
    378     OpConverter op_converter = op_registry_.at(op);
    379     std::vector<TRT_TensorOrWeights> outputs;
    380     TF_RETURN_IF_ERROR(op_converter(*this, node_def, inputs, &outputs));
    381     for (size_t i = 0; i < outputs.size(); ++i) {
    382       TRT_TensorOrWeights output = outputs.at(i);
    383       // TODO(jie): tf protobuf seems to be omitting the :0 suffix
    384       string output_name = node_def.name();
    385       if (i != 0) output_name = output_name + ":" + std::to_string(i);
    386       if (output.is_tensor()) {
    387         output.tensor()->setName(output_name.c_str());
    388       }
    389       VLOG(2) << "Write out tensor: " << output_name;
    390       if (!trt_tensors_.insert({output_name, output}).second) {
    391         return tensorflow::errors::AlreadyExists(
    392             "Output tensor already exists for op: " + op);
    393       }
    394     }
    395     return tensorflow::Status::OK();
    396   }
    398   nvinfer1::INetworkDefinition* network() { return trt_network_; }
    400   TRT_TensorOrWeights get_tensor(string name) {
    401     if (!trt_tensors_.count(name)) {
    402       return TRT_TensorOrWeights(nullptr);
    403     }
    404     return trt_tensors_.at(name);
    405   }
    407   bool insert_input_tensor(string name, nvinfer1::ITensor* tensor) {
    408     return trt_tensors_.insert({name, TRT_TensorOrWeights(tensor)}).second;
    409   }
    411   nvinfer1::ITensor* TransposeTensor(nvinfer1::ITensor* input_tensor,
    412                                      std::vector<int> order) {
    413     auto dims = input_tensor->getDimensions();
    415     // TODO(jie): change the return to status and properly exit
    416     if (order.size() - 1 != size_t(dims.nbDims))
    417       LOG(ERROR) << "Dimension does not match, fail gracefully";
    419     nvinfer1::IShuffleLayer* layer = this->network()->addShuffle(*input_tensor);
    420     nvinfer1::Permutation permutation;
    421     for (int32_t i = 0; i < dims.nbDims; ++i) {
    422       permutation.order[i] = order[i + 1] - 1;
    423     }
    424     layer->setFirstTranspose(permutation);
    426     nvinfer1::Dims reshape_dims;
    427     reshape_dims.nbDims = dims.nbDims;
    428     for (int32_t i = 0; i < reshape_dims.nbDims; ++i) {
    429       reshape_dims.d[i] = 0;
    430       reshape_dims.type[i] = dims.type[i];
    431     }
    432     layer->setReshapeDimensions(reshape_dims);
    433     return layer->getOutput(0);
    434   }
    435 };
    437 // ****************************************************************************
    438 // Constant folding functions
    439 // TODO(jie): once optimizer kicks in, we should have done constant folding
    440 // there.
    441 //*****************************************************************************/
    442 struct LambdaFactory {
    443   enum class OP_CATEGORY : int { RSQRT = 0, NEG, ADD, MUL, SUB };
    444   OP_CATEGORY op;
    446   template <typename T>
    447   std::function<T(T)> unary() {
    448     switch (op) {
    449       case OP_CATEGORY::RSQRT: {
    450         VLOG(2) << "RSQRT GETS DONE";
    451         return [](T t) -> T { return 1.0 / std::sqrt(t); };
    452       }
    453       case OP_CATEGORY::NEG:
    454         return [](T t) -> T { return -t; };
    455       default:
    456         VLOG(2) << "Not supported op for unary: " << static_cast<int>(op);
    457         return nullptr;
    458     }
    459   }
    461   template <typename T>
    462   std::function<T(T, T)> binary() {
    463     switch (op) {
    464       case OP_CATEGORY::ADD:
    465         return [](T l, T r) -> T { return l + r; };
    466       case OP_CATEGORY::SUB:
    467         return [](T l, T r) -> T { return l - r; };
    468       case OP_CATEGORY::MUL:
    469         return [](T l, T r) -> T { return l * r; };
    470       default:
    471         LOG(WARNING) << "Not supported op for binary: " << static_cast<int>(op);
    472     }
    473     return [](T l, T r) -> T {
    474       LOG(FATAL) << "Unsupported op type ";
    475       return l;
    476     };
    477   }
    479   template <typename T>
    480   std::function<T(T)> broadcast_r(T val) {
    481     VLOG(2) << "LAMBDA VAL : " << val;
    482     switch (op) {
    483       case OP_CATEGORY::ADD:
    484         return [val](T l) -> T {
    485           VLOG(2) << "LAMBDA VAL : " << val;
    486           return l + val;
    487         };
    488       // Return [val](T l)-> T {return l+val;};
    489       case OP_CATEGORY::SUB:
    490         return [val](T l) -> T {
    491           VLOG(2) << "LAMBDA VAL : " << val;
    492           return l - val;
    493         };
    494       case OP_CATEGORY::MUL:
    495         return [val](T l) -> T {
    496           VLOG(2) << "LAMBDA VAL : " << val;
    497           return l * val;
    498         };
    499       default:
    500         LOG(WARNING) << "Not supported op for binary: " << static_cast<int>(op);
    501     }
    502     return [val](T l) -> T {
    503       LOG(FATAL) << "Unsupported op type ";
    504       return l;
    505     };
    506   }
    508   template <typename T>
    509   std::function<T(T)> broadcast_l(T val) {
    510     VLOG(2) << "LAMBDA VAL : " << val;
    511     switch (op) {
    512       case OP_CATEGORY::ADD:
    513         return [val](T l) -> T {
    514           VLOG(2) << "LAMBDA VAL : " << val;
    515           return val + l;
    516         };
    517       case OP_CATEGORY::SUB:
    518         return [val](T l) -> T {
    519           VLOG(2) << "LAMBDA VAL : " << val;
    520           return val - l;
    521         };
    522       case OP_CATEGORY::MUL:
    523         return [val](T l) -> T {
    524           VLOG(2) << "LAMBDA VAL : " << val;
    525           return val * l;
    526         };
    527       default:
    528         LOG(ERROR) << "Not supported op for binary: " << static_cast<int>(op);
    529     }
    530     return [val](T l) -> T {
    531       LOG(FATAL) << "Unsupported op type ";
    532       return l;
    533     };
    534   }
    535 };
    537 tensorflow::Status UnaryCompute(const TRT_ShapedWeights& iweights,
    538                                 TRT_ShapedWeights* oweights,
    539                                 LambdaFactory unary_op) {
    540   CHECK_EQ(iweights.type_, oweights->type_);
    541   switch (iweights.type_) {
    542     case tensorflow::DataType::DT_FLOAT: {
    543       auto inp = static_cast<float const*>(iweights.GetValues());
    544       auto oup = static_cast<float*>(const_cast<void*>(oweights->GetValues()));
    545       std::transform(inp, inp + iweights.count(), oup, unary_op.unary<float>());
    546       break;
    547     }
    548     default:
    549       return tensorflow::errors::Unimplemented(
    550           "Data type not supported: " +
    551           tensorflow::DataTypeString(iweights.type_));
    552   }
    553   return tensorflow::Status::OK();
    554 }
    556 tensorflow::Status BinaryCompute(const TRT_ShapedWeights& iweights_l,
    557                                  const TRT_ShapedWeights& iweights_r,
    558                                  TRT_ShapedWeights* oweights,
    559                                  LambdaFactory binary_op) {
    560   // Assume iweights_l.type == iweight_r.type
    561   CHECK_EQ(iweights_l.type_, oweights->type_);
    562   CHECK_EQ(iweights_r.type_, oweights->type_);
    563   VLOG(2) << "SANITY CHECK!";
    565   switch (iweights_l.type_) {
    566     case tensorflow::DataType::DT_FLOAT: {
    567       auto inp_l = static_cast<const float*>(iweights_l.GetValues());
    568       auto inp_r = static_cast<const float*>(iweights_r.GetValues());
    569       auto oup = static_cast<float*>(const_cast<void*>(oweights->GetValues()));
    571       if (iweights_l.count() != iweights_r.count()) {
    572         // We only supports broadcast of RankZero
    573         if (iweights_l.count() == 1) {
    574           VLOG(2) << "I bet it is not working!" << (*inp_l);
    575           std::transform(inp_r, inp_r + iweights_r.count(), oup,
    576                          binary_op.broadcast_l<float>(*inp_l));
    577         } else if (iweights_r.count() == 1) {
    578           VLOG(2) << "I bet it is not working!" << (*inp_r);
    579           std::transform(inp_l, inp_l + iweights_l.count(), oup,
    580                          binary_op.broadcast_r<float>(*inp_r));
    581         } else {
    582           return tensorflow::errors::Unimplemented(
    583               "Binary op with non-rankZero broadcast not supported");
    584         }
    585       } else {
    586         std::transform(inp_l, inp_l + iweights_l.count(), inp_r, oup,
    587                        binary_op.binary<float>());
    588       }
    589       break;
    590     }
    591     default:
    592       return tensorflow::errors::Unimplemented(
    593           "Data type not supported: " +
    594           tensorflow::DataTypeString(iweights_l.type_));
    595   }
    597   return tensorflow::Status::OK();
    598 }
    600 tensorflow::Status ConstantFoldUnary(
    601     Converter& ctx, const tensorflow::NodeDef& node_def,
    602     std::vector<TRT_TensorOrWeights> const& inputs,
    603     std::vector<TRT_TensorOrWeights>* outputs) {
    604   TRT_ShapedWeights weights_input = inputs.at(0).weights();
    606   // Allocate output weights
    607   TRT_ShapedWeights weights_output = ctx.get_temp_weights_like(weights_input);
    609   // FIXME assume type matches input weights
    610   // Get trt type & shape
    611   // Maybe this part has to be moved into the block of rsqrt later
    612   // Check type consistency
    613   CHECK_EQ(weights_input.type_,
    614            TFAttrs(node_def).get<tensorflow::DataType>("T"));
    616   // Maybe I should do a switch
    617   LambdaFactory unary_op;
    618   if (node_def.op() == "Rsqrt") {
    619     // Compute rsqrt
    620     unary_op.op = LambdaFactory::OP_CATEGORY::RSQRT;
    621     auto ret = UnaryCompute(weights_input, &weights_output, unary_op);
    622     // PAss the output
    623     if (ret == tensorflow::Status::OK()) {
    624       outputs->push_back(TRT_TensorOrWeights(weights_output));
    625     }
    626     return ret;
    627   } else {
    628     return tensorflow::errors::Unimplemented("Binary op not supported: " +
    629                                              node_def.op());
    630   }
    631 }
    633 // TODO(jie,ben) broadcast is needed yet not implemented
    634 // Let's get the simple stuff working first. Maybe we should fall bakc to TF
    635 //   approach for constant folding
    636 tensorflow::Status ConstantFoldBinary(
    637     Converter& ctx, const tensorflow::NodeDef& node_def,
    638     std::vector<TRT_TensorOrWeights> const& inputs,
    639     std::vector<TRT_TensorOrWeights>* outputs) {
    640   TRT_ShapedWeights weights_input_l = inputs.at(0).weights();
    641   TRT_ShapedWeights weights_input_r = inputs.at(1).weights();
    643   // Check type consistency
    644   CHECK_EQ(weights_input_l.type_, weights_input_r.type_);
    646   if (weights_input_l.shape_.nbDims != weights_input_r.shape_.nbDims)
    647     return tensorflow::errors::Unimplemented(
    648         "Binary op implicit broadcast not supported: " + node_def.op());
    650   // TODO(jie): constant fold should really fall back to TF.
    651   int nb_dims = weights_input_l.shape_.nbDims;
    652   nvinfer1::Dims output_shape;
    653   output_shape.nbDims = nb_dims;
    654   VLOG(2) << "nb_dims: " << nb_dims
    655           << ", the other: " << weights_input_r.shape_.nbDims;
    656   for (int i = 0; i < nb_dims; i++) {
    657     if (weights_input_l.shape_.d[i] == weights_input_r.shape_.d[i]) {
    658       output_shape.d[i] = weights_input_l.shape_.d[i];
    659     } else if (weights_input_l.shape_.d[i] == 1 ||
    660                weights_input_r.shape_.d[i] == 1) {
    661       output_shape.d[i] =
    662           std::max(weights_input_l.shape_.d[i], weights_input_r.shape_.d[i]);
    663     } else {
    664       return tensorflow::errors::Unimplemented(
    665           "Binary op with incompatible shape at, " + node_def.op());
    666     }
    667     VLOG(2) << "left: " << weights_input_l.shape_.d[i]
    668             << "right: " << weights_input_r.shape_.d[i]
    669             << "output: " << output_shape.d[i];
    670   }
    672   // FIXME assume type matches input weights
    673   // Get trt type & shape
    674   TFAttrs attrs(node_def);
    675   // Maybe this part has to be moved into the block of rsqrt later
    676   tensorflow::DataType dtype = attrs.get<tensorflow::DataType>("T");
    678   // Allocate output weights
    679   TRT_ShapedWeights weights_output = ctx.get_temp_weights(dtype, output_shape);
    681   // Maybe I should do a switch
    682   LambdaFactory binary_op;
    683   if (node_def.op() == "Sub") {
    684     binary_op.op = LambdaFactory::OP_CATEGORY::SUB;
    685   } else if (node_def.op() == "Mul") {
    686     binary_op.op = LambdaFactory::OP_CATEGORY::MUL;
    687   } else if (node_def.op() == "Add") {
    688     binary_op.op = LambdaFactory::OP_CATEGORY::ADD;
    689   } else {
    690     return tensorflow::errors::Unimplemented("Binary op not supported: " +
    691                                              node_def.op());
    692   }
    693   auto ret = BinaryCompute(weights_input_l, weights_input_r, &weights_output,
    694                            binary_op);
    696   // Pass the output
    697   if (ret == tensorflow::Status::OK()) {
    698     outputs->push_back(TRT_TensorOrWeights(weights_output));
    699   }
    701   return ret;
    702 }
    704 // TODO(jie): broadcast is needed yet not implemented.
    705 // Only implemented channel wise for the time being
    706 tensorflow::Status BinaryTensorOpWeight(
    707     Converter& ctx, const tensorflow::NodeDef& node_def,
    708     const nvinfer1::ITensor* tensor, TRT_ShapedWeights weights,
    709     std::vector<TRT_TensorOrWeights>* outputs) {
    710   // FIXME assume type matches input weights
    711   // Get trt type & shape
    712   // Maybe this part has to be moved into the block of rsqrt later
    714   // Check type consistency
    715   auto dtype = TFAttrs(node_def).get<nvinfer1::DataType>("T");
    716   CHECK_EQ_TYPE(tensor->getType(), dtype);  // Cast to int for error messages
    717   nvinfer1::DataType ttype;
    718   TF_CHECK_OK(ConvertDType(weights.type_, &ttype));
    719   CHECK_EQ_TYPE(ttype, dtype);  // Cast to int for error message
    721   // Check scale mode
    722   auto dims_w = weights.shape_;
    723   auto dims_t = tensor->getDimensions();
    725   // Default to channel-wise
    726   auto scale_mode = nvinfer1::ScaleMode::kELEMENTWISE;
    728   if (weights.count() == 1) {
    729     VLOG(2) << "UNIFORM";
    730     scale_mode = nvinfer1::ScaleMode::kUNIFORM;
    731   } else {
    732     // No broadcasting on Batch dimension;
    733     assert(dims_w.d[0] == 1);
    735     // Broadcasting on Channel dimension only allowed in kUNIFORM
    736     assert(dims_w.d[1] == dims_t.d[0]);
    737     assert(dims_w.nbDims == dims_t.nbDims);
    739     // Default is element;
    740     for (int i = 2; i < dims_w.nbDims; i++) {
    741       if (dims_w.d[i] != dims_t.d[i - 1]) {
    742         scale_mode = nvinfer1::ScaleMode::kCHANNEL;
    743         break;
    744       }
    745     }
    746     if (scale_mode == nvinfer1::ScaleMode::kELEMENTWISE) {
    747       scale_mode = nvinfer1::ScaleMode::kELEMENTWISE;
    748       for (int i = 2; i < dims_w.nbDims; i++) {
    749         if (dims_w.d[i] != 1)
    750           return tensorflow::errors::InvalidArgument(
    751               "Weight shape not compatible at, " + node_def.name());
    752       }
    753     }
    754   }
    756   // Prepare weights
    757   TRT_ShapedWeights shift_weights(weights.type_);
    758   TRT_ShapedWeights scale_weights(weights.type_);
    759   TRT_ShapedWeights power_weights(weights.type_);
    761   // Maybe I should do a switch
    762   if (node_def.op() == "Sub") {
    763     TRT_ShapedWeights neg_weights = ctx.get_temp_weights_like(weights);
    764     LambdaFactory unary_op;
    765     unary_op.op = LambdaFactory::OP_CATEGORY::NEG;
    766     TF_RETURN_IF_ERROR(UnaryCompute(weights, &neg_weights, unary_op));
    767     shift_weights = neg_weights;
    768   } else if (node_def.op() == "Mul") {
    769     scale_weights = weights;
    770   } else if (node_def.op() == "Add") {
    771     shift_weights = weights;
    772   } else {
    773     return tensorflow::errors::Unimplemented("Binary op not supported: " +
    774                                              node_def.op());
    775   }
    777   nvinfer1::IScaleLayer* layer = ctx.network()->addScale(
    778       *const_cast<nvinfer1::ITensor*>(tensor), scale_mode, shift_weights,
    779       scale_weights, power_weights);
    781   nvinfer1::ITensor* output_tensor = layer->getOutput(0);
    783   // Pass the output
    784   outputs->push_back(TRT_TensorOrWeights(output_tensor));
    785   return tensorflow::Status::OK();
    786 }
    788 tensorflow::Status BinaryTensorOpTensor(
    789     Converter& ctx, const tensorflow::NodeDef& node_def,
    790     const nvinfer1::ITensor* tensor_l, const nvinfer1::ITensor* tensor_r,
    791     std::vector<TRT_TensorOrWeights>* outputs) {
    792   static const std::unordered_map<string, nvinfer1::ElementWiseOperation> ops{
    793       {"Add", nvinfer1::ElementWiseOperation::kSUM},
    794       {"Mul", nvinfer1::ElementWiseOperation::kPROD},
    795       // {"max", nvinfer1::ElementWiseOperation::kMAX},
    796       // {"min", nvinfer1::ElementWiseOperation::kMIN},
    797       {"Sub", nvinfer1::ElementWiseOperation::kSUB},
    798       {"Div", nvinfer1::ElementWiseOperation::kDIV},
    799   };
    801   // FIXME assume type matches input weights
    802   // Get trt type & shape
    803   TFAttrs attrs(node_def);
    804   // Maybe this part has to be moved into the block of rsqrt later
    805   nvinfer1::DataType dtype = attrs.get<nvinfer1::DataType>("T");
    807   // Check type consistency
    808   CHECK_EQ_TYPE(tensor_l->getType(), dtype);
    809   CHECK_EQ_TYPE(tensor_r->getType(), dtype);
    810   auto op_pair = ops.find(node_def.op());
    811   if (op_pair == ops.end())
    812     return tensorflow::errors::Unimplemented("binary op: " + node_def.op() +
    813                                              " not supported at: " +
    814                                              node_def.name());
    816   nvinfer1::IElementWiseLayer* layer = ctx.network()->addElementWise(
    817       *const_cast<nvinfer1::ITensor*>(tensor_l),
    818       *const_cast<nvinfer1::ITensor*>(tensor_r), op_pair->second);
    820   nvinfer1::ITensor* output_tensor = layer->getOutput(0);
    822   // Pass the output
    823   outputs->push_back(TRT_TensorOrWeights(output_tensor));
    824   return tensorflow::Status::OK();
    825 }
    827 tensorflow::Status ConvertPlaceholder(
    828     Converter& ctx, const tensorflow::NodeDef& node_def,
    829     std::vector<TRT_TensorOrWeights> const& inputs,
    830     std::vector<TRT_TensorOrWeights>* outputs) {
    831   VLOG(2) << "Placeholder should have been replace already";
    832   return tensorflow::errors::Unimplemented(", cannot convert Placeholder op");
    833   // OK this make sense since we are supposed to replace it with input
    834   TFAttrs attrs(node_def);
    835   nvinfer1::DataType dtype = attrs.get<nvinfer1::DataType>("dtype");
    836   nvinfer1::Dims dims = attrs.get<nvinfer1::Dims>("shape");
    838   dims.nbDims--;
    839   for (int i = 0; i < dims.nbDims; i++) dims.d[i] = dims.d[i + 1];
    841   nvinfer1::ITensor* output =
    842       ctx.network()->addInput(node_def.name().c_str(), dtype, dims);
    843   if (!output) {
    844     return tensorflow::errors::InvalidArgument("Failed to create Input layer");
    845   }
    846   outputs->push_back(TRT_TensorOrWeights(output));
    847   return tensorflow::Status::OK();
    848 }
    850 tensorflow::Status ConvertConv2D(Converter& ctx,
    851                                  const tensorflow::NodeDef& node_def,
    852                                  const std::vector<TRT_TensorOrWeights>& inputs,
    853                                  std::vector<TRT_TensorOrWeights>* outputs) {
    854   nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
    855   // TODO(jie): handle NHWC/NCHW transpose;
    856   TRT_ShapedWeights weights_rsck = inputs.at(1).weights();
    857   TRT_ShapedWeights weights = ctx.get_temp_weights_like(weights_rsck);
    858   ReorderRSCKToKCRS(weights_rsck, &weights);
    859   TRT_ShapedWeights biases(weights.type_);
    860   int noutput = weights.shape_.d[0];
    861   nvinfer1::DimsHW kernel_size;
    862   kernel_size.h() = weights.shape_.d[2];
    863   kernel_size.w() = weights.shape_.d[3];
    864   TFAttrs attrs(node_def);
    866   int h_index = 2;
    867   int w_index = 3;
    868   auto data_format = attrs.get<string>("data_format");
    869   if (data_format == "NHWC") {
    870     tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
    871                                  {0, 3, 1, 2});
    872     h_index = 1;
    873     w_index = 2;
    874     // TODO(jie): transpose it
    875   }
    877   // TODO(jie): stride. (NHWC/NCHW)
    878   auto tf_stride = attrs.get<std::vector<int>>("strides");
    879   nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
    881   auto tensor_dim = tensor->getDimensions();
    882   std::vector<std::pair<int, int>> padding;
    883   // TODO(jie): padding.
    884   if (attrs.get<string>("padding") == "SAME") {
    885     // This is NCHW tensor with no batch dimension.
    886     //  1 -> h
    887     //  2 -> w
    888     padding = CreateSamePadding(
    889         stride, kernel_size,
    890         {static_cast<int>(tensor_dim.d[1]), static_cast<int>(tensor_dim.d[2])});
    891   } else {
    892     padding = {{0, 0}, {0, 0}};
    893   }
    895   if (padding[0].first != padding[0].second ||
    896       padding[1].first != padding[1].second) {
    897     // TODO(jie): handle asymmetric padding
    898     VLOG(2) << "Padding!!!: " << padding[0].first << padding[0].second
    899             << padding[1].first << padding[1].second;
    901     auto dim_before = tensor->getDimensions();
    902     VLOG(2) << "TENSOR before: " << dim_before.d[0] << ", " << dim_before.d[1]
    903             << dim_before.d[2] << ", " << dim_before.d[3];
    904     auto pad_layer = ctx.network()->addPadding(
    905         *const_cast<nvinfer1::ITensor*>(tensor),
    906         nvinfer1::DimsHW(padding[0].first, padding[1].first),
    907         nvinfer1::DimsHW(padding[0].second, padding[1].second));
    908     padding = {{0, 0}, {0, 0}};
    909     tensor = pad_layer->getOutput(0);
    910     auto dim_after = tensor->getDimensions();
    911     VLOG(2) << "TENSOR after: " << dim_after.d[0] << ", " << dim_after.d[1]
    912             << dim_after.d[2] << ", " << dim_after.d[3];
    913   }
    915   nvinfer1::IConvolutionLayer* layer =
    916       ctx.network()->addConvolution(*const_cast<nvinfer1::ITensor*>(tensor),
    917                                     noutput, kernel_size, weights, biases);
    919   layer->setStride(stride);
    920   layer->setPadding({padding[0].first, padding[1].first});
    921   layer->setName(node_def.name().c_str());
    922   nvinfer1::ITensor* output_tensor = layer->getOutput(0);
    924   auto dim_after = output_tensor->getDimensions();
    925   VLOG(2) << "TENSOR out: " << dim_after.d[0] << ", " << dim_after.d[1]
    926           << dim_after.d[2] << ", " << dim_after.d[3];
    928   if (data_format == "NHWC") {
    929     // TODO(jie): transpose it back!
    930     output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1});
    931   } else {
    932     VLOG(2) << "NCHW !!!!";
    933   }
    934   outputs->push_back(TRT_TensorOrWeights(output_tensor));
    935   return tensorflow::Status::OK();
    936 }
    938 tensorflow::Status ConvertPool(Converter& ctx,
    939                                const tensorflow::NodeDef& node_def,
    940                                std::vector<TRT_TensorOrWeights> const& inputs,
    941                                std::vector<TRT_TensorOrWeights>* outputs) {
    942   nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
    943   TFAttrs attrs(node_def);
    945   int h_index = 2;
    946   int w_index = 3;
    947   auto data_format = attrs.get<string>("data_format");
    948   if (data_format == "NHWC") {
    949     h_index = 1;
    950     w_index = 2;
    951     tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
    952                                  {0, 3, 1, 2});
    953   } else {
    954     VLOG(2) << "NCHW !!!!";
    955   }
    956   nvinfer1::PoolingType type;
    957   // TODO(jie): support other pooling type
    958   if (node_def.op() == "MaxPool")
    959     type = nvinfer1::PoolingType::kMAX;
    960   else
    961     return tensorflow::errors::Unimplemented("Only supports Max pool");
    963   // TODO(jie): NCHW
    964   auto tf_stride = attrs.get<std::vector<int>>("strides");
    965   nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
    967   auto tf_kernel = attrs.get<std::vector<int>>("ksize");
    968   nvinfer1::DimsHW ksize(tf_kernel[h_index], tf_kernel[w_index]);
    970   auto tensor_dim = tensor->getDimensions();
    971   std::vector<std::pair<int, int>> padding;
    972   // TODO(jie): padding.
    973   if (attrs.get<string>("padding") == "SAME") {
    974     // This is NCHW tensor with no batch dimension.
    975     //  1 -> h
    976     //  2 -> w
    977     padding = CreateSamePadding(
    978         stride, ksize,
    979         {static_cast<int>(tensor_dim.d[1]), static_cast<int>(tensor_dim.d[2])});
    980   } else if (attrs.get<string>("padding") == "VALID") {
    981     // No padding for valid padding here
    982     VLOG(2) << "No padding added for VALID padding in pool" << node_def.name();
    983     padding = {{0, 0}, {0, 0}};
    984   } else {
    985     return tensorflow::errors::Unimplemented(
    986         "Current MaxPool cannot support padding other than SAME");
    987   }
    989   if (padding[0].first != padding[0].second ||
    990       padding[1].first != padding[1].second) {
    991     // TODO(jie): handle asymmetric padding
    992     VLOG(2) << "Padding!!!: " << padding[0].first << padding[0].second
    993             << padding[1].first << padding[1].second;
    994     auto pad_layer = ctx.network()->addPadding(
    995         *const_cast<nvinfer1::ITensor*>(tensor),
    996         nvinfer1::DimsHW(padding[0].first, padding[1].first),
    997         nvinfer1::DimsHW(padding[0].second, padding[1].second));
    998     padding = {{0, 0}, {0, 0}};
    999     tensor = pad_layer->getOutput(0);
   1000   }
   1002   nvinfer1::IPoolingLayer* layer = ctx.network()->addPooling(
   1003       *const_cast<nvinfer1::ITensor*>(tensor), type, ksize);
   1005   layer->setStride(stride);
   1006   layer->setPadding({padding[0].first, padding[1].first});
   1007   layer->setName(node_def.name().c_str());
   1008   nvinfer1::ITensor* output_tensor = layer->getOutput(0);
   1010   if (data_format == "NHWC") {
   1011     // TODO(jie): transpose it back!
   1012     output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1});
   1013   } else {
   1014     VLOG(2) << "NCHW !!!!";
   1015   }
   1016   outputs->push_back(TRT_TensorOrWeights(output_tensor));
   1017   return tensorflow::Status::OK();
   1018 }
   1020 tensorflow::Status ConvertActivation(
   1021     Converter& ctx, const tensorflow::NodeDef& node_def,
   1022     std::vector<TRT_TensorOrWeights> const& inputs,
   1023     std::vector<TRT_TensorOrWeights>* outputs) {
   1024   nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
   1025   nvinfer1::IActivationLayer* layer = ctx.network()->addActivation(
   1026       *const_cast<nvinfer1::ITensor*>(tensor), nvinfer1::ActivationType::kRELU);
   1027   nvinfer1::ITensor* output_tensor = layer->getOutput(0);
   1028   outputs->push_back(TRT_TensorOrWeights(output_tensor));
   1029   return tensorflow::Status::OK();
   1030 }
   1032 tensorflow::Status ConvertScale(Converter& ctx,
   1033                                 const tensorflow::NodeDef& node_def,
   1034                                 std::vector<TRT_TensorOrWeights> const& inputs,
   1035                                 std::vector<TRT_TensorOrWeights>* outputs) {
   1036   if (inputs.size() != 2 || !inputs.at(0).is_tensor() ||
   1037       !inputs.at(1).is_weights())
   1038     return tensorflow::errors::Unimplemented(
   1039         "Only supports tensor op weight for now, at " + node_def.name());
   1040   // Implement tensor binaryOp weight [channel wise] for now;
   1041   nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
   1043   // TODO(jie): handle NHWC/NCHW transpose;
   1044   TRT_ShapedWeights weights = inputs.at(1).weights();
   1045   TRT_ShapedWeights empty_weights(weights.type_);
   1047   TFAttrs attrs(node_def);
   1049   // Transpose NHWC
   1050   auto data_format = attrs.get<string>("data_format");
   1051   if (data_format == "NHWC") {
   1052     tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
   1053                                  {0, 3, 1, 2});
   1054     // TODO(jie): transpose it
   1055   } else {
   1056     VLOG(2) << "NCHW !!!!";
   1057   }
   1058   nvinfer1::IScaleLayer* layer = ctx.network()->addScale(
   1059       *const_cast<nvinfer1::ITensor*>(tensor), nvinfer1::ScaleMode::kCHANNEL,
   1060       weights, empty_weights, empty_weights);
   1062   nvinfer1::ITensor* output_tensor = layer->getOutput(0);
   1063   if (data_format == "NHWC") {
   1064     // TODO(jie): transpose it back!
   1065     output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1});
   1066   } else {
   1067     VLOG(2) << "NCHW !!!!";
   1068   }
   1069   outputs->push_back(TRT_TensorOrWeights(output_tensor));
   1070   return tensorflow::Status::OK();
   1071 }
   1073 tensorflow::Status ConvertConst(Converter& ctx,
   1074                                 const tensorflow::NodeDef& node_def,
   1075                                 std::vector<TRT_TensorOrWeights> const& inputs,
   1076                                 std::vector<TRT_TensorOrWeights>* outputs) {
   1077   const auto& weights_tensor = node_def.attr().at("value").tensor();
   1079   // Get trt type & shape
   1080   TFAttrs attrs(node_def);
   1081   const tensorflow::DataType dtype = attrs.get<tensorflow::DataType>("dtype");
   1083   // Create shaped weights as output
   1084   tensorflow::Tensor tensor;
   1085   if (!tensor.FromProto(weights_tensor))
   1086     return tensorflow::errors::Internal("Cannot parse weight tensor proto: " +
   1087                                         node_def.name());
   1089   TRT_ShapedWeights weights(dtype);
   1090   if (!weights_tensor.float_val().empty()) {
   1091     VLOG(2) << "SCALAR!!!" << node_def.name();
   1092     nvinfer1::Dims scalar_shape;
   1093     if (tensor.dims() > 0) {
   1094       VLOG(2) << "Dimensions: " << tensor.dims();
   1095       weights = TRT_ShapedWeights(dtype, weights_tensor.float_val().data(),
   1096                                   GetTensorShape(tensor));
   1097     } else {
   1098       VLOG(2) << "Dimensions: " << tensor.dims();
   1099       scalar_shape.nbDims = 1;
   1100       scalar_shape.d[0] = 1;
   1101       scalar_shape.type[0] = nvinfer1::DimensionType::kSPATIAL;
   1102       for (int i = 1; i < nvinfer1::Dims::MAX_DIMS; i++) {
   1103         scalar_shape.d[i] = 0;
   1104         scalar_shape.type[i] = nvinfer1::DimensionType::kSPATIAL;
   1105       }
   1106       weights = TRT_ShapedWeights(dtype, weights_tensor.float_val().data(),
   1107                                   scalar_shape);
   1108     }
   1109   } else if (!weights_tensor.tensor_content().empty()) {
   1110     VLOG(2) << "TENSOR!!!" << node_def.name();
   1111     const auto& content = weights_tensor.tensor_content();
   1113     weights = ctx.get_temp_weights(dtype, GetTensorShape(tensor));
   1114     if (content.size() > 0) {
   1115       const int dtype_size = tensorflow::DataTypeSize(dtype);
   1116       CHECK_EQ(0, content.size() % dtype_size)
   1117           << "Tensor content size (" << content.size()
   1118           << ") is not a multiple of " << dtype_size;
   1119       port::CopyToArray(
   1120           content, static_cast<char*>(const_cast<void*>(weights.GetValues())));
   1121     }
   1122   } else {
   1123     return tensorflow::errors::Unimplemented(
   1124         "Not supported constant type, at " + node_def.name());
   1125   }
   1126   // Pass the output
   1127   outputs->push_back(TRT_TensorOrWeights(weights));
   1128   return tensorflow::Status::OK();
   1129 }
   1131 tensorflow::Status ConvertIdentity(
   1132     Converter& ctx, const tensorflow::NodeDef& node_def,
   1133     std::vector<TRT_TensorOrWeights> const& inputs,
   1134     std::vector<TRT_TensorOrWeights>* outputs) {
   1135   outputs->push_back(inputs.at(0));
   1136   return tensorflow::Status::OK();
   1137 }
   1139 tensorflow::Status ConvertBinary(Converter& ctx,
   1140                                  const tensorflow::NodeDef& node_def,
   1141                                  std::vector<TRT_TensorOrWeights> const& inputs,
   1142                                  std::vector<TRT_TensorOrWeights>* outputs) {
   1143   if (inputs.size() != 2)
   1144     return tensorflow::errors::FailedPrecondition(
   1145         "Binary ops require two tensor input, at " + node_def.name());
   1147   if (inputs.at(0).is_weights() && inputs.at(1).is_weights())
   1148     return ConstantFoldBinary(ctx, node_def, inputs, outputs);
   1150   if (inputs.at(0).is_tensor() && inputs.at(1).is_weights())
   1151     return BinaryTensorOpWeight(ctx, node_def, inputs.at(0).tensor(),
   1152                                 inputs.at(1).weights(), outputs);
   1154   if (inputs.at(0).is_weights() && inputs.at(1).is_tensor())
   1155     return BinaryTensorOpWeight(ctx, node_def, inputs.at(1).tensor(),
   1156                                 inputs.at(0).weights(), outputs);
   1158   if (inputs.at(0).is_tensor() && inputs.at(1).is_tensor())
   1159     return BinaryTensorOpTensor(ctx, node_def, inputs.at(0).tensor(),
   1160                                 inputs.at(1).tensor(), outputs);
   1162   return tensorflow::errors::Unknown("Binary op input error, at " +
   1163                                      node_def.name());
   1164 }
   1166 tensorflow::Status ConvertUnary(Converter& ctx,
   1167                                 const tensorflow::NodeDef& node_def,
   1168                                 std::vector<TRT_TensorOrWeights> const& inputs,
   1169                                 std::vector<TRT_TensorOrWeights>* outputs) {
   1170   if (inputs.size() != 1)
   1171     return tensorflow::errors::FailedPrecondition(
   1172         "Unary ops require single tensor input, at " + node_def.name());
   1174   if (inputs.at(0).is_weights())
   1175     return ConstantFoldUnary(ctx, node_def, inputs, outputs);
   1176   else if (inputs.at(0).is_tensor())
   1177     return tensorflow::errors::Unimplemented(
   1178         "Unary op for tensor not supported, at " + node_def.name());
   1180   return tensorflow::errors::Unknown("Binary op input error, at " +
   1181                                      node_def.name());
   1182 }
   1184 tensorflow::Status ConvertReduce(Converter& ctx,
   1185                                  const tensorflow::NodeDef& node_def,
   1186                                  std::vector<TRT_TensorOrWeights> const& inputs,
   1187                                  std::vector<TRT_TensorOrWeights>* outputs) {
   1188   if (inputs.size() != 2 || !inputs.at(0).is_tensor() ||
   1189       !inputs.at(1).is_weights())
   1190     return tensorflow::errors::InvalidArgument(
   1191         "Input expects tensor and weights, at" + node_def.name());
   1193   // Implement tensor binaryOp weight [channel wise] for now;
   1194   nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
   1195   auto dims = tensor->getDimensions();
   1196   // Restore implicit batch dimension
   1197   int nb_dims = dims.nbDims + 1;
   1199   TRT_ShapedWeights index_list = inputs.at(1).weights();
   1201   TFAttrs attrs(node_def);
   1202   // TODO(jie): handle data type.
   1203   // Index type here is done through TF type, so I can leverage their
   1204   // EnumToDataType for my cast
   1205   auto index_type = attrs.get<tensorflow::DataType>("Tidx");
   1207   // Only expect to handle INT32 as attributes for now
   1208   if (index_type != tensorflow::DataType::DT_INT32)
   1209     return tensorflow::errors::Unimplemented("Tidx supports only DT_INT32");
   1210   auto index_list_data =
   1211       static_cast<int*>(const_cast<void*>(index_list.GetValues()));
   1213   // Hack warning: have to fall back to pool layer since reduce is not in public
   1214   // TRT yet.
   1215   if (nb_dims != 4)
   1216     return tensorflow::errors::InvalidArgument(
   1217         "TRT only support reduce on 4 dimensional tensors, at" +
   1218         node_def.name());
   1219   if (index_list.count() > 2)
   1220     return tensorflow::errors::InvalidArgument(
   1221         "TRT cannot support reduce on more than 2 dimensions, at" +
   1222         node_def.name());
   1224   std::set<int> idx_set;
   1225   // We cannot operate on Channel. permutation flag used to transpose tensor
   1226   int permuted_index = -1;
   1227   for (int i = 0; i < index_list.count(); i++) {
   1228     if (index_list_data[i] == 0)
   1229       return tensorflow::errors::InvalidArgument("TRT cannot reduce at 0, at" +
   1230                                                  node_def.name());
   1231     if (index_list_data[i] == 1) permuted_index = 1;
   1232     idx_set.emplace(index_list_data[i]);
   1233   }
   1235   std::vector<int> permutation_order(nb_dims);
   1236   nvinfer1::DimsHW pool_kernel;
   1237   if (permuted_index == 1) {
   1238     for (int i = 2; i < nb_dims; i++) {
   1239       if (idx_set.count(i)) {
   1240         permuted_index = i;
   1241         break;
   1242       }
   1243     }
   1244     for (int i = 0; i < nb_dims; i++) permutation_order[i] = i;
   1246     permutation_order[permuted_index] = 1;
   1247     permutation_order[1] = permuted_index;
   1249     // Apply permutation before extracting dimension for pool_kernel
   1250     tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
   1251                                  permutation_order);
   1252   }
   1254   // Apply permutation before extracting dimension for pool_kernel
   1255   pool_kernel.d[0] = (idx_set.count(2) || permuted_index == 2) ? dims.d[1] : 1;
   1256   pool_kernel.d[1] = (idx_set.count(3) || permuted_index == 3) ? dims.d[2] : 1;
   1258   nvinfer1::ITensor* output_tensor;
   1260   if (node_def.op() == "Mean") {
   1261     nvinfer1::IPoolingLayer* layer =
   1262         ctx.network()->addPooling(*const_cast<nvinfer1::ITensor*>(tensor),
   1263                                   nvinfer1::PoolingType::kAVERAGE, pool_kernel);
   1264     output_tensor = layer->getOutput(0);
   1265   } else {
   1266     return tensorflow::errors::Unimplemented(
   1267         "Op not supported " + node_def.op() + " , at " + node_def.name());
   1268   }
   1269   if (permuted_index != -1) {
   1270     // Apply permutation before extracting dimension for pool_kernel
   1271     output_tensor = ctx.TransposeTensor(
   1272         const_cast<nvinfer1::ITensor*>(output_tensor), permutation_order);
   1273   }
   1274   return tensorflow::Status::OK();
   1275 }
   1277 tensorflow::Status ConvertPad(Converter& ctx,
   1278                               const tensorflow::NodeDef& node_def,
   1279                               std::vector<TRT_TensorOrWeights> const& inputs,
   1280                               std::vector<TRT_TensorOrWeights>* outputs) {
   1281   if (inputs.size() != 2 || !inputs.at(0).is_tensor() ||
   1282       !inputs.at(1).is_weights())
   1283     return tensorflow::errors::InvalidArgument(
   1284         "Input expects tensor and weights, at" + node_def.name());
   1286   // Implement tensor binaryOp weight [channel wise] for now;
   1287   nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
   1288   auto dims = tensor->getDimensions();
   1289   // Restore implicit batch dimension
   1290   int nb_dims = dims.nbDims + 1;
   1292   TRT_ShapedWeights pads = inputs.at(1).weights();
   1294   TFAttrs attrs(node_def);
   1295   // Padding type here is done through TF type
   1296   //   so I can leverage their EnumToDataType for my cast
   1297   auto padding_type = attrs.get<tensorflow::DataType>("Tpaddings");
   1298   // TODO(jie): handle data type conversion for TRT?
   1300   if (pads.shape_.d[0] != nb_dims || pads.shape_.d[1] != 2)
   1301     return tensorflow::errors::InvalidArgument(
   1302         "Pad only supports explicit padding on 4 dimensional tensor, at " +
   1303         node_def.name());
   1305   // Only expect to handle INT32 as attributes for now
   1306   if (padding_type != tensorflow::DataType::DT_INT32)
   1307     return tensorflow::errors::Unimplemented(
   1308         "Tpaddings supports only DT_INT32");
   1309   auto pad_data = static_cast<int*>(const_cast<void*>(pads.GetValues()));
   1311   std::vector<int32_t> pad_index;
   1312   for (int i = 0; i < nb_dims; i++) {
   1313     if (pad_data[2 * i] != 0 || pad_data[2 * i + 1] != 0)
   1314       pad_index.push_back(i);
   1315   }
   1317   // No padding at all, we should exit
   1318   if (pad_index.size() == 0) {
   1319     outputs->push_back(inputs.at(0));
   1320     return tensorflow::Status::OK();
   1321   }
   1323   // Only supports padding on less than 2 axis GIE-2579
   1324   if (pad_index.size() > 2)
   1325     return tensorflow::errors::InvalidArgument(
   1326         "Padding layer does not support padding on > 2");
   1328   // Padding on batch dimension is not supported
   1329   if (pad_index[0] == 0)
   1330     return tensorflow::errors::InvalidArgument(
   1331         "Padding layer does not support padding on batch dimension");
   1333   // Not doing the legit thing here. ignoring padding on dim 1 and 3;
   1334   // TODO(jie): implement pad as uff parser
   1335   if (pad_index.size() == 2 && pad_index[0] == 0 && pad_index[1] == 3)
   1336     return tensorflow::errors::Unimplemented(
   1337         "Padding layer does not support padding on dimension 1 and 3 yet");
   1339   bool legit_pad = true;
   1340   nvinfer1::DimsHW pre_padding(0, 0);
   1341   nvinfer1::DimsHW post_padding(0, 0);
   1343   std::vector<int32_t> permuted_pad_index(pad_index);
   1344   if (pad_index[0] == 1) {
   1345     legit_pad = false;
   1346     tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
   1347                                  {0, 3, 2, 1});
   1348     permuted_pad_index[0] = 3;
   1349   }
   1351   for (size_t i = 0; i < pad_index.size(); i++) {
   1352     int index = pad_index[i];
   1353     if (permuted_pad_index[i] == 2) {
   1354       pre_padding.h() = pad_data[index * 2];
   1355       post_padding.h() = pad_data[index * 2 + 1];
   1356     } else if (permuted_pad_index[i] == 3) {
   1357       pre_padding.w() = pad_data[index * 2];
   1358       post_padding.w() = pad_data[index * 2 + 1];
   1359     }
   1360   }
   1362   nvinfer1::IPaddingLayer* layer = ctx.network()->addPadding(
   1363       *const_cast<nvinfer1::ITensor*>(tensor), pre_padding, post_padding);
   1364   nvinfer1::ITensor* output_tensor = layer->getOutput(0);
   1366   if (!legit_pad)
   1367     output_tensor = ctx.TransposeTensor(
   1368         const_cast<nvinfer1::ITensor*>(output_tensor), {0, 3, 2, 1});
   1370   outputs->push_back(TRT_TensorOrWeights(output_tensor));
   1371   return tensorflow::Status::OK();
   1372 }
   1374 void Converter::register_op_converters() {
   1375   // vgg_16 slim implementation
   1376   op_registry_["Placeholder"] = ConvertPlaceholder;
   1377   op_registry_["Conv2D"] = ConvertConv2D;
   1378   op_registry_["Relu"] = ConvertActivation;
   1379   op_registry_["MaxPool"] = ConvertPool;
   1380   // This could be really handled as ConvertBinary
   1381   op_registry_["BiasAdd"] = ConvertScale;
   1382   op_registry_["Const"] = ConvertConst;
   1383   // op_registry_["MatMul"] = ConvertFullyConnected;  // Not used in vgg
   1384   // TODO(ben,jie): this is a temp hack.
   1385   op_registry_["Identity"] = ConvertIdentity;  // Identity should be removed
   1386   // op_registry_["AvgPool"] = ConvertPool;
   1388   // resnet_50_v1 slim implementation
   1389   op_registry_["Add"] = ConvertBinary;
   1390   op_registry_["Mul"] = ConvertBinary;
   1391   op_registry_["Sub"] = ConvertBinary;
   1392   op_registry_["Rsqrt"] = ConvertUnary;
   1393   op_registry_["Mean"] = ConvertReduce;
   1394   op_registry_["Pad"] = ConvertPad;
   1395   // TODO(ben,jie): Add more ops
   1396 }
   1398 }  // namespace
   1400 tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
   1401     const tensorflow::Graph& graph, const std::set<int>& subgraph_node_ids,
   1402     const std::vector<std::pair<int, int>>& input_inds,
   1403     const std::vector<std::pair<int, int>>& output_inds, size_t max_batch_size,
   1404     size_t max_workspace_size_bytes,
   1405     const tensorflow::grappler::GraphProperties& graph_properties,
   1406     tensorflow::NodeDef* trt_node) {
   1407   // Visit nodes in reverse topological order and construct the TRT network.
   1409   // Toposort
   1410   std::vector<tensorflow::Node*> order_vec;
   1411   tensorflow::GetPostOrder(graph, &order_vec);
   1412   // Select just the subgraph
   1413   std::list<tensorflow::Node*> order;
   1414   for (tensorflow::Node* node : order_vec) {
   1415     if (subgraph_node_ids.count(node->id())) {
   1416       // We want topological order to contstruct the
   1417       // network layer by layer
   1418       order.push_front(node);
   1419     }
   1420   }
   1421   // Topological order is needed to build TRT network
   1423   tensorflow::tensorrt::Logger trt_logger;
   1425   auto trt_builder = infer_object(nvinfer1::createInferBuilder(trt_logger));
   1426   if (!trt_builder) {
   1427     return tensorflow::errors::Internal(
   1428         "Failed to create TensorRT builder object");
   1429   }
   1431   auto trt_network = infer_object(trt_builder->createNetwork());
   1432   if (!trt_network) {
   1433     return tensorflow::errors::Internal(
   1434         "Failed to create TensorRT network object");
   1435   }
   1437   // Build the network
   1438   Converter converter(trt_network.get());
   1440   std::vector<string> input_names;
   1441   std::vector<tensorflow::DataType> input_dtypes;
   1442   for (std::pair<int, int> const& input : input_inds) {
   1443     int node_id = input.first;
   1444     int output_idx = input.second;
   1445     tensorflow::Node* node = graph.FindNodeId(node_id);
   1446     auto node_name = node->name();
   1447     input_names.push_back(node_name);  // Insert original node name without port
   1448     // TODO(jie): alternative :)
   1449     if (!graph_properties.HasOutputProperties(node_name))
   1450       return tensorflow::errors::Internal("Failed to find input node: " +
   1451                                           node_name);
   1453     auto op_info_vec = graph_properties.GetOutputProperties(node_name);
   1454     if (static_cast<int>(op_info_vec.size()) < output_idx)
   1455       return tensorflow::errors::Internal(
   1456           "Accessing output index of: " + std::to_string(output_idx) +
   1457           ", at node: " + node_name + " with output entry from shape_map: " +
   1458           std::to_string(op_info_vec.size()));
   1460     auto op_info = op_info_vec.at(output_idx);
   1462     tensorflow::DataType tf_dtype = op_info.dtype();
   1463     input_dtypes.push_back(tf_dtype);
   1465     nvinfer1::DataType dtype(nvinfer1::DataType::kFLOAT);
   1466     TF_CHECK_OK(ConvertDType(tf_dtype, &dtype));
   1468     VLOG(2) << "Accessing output index of: " << std::to_string(output_idx)
   1469             << ", at node: " << node_name
   1470             << " with output entry from shape_map: "
   1471             << std::to_string(op_info_vec.size());
   1473     // TODO(ben,jie): update TRT input format/dimension
   1474     nvinfer1::DimsCHW input_dim_psuedo_chw;
   1475     for (int i = 0; i < 3; i++) input_dim_psuedo_chw.d[i] = 1;
   1477     for (int i = 1; i < op_info.shape().dim_size(); i++) {
   1478       VLOG(2) << "dimension: " << i
   1479               << " , size: " << op_info.shape().dim(i).size();
   1480       input_dim_psuedo_chw.d[i - 1] = op_info.shape().dim(i).size();
   1481     }
   1483     // TODO(ben,jie): proper way to restore input tensor name?
   1484     auto input_tensor_name = node_name;
   1485     if (output_idx != 0)
   1486       input_tensor_name = node_name + ":" + std::to_string(output_idx);
   1488     nvinfer1::ITensor* input_tensor = converter.network()->addInput(
   1489         input_tensor_name.c_str(), dtype, input_dim_psuedo_chw);
   1491     if (!input_tensor)
   1492       return tensorflow::errors::InvalidArgument(
   1493           "Failed to create Input layer");
   1494     VLOG(2) << "Input tensor name :" << input_tensor_name;
   1496     if (!converter.insert_input_tensor(input_tensor_name, input_tensor))
   1497       return tensorflow::errors::AlreadyExists(
   1498           "Output tensor already exists for op: " + input_tensor_name);
   1499   }
   1501   VLOG(2) << "Finished sorting";
   1503   for (const tensorflow::Node* node : order) {
   1504     const tensorflow::NodeDef& node_def = node->def();
   1505     VLOG(2) << "Converting node: " << node_def.name() << " , " << node_def.op();
   1506     TF_RETURN_IF_ERROR(converter.convert_node(node_def));
   1507   }
   1509   VLOG(2) << "Finished conversion";
   1511   // Gather output metadata
   1512   std::vector<string> output_names;
   1513   std::vector<tensorflow::DataType> output_dtypes;
   1514   for (std::pair<int, int> const& output : output_inds) {
   1515     int node_id = output.first;
   1516     int output_idx = output.second;
   1517     tensorflow::Node* node = graph.FindNodeId(node_id);
   1518     string op_name = node->name();
   1519     string tensor_name = op_name;
   1520     if (output_idx != 0)
   1521       tensor_name = tensor_name + ":" + std::to_string(output_idx);
   1522     VLOG(2) << "Output tensor name: " << tensor_name;
   1523     output_names.push_back(tensor_name);
   1524     auto tensor_or_weights = converter.get_tensor(tensor_name);
   1525     if (!tensor_or_weights.is_tensor()) {
   1526       return tensorflow::errors::InvalidArgument(
   1527           "Output node is weights not tensor");
   1528     }
   1529     nvinfer1::ITensor* tensor = tensor_or_weights.tensor();
   1530     if (!tensor) {
   1531       return tensorflow::errors::NotFound("Output tensor not found: " +
   1532                                           tensor_name);
   1533     }
   1534     converter.network()->markOutput(*tensor);
   1535     tensorflow::DataType tf_dtype = node->output_type(output_idx);
   1536     output_dtypes.push_back(tf_dtype);
   1537     nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT;
   1538     TF_RETURN_IF_ERROR(ConvertDType(tf_dtype, &trt_dtype));
   1539     tensor->setType(trt_dtype);
   1540   }
   1542   VLOG(2) << "Finished output";
   1543   // TODO(jie): static_id is not thread safe.
   1544   static int static_id = 0;
   1546   // Build the engine
   1547   trt_builder->setMaxBatchSize(max_batch_size);
   1548   trt_builder->setMaxWorkspaceSize(max_workspace_size_bytes);
   1549   VLOG(0) << "Starting build engine " << static_id;
   1550   // TODO(ben,jie): half2 and int8 mode support
   1551   string engine_plan_string;
   1552   {
   1553     auto trt_engine =
   1554         infer_object(trt_builder->buildCudaEngine(*converter.network()));
   1555     VLOG(0) << "Built network";
   1556     auto engine_plan = infer_object(trt_engine->serialize());
   1557     VLOG(0) << "Serialized engine";
   1558     const char* engine_plan_data =
   1559         static_cast<const char*>(engine_plan->data());
   1560     engine_plan_string =
   1561         string(engine_plan_data, engine_plan_data + engine_plan->size());
   1562   }
   1564   VLOG(0) << "Finished engine";
   1566   // Build the TRT op
   1567   // TODO(sami,ben,jie): proper naming!
   1568   tensorflow::NodeDefBuilder op_builder(
   1569       tensorflow::strings::StrCat("my_trt_op", static_id++), "TRTEngineOp");
   1570   std::vector<tensorflow::NodeDefBuilder::NodeOut> income_edges;
   1571   for (size_t i = 0; i < input_names.size(); ++i) {
   1572     int output_idx = input_inds.at(i).second;
   1573     // We wired up the input here already, it is redundant to do it again in
   1574     // ConvertSubGraphToTensorRT(convert_graph.cc)
   1575     auto incoming_edge = tensorflow::NodeDefBuilder::NodeOut(
   1576         input_names.at(i), output_idx, input_dtypes.at(i));
   1577     income_edges.push_back(incoming_edge);
   1578   }
   1579   tensorflow::gtl::ArraySlice<tensorflow::NodeDefBuilder::NodeOut> input_list(
   1580       income_edges);
   1581   op_builder.Input(input_list);
   1583   VLOG(0) << "Finished op preparation";
   1585   auto status = op_builder.Attr("serialized_engine", engine_plan_string)
   1586                     .Attr("input_nodes", input_names)
   1587                     .Attr("output_nodes", output_names)
   1588                     .Attr("OutT", output_dtypes)
   1589                     .Finalize(trt_node);
   1591   VLOG(0) << status.ToString() << " finished op building";
   1593   return tensorflow::Status::OK();
   1594 }
   1596 }  // namespace convert
   1597 }  // namespace tensorrt
   1598 }  // namespace tensorflow
   1600 #endif  // GOOGLE_TENSORRT
   1601 #endif  // GOOGLE_CUDA