Home | History | Annotate | Download | only in convert
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/contrib/tensorrt/convert/convert_nodes.h"
     17 
     18 #include <algorithm>
     19 #include <list>
     20 #include <map>
     21 #include <memory>
     22 #include <set>
     23 #include <unordered_map>
     24 #include <utility>
     25 #include <vector>
     26 
     27 #include "tensorflow/core/framework/node_def_builder.h"
     28 #include "tensorflow/core/framework/tensor_shape.pb.h"  // NOLINT
     29 #include "tensorflow/core/framework/types.h"
     30 #include "tensorflow/core/graph/algorithm.h"
     31 #include "tensorflow/core/graph/graph.h"
     32 #include "tensorflow/core/graph/graph_constructor.h"
     33 #include "tensorflow/core/lib/core/errors.h"
     34 #include "tensorflow/core/lib/core/status.h"
     35 #include "tensorflow/core/lib/strings/strcat.h"
     36 #include "tensorflow/core/platform/logging.h"
     37 #include "tensorflow/core/platform/tensor_coding.h"
     38 #include "tensorflow/core/platform/types.h"
     39 
     40 #if GOOGLE_CUDA
     41 #if GOOGLE_TENSORRT
     42 #include "tensorflow/contrib/tensorrt/log/trt_logger.h"
     43 #include "tensorrt/include/NvInfer.h"
     44 
     45 //  Check if the types are equal. Cast to int first so that failure log message
     46 //  would work!
     47 #define CHECK_EQ_TYPE(val1, val2) CHECK_EQ((int)val1, (int)val2)
     48 
     49 namespace tensorflow {
     50 namespace tensorrt {
     51 namespace convert {
     52 
     53 namespace {
     54 
     55 inline tensorflow::Status ConvertDType(tensorflow::DataType tf_dtype,
     56                                        nvinfer1::DataType* trt_dtype) {
     57   switch (tf_dtype) {
     58     case tensorflow::DataType::DT_FLOAT:
     59       *trt_dtype = nvinfer1::DataType::kFLOAT;
     60       break;
     61     case tensorflow::DataType::DT_INT8:
     62       *trt_dtype = nvinfer1::DataType::kINT8;
     63       break;
     64     case tensorflow::DataType::DT_HALF:
     65       *trt_dtype = nvinfer1::DataType::kHALF;
     66       break;
     67     default:
     68       return tensorflow::errors::InvalidArgument("Unsupported data type");
     69   }
     70   return tensorflow::Status::OK();
     71 }
     72 
     73 inline nvinfer1::Dims GetTensorShape(const tensorflow::Tensor& tensor) {
     74   nvinfer1::Dims dims;
     75   dims.nbDims = tensor.dims();
     76   for (int i = 0; i < dims.nbDims; i++) {
     77     dims.d[i] = tensor.dim_size(i);
     78   }
     79   return dims;
     80 }
     81 
     82 inline int64_t GetShapeSize(nvinfer1::Dims shape) {
     83   // Returns total number of elements in shape
     84   int64_t count = 1;
     85   for (int d = 0; d < shape.nbDims; ++d) {
     86     count *= shape.d[d];
     87   }
     88   return count;
     89 }
     90 
     91 static std::vector<std::pair<int, int>> CreateSamePadding(
     92     const nvinfer1::DimsHW& stride, const nvinfer1::DimsHW& kernel,
     93     const std::vector<int64_t>& input_dims) {
     94   std::vector<std::pair<int, int>> padding(input_dims.size());
     95   CHECK_EQ((size_t)stride.nbDims, input_dims.size());  // TODO(jie): N+C? NC+?
     96 
     97   for (size_t i = 0; i < input_dims.size(); ++i) {
     98     // Formula to calculate the padding
     99     int p = ((input_dims[i] - 1) / stride.d[i]) * stride.d[i] + kernel.d[i] -
    100             input_dims[i];
    101     p = (p > 0) ? p : 0;
    102 
    103     // Right precedence padding, like in TensorFlow
    104     int left = p / 2;
    105     int right = p - left;
    106 
    107     VLOG(2) << "PADDING_" << i << " pre: " << left << ", post: " << right
    108             << "paras: " << input_dims[i] << ", " << stride.d[i] << ", "
    109             << "kernel: " << kernel.d[i];
    110     padding[i] = {left, right};
    111   }
    112   return padding;
    113 }
    114 
    115 class TRT_ShapedWeights {
    116  public:
    117   TRT_ShapedWeights(tensorflow::DataType type, const void* values,
    118                     nvinfer1::Dims shape)
    119       : shape_(shape), type_(type), values_(values), empty_weight_flag_(false) {
    120     // Note: this->shape.type[] is not used
    121   }
    122 
    123   explicit TRT_ShapedWeights(tensorflow::DataType type)
    124       : shape_(), type_(type), values_(nullptr), empty_weight_flag_(true) {}
    125 
    126   TRT_ShapedWeights(const TRT_ShapedWeights& rhs)
    127       : shape_(rhs.shape_),
    128         type_(rhs.type_),
    129         values_(rhs.values_),
    130         empty_weight_flag_(rhs.empty_weight_flag_) {}
    131 
    132   int64_t count() const {
    133     int64_t c = 1;
    134     for (int i = 0; i < shape_.nbDims; i++) c *= shape_.d[i];
    135     return c;
    136   }
    137 
    138   nvinfer1::Weights GetWeightsForTRT() const {
    139     nvinfer1::DataType trt_type(nvinfer1::DataType::kFLOAT);
    140     TF_CHECK_OK(ConvertDType(type_, &trt_type));
    141     if (empty_weight_flag_) return nvinfer1::Weights{trt_type, nullptr, 0};
    142 
    143     // Note: this->shape.type[] is not used
    144     return nvinfer1::Weights{trt_type, GetValues(), GetShapeSize(shape_)};
    145   }
    146 
    147   const void* GetValues() const { return values_; }
    148 
    149   void SetValues(const void* values) { values_ = values; }
    150 
    151   size_t size_bytes() const {
    152     int type_size = tensorflow::DataTypeSize(this->type_);
    153     return this->count() * type_size;
    154   }
    155 
    156   // Default converter
    157   operator nvinfer1::Weights() const { return GetWeightsForTRT(); }
    158 
    159   nvinfer1::Dims shape_;
    160   tensorflow::DataType type_;
    161 
    162  private:
    163   const void* values_;
    164   bool empty_weight_flag_;
    165 };
    166 
    167 class TRT_TensorOrWeights {
    168  public:
    169   explicit TRT_TensorOrWeights(nvinfer1::ITensor* tensor)
    170       : tensor_(tensor), weights_(DT_FLOAT), variant_(TRT_NODE_TENSOR) {}
    171   explicit TRT_TensorOrWeights(const TRT_ShapedWeights& weights)
    172       : tensor_(nullptr), weights_(weights), variant_(TRT_NODE_WEIGHTS) {}
    173   TRT_TensorOrWeights(const TRT_TensorOrWeights& rhs)
    174       : tensor_(rhs.tensor_), weights_(rhs.weights_), variant_(rhs.variant_) {}
    175   ~TRT_TensorOrWeights() {}
    176 
    177   bool is_tensor() const { return variant_ == TRT_NODE_TENSOR; }
    178   bool is_weights() const { return variant_ == TRT_NODE_WEIGHTS; }
    179 
    180   nvinfer1::ITensor* tensor() {
    181     CHECK_EQ(is_tensor(), true);
    182     return tensor_;
    183   }
    184   const nvinfer1::ITensor* tensor() const {
    185     CHECK_EQ(is_tensor(), true);
    186     return tensor_;
    187   }
    188   TRT_ShapedWeights& weights() {
    189     CHECK_EQ(is_weights(), true);
    190     return weights_;
    191   }
    192   const TRT_ShapedWeights& weights() const {
    193     CHECK_EQ(is_weights(), true);
    194     return weights_;
    195   }
    196   nvinfer1::Dims shape() const {
    197     if (is_tensor()) {
    198       return tensor()->getDimensions();
    199     } else {
    200       return weights().shape_;
    201     }
    202   }
    203 
    204  private:
    205   nvinfer1::ITensor* tensor_;
    206   TRT_ShapedWeights weights_;
    207   enum { TRT_NODE_TENSOR, TRT_NODE_WEIGHTS } variant_;
    208 };
    209 
    210 class TFAttrs {
    211  public:
    212   explicit TFAttrs(const tensorflow::NodeDef& tf_node) {
    213     for (const auto& attr : tf_node.attr()) {
    214       attrs_.insert({attr.first, &attr.second});
    215     }
    216   }
    217   bool count(string key) const { return attrs_.count(key); }
    218   tensorflow::AttrValue const* at(string key) const {
    219     if (!attrs_.count(key)) {
    220       LOG(FATAL) << "Attribute not found: " << key;
    221     }
    222     return attrs_.at(key);
    223   }
    224   template <typename T>
    225   T get(string key) const;
    226   template <typename T>
    227   T get(string key, const T& default_value) const {
    228     return attrs_.count(key) ? this->get<T>(key) : default_value;
    229   }
    230 
    231  private:
    232   typedef std::map<string, tensorflow::AttrValue const*> AttrMap;
    233   AttrMap attrs_;
    234 };
    235 
    236 template <>
    237 string TFAttrs::get<string>(string key) const {
    238   return this->at(key)->s();
    239 }
    240 
    241 template <>
    242 std::vector<int> TFAttrs::get<std::vector<int>>(string key) const {
    243   auto attr = this->at(key)->list().i();
    244   return std::vector<int>(attr.begin(), attr.end());
    245 }
    246 
    247 template <>
    248 nvinfer1::Dims TFAttrs::get<nvinfer1::Dims>(string key) const {
    249   auto values = this->get<std::vector<int>>(key);
    250   nvinfer1::Dims dims;
    251   dims.nbDims = values.size();
    252   std::copy(values.begin(), values.end(), dims.d);
    253   // Note: No dimension type information is included
    254   return dims;
    255 }
    256 
    257 template <>
    258 nvinfer1::DataType TFAttrs::get<nvinfer1::DataType>(string key) const {
    259   nvinfer1::DataType trt_dtype(nvinfer1::DataType::kFLOAT);
    260   TF_CHECK_OK(ConvertDType(this->at(key)->type(), &trt_dtype));
    261   return trt_dtype;
    262 }
    263 
    264 template <>
    265 tensorflow::DataType TFAttrs::get<tensorflow::DataType>(string key) const {
    266   return this->at(key)->type();
    267 }
    268 
    269 template <typename T>
    270 void Reorder4(nvinfer1::DimsNCHW shape, const T* idata,
    271               nvinfer1::DimsNCHW istrides, T* odata,
    272               nvinfer1::DimsNCHW ostrides) {
    273   for (int n = 0; n < shape.n(); ++n) {
    274     for (int c = 0; c < shape.c(); ++c) {
    275       for (int h = 0; h < shape.h(); ++h) {
    276         for (int w = 0; w < shape.w(); ++w) {
    277           odata[n * ostrides.n() + c * ostrides.c() + h * ostrides.h() +
    278                 w * ostrides.w()] = idata[n * istrides.n() + c * istrides.c() +
    279                                           h * istrides.h() + w * istrides.w()];
    280         }
    281       }
    282     }
    283   }
    284 }
    285 
    286 void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights,
    287                        TRT_ShapedWeights* oweights) {
    288   CHECK_EQ(iweights.type_, oweights->type_);
    289   CHECK_EQ(iweights.size_bytes(), oweights->size_bytes());
    290   int r = iweights.shape_.d[0];
    291   int s = iweights.shape_.d[1];
    292   int c = iweights.shape_.d[2];
    293   int k = iweights.shape_.d[3];
    294   oweights->shape_.d[0] = k;
    295   oweights->shape_.d[1] = c;
    296   oweights->shape_.d[2] = r;
    297   oweights->shape_.d[3] = s;
    298   nvinfer1::DimsNCHW istrides = {1, k, s * k * c, c * k};
    299   nvinfer1::DimsNCHW ostrides = {c * r * s, r * s, s, 1};
    300   switch (iweights.type_) {
    301     case tensorflow::DataType::DT_FLOAT:
    302       Reorder4({k, c, r, s}, static_cast<float const*>(iweights.GetValues()),
    303                istrides,
    304                static_cast<float*>(const_cast<void*>(oweights->GetValues())),
    305                ostrides);
    306       break;
    307     default:
    308       LOG(FATAL) << "!!!!!!!!!!!!!!!!!!!!!!!!broke!!!!!!!!!!!!";
    309   }
    310 }
    311 
    312 struct InferDeleter {
    313   template <typename T>
    314   void operator()(T* obj) const {
    315     if (obj) {
    316       obj->destroy();
    317     }
    318   }
    319 };
    320 
    321 template <typename T>
    322 inline std::shared_ptr<T> infer_object(T* obj) {
    323   return std::shared_ptr<T>(obj, InferDeleter());
    324 }
    325 
    326 // Logger for GIE info/warning/errors
    327 class Converter;
    328 
    329 using OpConverter =
    330     std::function<tensorflow::Status(Converter&, const tensorflow::NodeDef&,
    331                                      std::vector<TRT_TensorOrWeights> const&,
    332                                      std::vector<TRT_TensorOrWeights>*)>;
    333 
    334 class Converter {
    335   std::unordered_map<string, TRT_TensorOrWeights> trt_tensors_;
    336   std::unordered_map<string, OpConverter> op_registry_;
    337   nvinfer1::INetworkDefinition* trt_network_;
    338   std::list<std::vector<uint8_t>> temp_bufs_;
    339 
    340   void register_op_converters();
    341 
    342   std::vector<TRT_TensorOrWeights> get_inputs(
    343       const tensorflow::NodeDef& node_def) {
    344     std::vector<TRT_TensorOrWeights> inputs;
    345     for (const auto& input_name : node_def.input()) {
    346       VLOG(2) << "Retrieve input: " << input_name;
    347       inputs.push_back(trt_tensors_.at(input_name));
    348     }
    349     return inputs;
    350   }
    351 
    352  public:
    353   explicit Converter(nvinfer1::INetworkDefinition* trt_network)
    354       : trt_network_(trt_network) {
    355     this->register_op_converters();
    356   }
    357 
    358   TRT_ShapedWeights get_temp_weights(tensorflow::DataType type,
    359                                      nvinfer1::Dims shape) {
    360     TRT_ShapedWeights weights(type, nullptr, shape);
    361     // TODO(jie): check weights size_bytes. 0 means type error
    362     temp_bufs_.push_back(std::vector<uint8_t>(weights.size_bytes()));
    363     weights.SetValues(temp_bufs_.back().data());
    364     return weights;
    365   }
    366 
    367   TRT_ShapedWeights get_temp_weights_like(const TRT_ShapedWeights& weights) {
    368     return this->get_temp_weights(weights.type_, weights.shape_);
    369   }
    370 
    371   tensorflow::Status convert_node(const tensorflow::NodeDef& node_def) {
    372     std::vector<TRT_TensorOrWeights> inputs = this->get_inputs(node_def);
    373     string op = node_def.op();
    374     if (!op_registry_.count(op)) {
    375       return tensorflow::errors::Unimplemented(
    376           "No converter registered for op: " + op);
    377     }
    378     OpConverter op_converter = op_registry_.at(op);
    379     std::vector<TRT_TensorOrWeights> outputs;
    380     TF_RETURN_IF_ERROR(op_converter(*this, node_def, inputs, &outputs));
    381     for (size_t i = 0; i < outputs.size(); ++i) {
    382       TRT_TensorOrWeights output = outputs.at(i);
    383       // TODO(jie): tf protobuf seems to be omitting the :0 suffix
    384       string output_name = node_def.name();
    385       if (i != 0) output_name = output_name + ":" + std::to_string(i);
    386       if (output.is_tensor()) {
    387         output.tensor()->setName(output_name.c_str());
    388       }
    389       VLOG(2) << "Write out tensor: " << output_name;
    390       if (!trt_tensors_.insert({output_name, output}).second) {
    391         return tensorflow::errors::AlreadyExists(
    392             "Output tensor already exists for op: " + op);
    393       }
    394     }
    395     return tensorflow::Status::OK();
    396   }
    397 
    398   nvinfer1::INetworkDefinition* network() { return trt_network_; }
    399 
    400   TRT_TensorOrWeights get_tensor(string name) {
    401     if (!trt_tensors_.count(name)) {
    402       return TRT_TensorOrWeights(nullptr);
    403     }
    404     return trt_tensors_.at(name);
    405   }
    406 
    407   bool insert_input_tensor(string name, nvinfer1::ITensor* tensor) {
    408     return trt_tensors_.insert({name, TRT_TensorOrWeights(tensor)}).second;
    409   }
    410 
    411   nvinfer1::ITensor* TransposeTensor(nvinfer1::ITensor* input_tensor,
    412                                      std::vector<int> order) {
    413     auto dims = input_tensor->getDimensions();
    414 
    415     // TODO(jie): change the return to status and properly exit
    416     if (order.size() - 1 != size_t(dims.nbDims))
    417       LOG(ERROR) << "Dimension does not match, fail gracefully";
    418 
    419     nvinfer1::IShuffleLayer* layer = this->network()->addShuffle(*input_tensor);
    420     nvinfer1::Permutation permutation;
    421     for (int32_t i = 0; i < dims.nbDims; ++i) {
    422       permutation.order[i] = order[i + 1] - 1;
    423     }
    424     layer->setFirstTranspose(permutation);
    425 
    426     nvinfer1::Dims reshape_dims;
    427     reshape_dims.nbDims = dims.nbDims;
    428     for (int32_t i = 0; i < reshape_dims.nbDims; ++i) {
    429       reshape_dims.d[i] = 0;
    430       reshape_dims.type[i] = dims.type[i];
    431     }
    432     layer->setReshapeDimensions(reshape_dims);
    433     return layer->getOutput(0);
    434   }
    435 };
    436 
    437 // ****************************************************************************
    438 // Constant folding functions
    439 // TODO(jie): once optimizer kicks in, we should have done constant folding
    440 // there.
    441 //*****************************************************************************/
    442 struct LambdaFactory {
    443   enum class OP_CATEGORY : int { RSQRT = 0, NEG, ADD, MUL, SUB };
    444   OP_CATEGORY op;
    445 
    446   template <typename T>
    447   std::function<T(T)> unary() {
    448     switch (op) {
    449       case OP_CATEGORY::RSQRT: {
    450         VLOG(2) << "RSQRT GETS DONE";
    451         return [](T t) -> T { return 1.0 / std::sqrt(t); };
    452       }
    453       case OP_CATEGORY::NEG:
    454         return [](T t) -> T { return -t; };
    455       default:
    456         VLOG(2) << "Not supported op for unary: " << static_cast<int>(op);
    457         return nullptr;
    458     }
    459   }
    460 
    461   template <typename T>
    462   std::function<T(T, T)> binary() {
    463     switch (op) {
    464       case OP_CATEGORY::ADD:
    465         return [](T l, T r) -> T { return l + r; };
    466       case OP_CATEGORY::SUB:
    467         return [](T l, T r) -> T { return l - r; };
    468       case OP_CATEGORY::MUL:
    469         return [](T l, T r) -> T { return l * r; };
    470       default:
    471         LOG(WARNING) << "Not supported op for binary: " << static_cast<int>(op);
    472     }
    473     return [](T l, T r) -> T {
    474       LOG(FATAL) << "Unsupported op type ";
    475       return l;
    476     };
    477   }
    478 
    479   template <typename T>
    480   std::function<T(T)> broadcast_r(T val) {
    481     VLOG(2) << "LAMBDA VAL : " << val;
    482     switch (op) {
    483       case OP_CATEGORY::ADD:
    484         return [val](T l) -> T {
    485           VLOG(2) << "LAMBDA VAL : " << val;
    486           return l + val;
    487         };
    488       // Return [val](T l)-> T {return l+val;};
    489       case OP_CATEGORY::SUB:
    490         return [val](T l) -> T {
    491           VLOG(2) << "LAMBDA VAL : " << val;
    492           return l - val;
    493         };
    494       case OP_CATEGORY::MUL:
    495         return [val](T l) -> T {
    496           VLOG(2) << "LAMBDA VAL : " << val;
    497           return l * val;
    498         };
    499       default:
    500         LOG(WARNING) << "Not supported op for binary: " << static_cast<int>(op);
    501     }
    502     return [val](T l) -> T {
    503       LOG(FATAL) << "Unsupported op type ";
    504       return l;
    505     };
    506   }
    507 
    508   template <typename T>
    509   std::function<T(T)> broadcast_l(T val) {
    510     VLOG(2) << "LAMBDA VAL : " << val;
    511     switch (op) {
    512       case OP_CATEGORY::ADD:
    513         return [val](T l) -> T {
    514           VLOG(2) << "LAMBDA VAL : " << val;
    515           return val + l;
    516         };
    517       case OP_CATEGORY::SUB:
    518         return [val](T l) -> T {
    519           VLOG(2) << "LAMBDA VAL : " << val;
    520           return val - l;
    521         };
    522       case OP_CATEGORY::MUL:
    523         return [val](T l) -> T {
    524           VLOG(2) << "LAMBDA VAL : " << val;
    525           return val * l;
    526         };
    527       default:
    528         LOG(ERROR) << "Not supported op for binary: " << static_cast<int>(op);
    529     }
    530     return [val](T l) -> T {
    531       LOG(FATAL) << "Unsupported op type ";
    532       return l;
    533     };
    534   }
    535 };
    536 
    537 tensorflow::Status UnaryCompute(const TRT_ShapedWeights& iweights,
    538                                 TRT_ShapedWeights* oweights,
    539                                 LambdaFactory unary_op) {
    540   CHECK_EQ(iweights.type_, oweights->type_);
    541   switch (iweights.type_) {
    542     case tensorflow::DataType::DT_FLOAT: {
    543       auto inp = static_cast<float const*>(iweights.GetValues());
    544       auto oup = static_cast<float*>(const_cast<void*>(oweights->GetValues()));
    545       std::transform(inp, inp + iweights.count(), oup, unary_op.unary<float>());
    546       break;
    547     }
    548     default:
    549       return tensorflow::errors::Unimplemented(
    550           "Data type not supported: " +
    551           tensorflow::DataTypeString(iweights.type_));
    552   }
    553   return tensorflow::Status::OK();
    554 }
    555 
    556 tensorflow::Status BinaryCompute(const TRT_ShapedWeights& iweights_l,
    557                                  const TRT_ShapedWeights& iweights_r,
    558                                  TRT_ShapedWeights* oweights,
    559                                  LambdaFactory binary_op) {
    560   // Assume iweights_l.type == iweight_r.type
    561   CHECK_EQ(iweights_l.type_, oweights->type_);
    562   CHECK_EQ(iweights_r.type_, oweights->type_);
    563   VLOG(2) << "SANITY CHECK!";
    564 
    565   switch (iweights_l.type_) {
    566     case tensorflow::DataType::DT_FLOAT: {
    567       auto inp_l = static_cast<const float*>(iweights_l.GetValues());
    568       auto inp_r = static_cast<const float*>(iweights_r.GetValues());
    569       auto oup = static_cast<float*>(const_cast<void*>(oweights->GetValues()));
    570 
    571       if (iweights_l.count() != iweights_r.count()) {
    572         // We only supports broadcast of RankZero
    573         if (iweights_l.count() == 1) {
    574           VLOG(2) << "I bet it is not working!" << (*inp_l);
    575           std::transform(inp_r, inp_r + iweights_r.count(), oup,
    576                          binary_op.broadcast_l<float>(*inp_l));
    577         } else if (iweights_r.count() == 1) {
    578           VLOG(2) << "I bet it is not working!" << (*inp_r);
    579           std::transform(inp_l, inp_l + iweights_l.count(), oup,
    580                          binary_op.broadcast_r<float>(*inp_r));
    581         } else {
    582           return tensorflow::errors::Unimplemented(
    583               "Binary op with non-rankZero broadcast not supported");
    584         }
    585       } else {
    586         std::transform(inp_l, inp_l + iweights_l.count(), inp_r, oup,
    587                        binary_op.binary<float>());
    588       }
    589       break;
    590     }
    591     default:
    592       return tensorflow::errors::Unimplemented(
    593           "Data type not supported: " +
    594           tensorflow::DataTypeString(iweights_l.type_));
    595   }
    596 
    597   return tensorflow::Status::OK();
    598 }
    599 
    600 tensorflow::Status ConstantFoldUnary(
    601     Converter& ctx, const tensorflow::NodeDef& node_def,
    602     std::vector<TRT_TensorOrWeights> const& inputs,
    603     std::vector<TRT_TensorOrWeights>* outputs) {
    604   TRT_ShapedWeights weights_input = inputs.at(0).weights();
    605 
    606   // Allocate output weights
    607   TRT_ShapedWeights weights_output = ctx.get_temp_weights_like(weights_input);
    608 
    609   // FIXME assume type matches input weights
    610   // Get trt type & shape
    611   // Maybe this part has to be moved into the block of rsqrt later
    612   // Check type consistency
    613   CHECK_EQ(weights_input.type_,
    614            TFAttrs(node_def).get<tensorflow::DataType>("T"));
    615 
    616   // Maybe I should do a switch
    617   LambdaFactory unary_op;
    618   if (node_def.op() == "Rsqrt") {
    619     // Compute rsqrt
    620     unary_op.op = LambdaFactory::OP_CATEGORY::RSQRT;
    621     auto ret = UnaryCompute(weights_input, &weights_output, unary_op);
    622     // PAss the output
    623     if (ret == tensorflow::Status::OK()) {
    624       outputs->push_back(TRT_TensorOrWeights(weights_output));
    625     }
    626     return ret;
    627   } else {
    628     return tensorflow::errors::Unimplemented("Binary op not supported: " +
    629                                              node_def.op());
    630   }
    631 }
    632 
    633 // TODO(jie,ben) broadcast is needed yet not implemented
    634 // Let's get the simple stuff working first. Maybe we should fall bakc to TF
    635 //   approach for constant folding
    636 tensorflow::Status ConstantFoldBinary(
    637     Converter& ctx, const tensorflow::NodeDef& node_def,
    638     std::vector<TRT_TensorOrWeights> const& inputs,
    639     std::vector<TRT_TensorOrWeights>* outputs) {
    640   TRT_ShapedWeights weights_input_l = inputs.at(0).weights();
    641   TRT_ShapedWeights weights_input_r = inputs.at(1).weights();
    642 
    643   // Check type consistency
    644   CHECK_EQ(weights_input_l.type_, weights_input_r.type_);
    645 
    646   if (weights_input_l.shape_.nbDims != weights_input_r.shape_.nbDims)
    647     return tensorflow::errors::Unimplemented(
    648         "Binary op implicit broadcast not supported: " + node_def.op());
    649 
    650   // TODO(jie): constant fold should really fall back to TF.
    651   int nb_dims = weights_input_l.shape_.nbDims;
    652   nvinfer1::Dims output_shape;
    653   output_shape.nbDims = nb_dims;
    654   VLOG(2) << "nb_dims: " << nb_dims
    655           << ", the other: " << weights_input_r.shape_.nbDims;
    656   for (int i = 0; i < nb_dims; i++) {
    657     if (weights_input_l.shape_.d[i] == weights_input_r.shape_.d[i]) {
    658       output_shape.d[i] = weights_input_l.shape_.d[i];
    659     } else if (weights_input_l.shape_.d[i] == 1 ||
    660                weights_input_r.shape_.d[i] == 1) {
    661       output_shape.d[i] =
    662           std::max(weights_input_l.shape_.d[i], weights_input_r.shape_.d[i]);
    663     } else {
    664       return tensorflow::errors::Unimplemented(
    665           "Binary op with incompatible shape at, " + node_def.op());
    666     }
    667     VLOG(2) << "left: " << weights_input_l.shape_.d[i]
    668             << "right: " << weights_input_r.shape_.d[i]
    669             << "output: " << output_shape.d[i];
    670   }
    671 
    672   // FIXME assume type matches input weights
    673   // Get trt type & shape
    674   TFAttrs attrs(node_def);
    675   // Maybe this part has to be moved into the block of rsqrt later
    676   tensorflow::DataType dtype = attrs.get<tensorflow::DataType>("T");
    677 
    678   // Allocate output weights
    679   TRT_ShapedWeights weights_output = ctx.get_temp_weights(dtype, output_shape);
    680 
    681   // Maybe I should do a switch
    682   LambdaFactory binary_op;
    683   if (node_def.op() == "Sub") {
    684     binary_op.op = LambdaFactory::OP_CATEGORY::SUB;
    685   } else if (node_def.op() == "Mul") {
    686     binary_op.op = LambdaFactory::OP_CATEGORY::MUL;
    687   } else if (node_def.op() == "Add") {
    688     binary_op.op = LambdaFactory::OP_CATEGORY::ADD;
    689   } else {
    690     return tensorflow::errors::Unimplemented("Binary op not supported: " +
    691                                              node_def.op());
    692   }
    693   auto ret = BinaryCompute(weights_input_l, weights_input_r, &weights_output,
    694                            binary_op);
    695 
    696   // Pass the output
    697   if (ret == tensorflow::Status::OK()) {
    698     outputs->push_back(TRT_TensorOrWeights(weights_output));
    699   }
    700 
    701   return ret;
    702 }
    703 
    704 // TODO(jie): broadcast is needed yet not implemented.
    705 // Only implemented channel wise for the time being
    706 tensorflow::Status BinaryTensorOpWeight(
    707     Converter& ctx, const tensorflow::NodeDef& node_def,
    708     const nvinfer1::ITensor* tensor, TRT_ShapedWeights weights,
    709     std::vector<TRT_TensorOrWeights>* outputs) {
    710   // FIXME assume type matches input weights
    711   // Get trt type & shape
    712   // Maybe this part has to be moved into the block of rsqrt later
    713 
    714   // Check type consistency
    715   auto dtype = TFAttrs(node_def).get<nvinfer1::DataType>("T");
    716   CHECK_EQ_TYPE(tensor->getType(), dtype);  // Cast to int for error messages
    717   nvinfer1::DataType ttype;
    718   TF_CHECK_OK(ConvertDType(weights.type_, &ttype));
    719   CHECK_EQ_TYPE(ttype, dtype);  // Cast to int for error message
    720 
    721   // Check scale mode
    722   auto dims_w = weights.shape_;
    723   auto dims_t = tensor->getDimensions();
    724 
    725   // Default to channel-wise
    726   auto scale_mode = nvinfer1::ScaleMode::kELEMENTWISE;
    727 
    728   if (weights.count() == 1) {
    729     VLOG(2) << "UNIFORM";
    730     scale_mode = nvinfer1::ScaleMode::kUNIFORM;
    731   } else {
    732     // No broadcasting on Batch dimension;
    733     assert(dims_w.d[0] == 1);
    734 
    735     // Broadcasting on Channel dimension only allowed in kUNIFORM
    736     assert(dims_w.d[1] == dims_t.d[0]);
    737     assert(dims_w.nbDims == dims_t.nbDims);
    738 
    739     // Default is element;
    740     for (int i = 2; i < dims_w.nbDims; i++) {
    741       if (dims_w.d[i] != dims_t.d[i - 1]) {
    742         scale_mode = nvinfer1::ScaleMode::kCHANNEL;
    743         break;
    744       }
    745     }
    746     if (scale_mode == nvinfer1::ScaleMode::kELEMENTWISE) {
    747       scale_mode = nvinfer1::ScaleMode::kELEMENTWISE;
    748       for (int i = 2; i < dims_w.nbDims; i++) {
    749         if (dims_w.d[i] != 1)
    750           return tensorflow::errors::InvalidArgument(
    751               "Weight shape not compatible at, " + node_def.name());
    752       }
    753     }
    754   }
    755 
    756   // Prepare weights
    757   TRT_ShapedWeights shift_weights(weights.type_);
    758   TRT_ShapedWeights scale_weights(weights.type_);
    759   TRT_ShapedWeights power_weights(weights.type_);
    760 
    761   // Maybe I should do a switch
    762   if (node_def.op() == "Sub") {
    763     TRT_ShapedWeights neg_weights = ctx.get_temp_weights_like(weights);
    764     LambdaFactory unary_op;
    765     unary_op.op = LambdaFactory::OP_CATEGORY::NEG;
    766     TF_RETURN_IF_ERROR(UnaryCompute(weights, &neg_weights, unary_op));
    767     shift_weights = neg_weights;
    768   } else if (node_def.op() == "Mul") {
    769     scale_weights = weights;
    770   } else if (node_def.op() == "Add") {
    771     shift_weights = weights;
    772   } else {
    773     return tensorflow::errors::Unimplemented("Binary op not supported: " +
    774                                              node_def.op());
    775   }
    776 
    777   nvinfer1::IScaleLayer* layer = ctx.network()->addScale(
    778       *const_cast<nvinfer1::ITensor*>(tensor), scale_mode, shift_weights,
    779       scale_weights, power_weights);
    780 
    781   nvinfer1::ITensor* output_tensor = layer->getOutput(0);
    782 
    783   // Pass the output
    784   outputs->push_back(TRT_TensorOrWeights(output_tensor));
    785   return tensorflow::Status::OK();
    786 }
    787 
    788 tensorflow::Status BinaryTensorOpTensor(
    789     Converter& ctx, const tensorflow::NodeDef& node_def,
    790     const nvinfer1::ITensor* tensor_l, const nvinfer1::ITensor* tensor_r,
    791     std::vector<TRT_TensorOrWeights>* outputs) {
    792   static const std::unordered_map<string, nvinfer1::ElementWiseOperation> ops{
    793       {"Add", nvinfer1::ElementWiseOperation::kSUM},
    794       {"Mul", nvinfer1::ElementWiseOperation::kPROD},
    795       // {"max", nvinfer1::ElementWiseOperation::kMAX},
    796       // {"min", nvinfer1::ElementWiseOperation::kMIN},
    797       {"Sub", nvinfer1::ElementWiseOperation::kSUB},
    798       {"Div", nvinfer1::ElementWiseOperation::kDIV},
    799   };
    800 
    801   // FIXME assume type matches input weights
    802   // Get trt type & shape
    803   TFAttrs attrs(node_def);
    804   // Maybe this part has to be moved into the block of rsqrt later
    805   nvinfer1::DataType dtype = attrs.get<nvinfer1::DataType>("T");
    806 
    807   // Check type consistency
    808   CHECK_EQ_TYPE(tensor_l->getType(), dtype);
    809   CHECK_EQ_TYPE(tensor_r->getType(), dtype);
    810   auto op_pair = ops.find(node_def.op());
    811   if (op_pair == ops.end())
    812     return tensorflow::errors::Unimplemented("binary op: " + node_def.op() +
    813                                              " not supported at: " +
    814                                              node_def.name());
    815 
    816   nvinfer1::IElementWiseLayer* layer = ctx.network()->addElementWise(
    817       *const_cast<nvinfer1::ITensor*>(tensor_l),
    818       *const_cast<nvinfer1::ITensor*>(tensor_r), op_pair->second);
    819 
    820   nvinfer1::ITensor* output_tensor = layer->getOutput(0);
    821 
    822   // Pass the output
    823   outputs->push_back(TRT_TensorOrWeights(output_tensor));
    824   return tensorflow::Status::OK();
    825 }
    826 
    827 tensorflow::Status ConvertPlaceholder(
    828     Converter& ctx, const tensorflow::NodeDef& node_def,
    829     std::vector<TRT_TensorOrWeights> const& inputs,
    830     std::vector<TRT_TensorOrWeights>* outputs) {
    831   VLOG(2) << "Placeholder should have been replace already";
    832   return tensorflow::errors::Unimplemented(", cannot convert Placeholder op");
    833   // OK this make sense since we are supposed to replace it with input
    834   TFAttrs attrs(node_def);
    835   nvinfer1::DataType dtype = attrs.get<nvinfer1::DataType>("dtype");
    836   nvinfer1::Dims dims = attrs.get<nvinfer1::Dims>("shape");
    837 
    838   dims.nbDims--;
    839   for (int i = 0; i < dims.nbDims; i++) dims.d[i] = dims.d[i + 1];
    840 
    841   nvinfer1::ITensor* output =
    842       ctx.network()->addInput(node_def.name().c_str(), dtype, dims);
    843   if (!output) {
    844     return tensorflow::errors::InvalidArgument("Failed to create Input layer");
    845   }
    846   outputs->push_back(TRT_TensorOrWeights(output));
    847   return tensorflow::Status::OK();
    848 }
    849 
    850 tensorflow::Status ConvertConv2D(Converter& ctx,
    851                                  const tensorflow::NodeDef& node_def,
    852                                  const std::vector<TRT_TensorOrWeights>& inputs,
    853                                  std::vector<TRT_TensorOrWeights>* outputs) {
    854   nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
    855   // TODO(jie): handle NHWC/NCHW transpose;
    856   TRT_ShapedWeights weights_rsck = inputs.at(1).weights();
    857   TRT_ShapedWeights weights = ctx.get_temp_weights_like(weights_rsck);
    858   ReorderRSCKToKCRS(weights_rsck, &weights);
    859   TRT_ShapedWeights biases(weights.type_);
    860   int noutput = weights.shape_.d[0];
    861   nvinfer1::DimsHW kernel_size;
    862   kernel_size.h() = weights.shape_.d[2];
    863   kernel_size.w() = weights.shape_.d[3];
    864   TFAttrs attrs(node_def);
    865 
    866   int h_index = 2;
    867   int w_index = 3;
    868   auto data_format = attrs.get<string>("data_format");
    869   if (data_format == "NHWC") {
    870     tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
    871                                  {0, 3, 1, 2});
    872     h_index = 1;
    873     w_index = 2;
    874     // TODO(jie): transpose it
    875   }
    876 
    877   // TODO(jie): stride. (NHWC/NCHW)
    878   auto tf_stride = attrs.get<std::vector<int>>("strides");
    879   nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
    880 
    881   auto tensor_dim = tensor->getDimensions();
    882   std::vector<std::pair<int, int>> padding;
    883   // TODO(jie): padding.
    884   if (attrs.get<string>("padding") == "SAME") {
    885     // This is NCHW tensor with no batch dimension.
    886     //  1 -> h
    887     //  2 -> w
    888     padding = CreateSamePadding(
    889         stride, kernel_size,
    890         {static_cast<int>(tensor_dim.d[1]), static_cast<int>(tensor_dim.d[2])});
    891   } else {
    892     padding = {{0, 0}, {0, 0}};
    893   }
    894 
    895   if (padding[0].first != padding[0].second ||
    896       padding[1].first != padding[1].second) {
    897     // TODO(jie): handle asymmetric padding
    898     VLOG(2) << "Padding!!!: " << padding[0].first << padding[0].second
    899             << padding[1].first << padding[1].second;
    900 
    901     auto dim_before = tensor->getDimensions();
    902     VLOG(2) << "TENSOR before: " << dim_before.d[0] << ", " << dim_before.d[1]
    903             << dim_before.d[2] << ", " << dim_before.d[3];
    904     auto pad_layer = ctx.network()->addPadding(
    905         *const_cast<nvinfer1::ITensor*>(tensor),
    906         nvinfer1::DimsHW(padding[0].first, padding[1].first),
    907         nvinfer1::DimsHW(padding[0].second, padding[1].second));
    908     padding = {{0, 0}, {0, 0}};
    909     tensor = pad_layer->getOutput(0);
    910     auto dim_after = tensor->getDimensions();
    911     VLOG(2) << "TENSOR after: " << dim_after.d[0] << ", " << dim_after.d[1]
    912             << dim_after.d[2] << ", " << dim_after.d[3];
    913   }
    914 
    915   nvinfer1::IConvolutionLayer* layer =
    916       ctx.network()->addConvolution(*const_cast<nvinfer1::ITensor*>(tensor),
    917                                     noutput, kernel_size, weights, biases);
    918 
    919   layer->setStride(stride);
    920   layer->setPadding({padding[0].first, padding[1].first});
    921   layer->setName(node_def.name().c_str());
    922   nvinfer1::ITensor* output_tensor = layer->getOutput(0);
    923 
    924   auto dim_after = output_tensor->getDimensions();
    925   VLOG(2) << "TENSOR out: " << dim_after.d[0] << ", " << dim_after.d[1]
    926           << dim_after.d[2] << ", " << dim_after.d[3];
    927 
    928   if (data_format == "NHWC") {
    929     // TODO(jie): transpose it back!
    930     output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1});
    931   } else {
    932     VLOG(2) << "NCHW !!!!";
    933   }
    934   outputs->push_back(TRT_TensorOrWeights(output_tensor));
    935   return tensorflow::Status::OK();
    936 }
    937 
    938 tensorflow::Status ConvertPool(Converter& ctx,
    939                                const tensorflow::NodeDef& node_def,
    940                                std::vector<TRT_TensorOrWeights> const& inputs,
    941                                std::vector<TRT_TensorOrWeights>* outputs) {
    942   nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
    943   TFAttrs attrs(node_def);
    944 
    945   int h_index = 2;
    946   int w_index = 3;
    947   auto data_format = attrs.get<string>("data_format");
    948   if (data_format == "NHWC") {
    949     h_index = 1;
    950     w_index = 2;
    951     tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
    952                                  {0, 3, 1, 2});
    953   } else {
    954     VLOG(2) << "NCHW !!!!";
    955   }
    956   nvinfer1::PoolingType type;
    957   // TODO(jie): support other pooling type
    958   if (node_def.op() == "MaxPool")
    959     type = nvinfer1::PoolingType::kMAX;
    960   else
    961     return tensorflow::errors::Unimplemented("Only supports Max pool");
    962 
    963   // TODO(jie): NCHW
    964   auto tf_stride = attrs.get<std::vector<int>>("strides");
    965   nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
    966 
    967   auto tf_kernel = attrs.get<std::vector<int>>("ksize");
    968   nvinfer1::DimsHW ksize(tf_kernel[h_index], tf_kernel[w_index]);
    969 
    970   auto tensor_dim = tensor->getDimensions();
    971   std::vector<std::pair<int, int>> padding;
    972   // TODO(jie): padding.
    973   if (attrs.get<string>("padding") == "SAME") {
    974     // This is NCHW tensor with no batch dimension.
    975     //  1 -> h
    976     //  2 -> w
    977     padding = CreateSamePadding(
    978         stride, ksize,
    979         {static_cast<int>(tensor_dim.d[1]), static_cast<int>(tensor_dim.d[2])});
    980   } else if (attrs.get<string>("padding") == "VALID") {
    981     // No padding for valid padding here
    982     VLOG(2) << "No padding added for VALID padding in pool" << node_def.name();
    983     padding = {{0, 0}, {0, 0}};
    984   } else {
    985     return tensorflow::errors::Unimplemented(
    986         "Current MaxPool cannot support padding other than SAME");
    987   }
    988 
    989   if (padding[0].first != padding[0].second ||
    990       padding[1].first != padding[1].second) {
    991     // TODO(jie): handle asymmetric padding
    992     VLOG(2) << "Padding!!!: " << padding[0].first << padding[0].second
    993             << padding[1].first << padding[1].second;
    994     auto pad_layer = ctx.network()->addPadding(
    995         *const_cast<nvinfer1::ITensor*>(tensor),
    996         nvinfer1::DimsHW(padding[0].first, padding[1].first),
    997         nvinfer1::DimsHW(padding[0].second, padding[1].second));
    998     padding = {{0, 0}, {0, 0}};
    999     tensor = pad_layer->getOutput(0);
   1000   }
   1001 
   1002   nvinfer1::IPoolingLayer* layer = ctx.network()->addPooling(
   1003       *const_cast<nvinfer1::ITensor*>(tensor), type, ksize);
   1004 
   1005   layer->setStride(stride);
   1006   layer->setPadding({padding[0].first, padding[1].first});
   1007   layer->setName(node_def.name().c_str());
   1008   nvinfer1::ITensor* output_tensor = layer->getOutput(0);
   1009 
   1010   if (data_format == "NHWC") {
   1011     // TODO(jie): transpose it back!
   1012     output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1});
   1013   } else {
   1014     VLOG(2) << "NCHW !!!!";
   1015   }
   1016   outputs->push_back(TRT_TensorOrWeights(output_tensor));
   1017   return tensorflow::Status::OK();
   1018 }
   1019 
   1020 tensorflow::Status ConvertActivation(
   1021     Converter& ctx, const tensorflow::NodeDef& node_def,
   1022     std::vector<TRT_TensorOrWeights> const& inputs,
   1023     std::vector<TRT_TensorOrWeights>* outputs) {
   1024   nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
   1025   nvinfer1::IActivationLayer* layer = ctx.network()->addActivation(
   1026       *const_cast<nvinfer1::ITensor*>(tensor), nvinfer1::ActivationType::kRELU);
   1027   nvinfer1::ITensor* output_tensor = layer->getOutput(0);
   1028   outputs->push_back(TRT_TensorOrWeights(output_tensor));
   1029   return tensorflow::Status::OK();
   1030 }
   1031 
   1032 tensorflow::Status ConvertScale(Converter& ctx,
   1033                                 const tensorflow::NodeDef& node_def,
   1034                                 std::vector<TRT_TensorOrWeights> const& inputs,
   1035                                 std::vector<TRT_TensorOrWeights>* outputs) {
   1036   if (inputs.size() != 2 || !inputs.at(0).is_tensor() ||
   1037       !inputs.at(1).is_weights())
   1038     return tensorflow::errors::Unimplemented(
   1039         "Only supports tensor op weight for now, at " + node_def.name());
   1040   // Implement tensor binaryOp weight [channel wise] for now;
   1041   nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
   1042 
   1043   // TODO(jie): handle NHWC/NCHW transpose;
   1044   TRT_ShapedWeights weights = inputs.at(1).weights();
   1045   TRT_ShapedWeights empty_weights(weights.type_);
   1046 
   1047   TFAttrs attrs(node_def);
   1048 
   1049   // Transpose NHWC
   1050   auto data_format = attrs.get<string>("data_format");
   1051   if (data_format == "NHWC") {
   1052     tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
   1053                                  {0, 3, 1, 2});
   1054     // TODO(jie): transpose it
   1055   } else {
   1056     VLOG(2) << "NCHW !!!!";
   1057   }
   1058   nvinfer1::IScaleLayer* layer = ctx.network()->addScale(
   1059       *const_cast<nvinfer1::ITensor*>(tensor), nvinfer1::ScaleMode::kCHANNEL,
   1060       weights, empty_weights, empty_weights);
   1061 
   1062   nvinfer1::ITensor* output_tensor = layer->getOutput(0);
   1063   if (data_format == "NHWC") {
   1064     // TODO(jie): transpose it back!
   1065     output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1});
   1066   } else {
   1067     VLOG(2) << "NCHW !!!!";
   1068   }
   1069   outputs->push_back(TRT_TensorOrWeights(output_tensor));
   1070   return tensorflow::Status::OK();
   1071 }
   1072 
   1073 tensorflow::Status ConvertConst(Converter& ctx,
   1074                                 const tensorflow::NodeDef& node_def,
   1075                                 std::vector<TRT_TensorOrWeights> const& inputs,
   1076                                 std::vector<TRT_TensorOrWeights>* outputs) {
   1077   const auto& weights_tensor = node_def.attr().at("value").tensor();
   1078 
   1079   // Get trt type & shape
   1080   TFAttrs attrs(node_def);
   1081   const tensorflow::DataType dtype = attrs.get<tensorflow::DataType>("dtype");
   1082 
   1083   // Create shaped weights as output
   1084   tensorflow::Tensor tensor;
   1085   if (!tensor.FromProto(weights_tensor))
   1086     return tensorflow::errors::Internal("Cannot parse weight tensor proto: " +
   1087                                         node_def.name());
   1088 
   1089   TRT_ShapedWeights weights(dtype);
   1090   if (!weights_tensor.float_val().empty()) {
   1091     VLOG(2) << "SCALAR!!!" << node_def.name();
   1092     nvinfer1::Dims scalar_shape;
   1093     if (tensor.dims() > 0) {
   1094       VLOG(2) << "Dimensions: " << tensor.dims();
   1095       weights = TRT_ShapedWeights(dtype, weights_tensor.float_val().data(),
   1096                                   GetTensorShape(tensor));
   1097     } else {
   1098       VLOG(2) << "Dimensions: " << tensor.dims();
   1099       scalar_shape.nbDims = 1;
   1100       scalar_shape.d[0] = 1;
   1101       scalar_shape.type[0] = nvinfer1::DimensionType::kSPATIAL;
   1102       for (int i = 1; i < nvinfer1::Dims::MAX_DIMS; i++) {
   1103         scalar_shape.d[i] = 0;
   1104         scalar_shape.type[i] = nvinfer1::DimensionType::kSPATIAL;
   1105       }
   1106       weights = TRT_ShapedWeights(dtype, weights_tensor.float_val().data(),
   1107                                   scalar_shape);
   1108     }
   1109   } else if (!weights_tensor.tensor_content().empty()) {
   1110     VLOG(2) << "TENSOR!!!" << node_def.name();
   1111     const auto& content = weights_tensor.tensor_content();
   1112 
   1113     weights = ctx.get_temp_weights(dtype, GetTensorShape(tensor));
   1114     if (content.size() > 0) {
   1115       const int dtype_size = tensorflow::DataTypeSize(dtype);
   1116       CHECK_EQ(0, content.size() % dtype_size)
   1117           << "Tensor content size (" << content.size()
   1118           << ") is not a multiple of " << dtype_size;
   1119       port::CopyToArray(
   1120           content, static_cast<char*>(const_cast<void*>(weights.GetValues())));
   1121     }
   1122   } else {
   1123     return tensorflow::errors::Unimplemented(
   1124         "Not supported constant type, at " + node_def.name());
   1125   }
   1126   // Pass the output
   1127   outputs->push_back(TRT_TensorOrWeights(weights));
   1128   return tensorflow::Status::OK();
   1129 }
   1130 
   1131 tensorflow::Status ConvertIdentity(
   1132     Converter& ctx, const tensorflow::NodeDef& node_def,
   1133     std::vector<TRT_TensorOrWeights> const& inputs,
   1134     std::vector<TRT_TensorOrWeights>* outputs) {
   1135   outputs->push_back(inputs.at(0));
   1136   return tensorflow::Status::OK();
   1137 }
   1138 
   1139 tensorflow::Status ConvertBinary(Converter& ctx,
   1140                                  const tensorflow::NodeDef& node_def,
   1141                                  std::vector<TRT_TensorOrWeights> const& inputs,
   1142                                  std::vector<TRT_TensorOrWeights>* outputs) {
   1143   if (inputs.size() != 2)
   1144     return tensorflow::errors::FailedPrecondition(
   1145         "Binary ops require two tensor input, at " + node_def.name());
   1146 
   1147   if (inputs.at(0).is_weights() && inputs.at(1).is_weights())
   1148     return ConstantFoldBinary(ctx, node_def, inputs, outputs);
   1149 
   1150   if (inputs.at(0).is_tensor() && inputs.at(1).is_weights())
   1151     return BinaryTensorOpWeight(ctx, node_def, inputs.at(0).tensor(),
   1152                                 inputs.at(1).weights(), outputs);
   1153 
   1154   if (inputs.at(0).is_weights() && inputs.at(1).is_tensor())
   1155     return BinaryTensorOpWeight(ctx, node_def, inputs.at(1).tensor(),
   1156                                 inputs.at(0).weights(), outputs);
   1157 
   1158   if (inputs.at(0).is_tensor() && inputs.at(1).is_tensor())
   1159     return BinaryTensorOpTensor(ctx, node_def, inputs.at(0).tensor(),
   1160                                 inputs.at(1).tensor(), outputs);
   1161 
   1162   return tensorflow::errors::Unknown("Binary op input error, at " +
   1163                                      node_def.name());
   1164 }
   1165 
   1166 tensorflow::Status ConvertUnary(Converter& ctx,
   1167                                 const tensorflow::NodeDef& node_def,
   1168                                 std::vector<TRT_TensorOrWeights> const& inputs,
   1169                                 std::vector<TRT_TensorOrWeights>* outputs) {
   1170   if (inputs.size() != 1)
   1171     return tensorflow::errors::FailedPrecondition(
   1172         "Unary ops require single tensor input, at " + node_def.name());
   1173 
   1174   if (inputs.at(0).is_weights())
   1175     return ConstantFoldUnary(ctx, node_def, inputs, outputs);
   1176   else if (inputs.at(0).is_tensor())
   1177     return tensorflow::errors::Unimplemented(
   1178         "Unary op for tensor not supported, at " + node_def.name());
   1179 
   1180   return tensorflow::errors::Unknown("Binary op input error, at " +
   1181                                      node_def.name());
   1182 }
   1183 
   1184 tensorflow::Status ConvertReduce(Converter& ctx,
   1185                                  const tensorflow::NodeDef& node_def,
   1186                                  std::vector<TRT_TensorOrWeights> const& inputs,
   1187                                  std::vector<TRT_TensorOrWeights>* outputs) {
   1188   if (inputs.size() != 2 || !inputs.at(0).is_tensor() ||
   1189       !inputs.at(1).is_weights())
   1190     return tensorflow::errors::InvalidArgument(
   1191         "Input expects tensor and weights, at" + node_def.name());
   1192 
   1193   // Implement tensor binaryOp weight [channel wise] for now;
   1194   nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
   1195   auto dims = tensor->getDimensions();
   1196   // Restore implicit batch dimension
   1197   int nb_dims = dims.nbDims + 1;
   1198 
   1199   TRT_ShapedWeights index_list = inputs.at(1).weights();
   1200 
   1201   TFAttrs attrs(node_def);
   1202   // TODO(jie): handle data type.
   1203   // Index type here is done through TF type, so I can leverage their
   1204   // EnumToDataType for my cast
   1205   auto index_type = attrs.get<tensorflow::DataType>("Tidx");
   1206 
   1207   // Only expect to handle INT32 as attributes for now
   1208   if (index_type != tensorflow::DataType::DT_INT32)
   1209     return tensorflow::errors::Unimplemented("Tidx supports only DT_INT32");
   1210   auto index_list_data =
   1211       static_cast<int*>(const_cast<void*>(index_list.GetValues()));
   1212 
   1213   // Hack warning: have to fall back to pool layer since reduce is not in public
   1214   // TRT yet.
   1215   if (nb_dims != 4)
   1216     return tensorflow::errors::InvalidArgument(
   1217         "TRT only support reduce on 4 dimensional tensors, at" +
   1218         node_def.name());
   1219   if (index_list.count() > 2)
   1220     return tensorflow::errors::InvalidArgument(
   1221         "TRT cannot support reduce on more than 2 dimensions, at" +
   1222         node_def.name());
   1223 
   1224   std::set<int> idx_set;
   1225   // We cannot operate on Channel. permutation flag used to transpose tensor
   1226   int permuted_index = -1;
   1227   for (int i = 0; i < index_list.count(); i++) {
   1228     if (index_list_data[i] == 0)
   1229       return tensorflow::errors::InvalidArgument("TRT cannot reduce at 0, at" +
   1230                                                  node_def.name());
   1231     if (index_list_data[i] == 1) permuted_index = 1;
   1232     idx_set.emplace(index_list_data[i]);
   1233   }
   1234 
   1235   std::vector<int> permutation_order(nb_dims);
   1236   nvinfer1::DimsHW pool_kernel;
   1237   if (permuted_index == 1) {
   1238     for (int i = 2; i < nb_dims; i++) {
   1239       if (idx_set.count(i)) {
   1240         permuted_index = i;
   1241         break;
   1242       }
   1243     }
   1244     for (int i = 0; i < nb_dims; i++) permutation_order[i] = i;
   1245 
   1246     permutation_order[permuted_index] = 1;
   1247     permutation_order[1] = permuted_index;
   1248 
   1249     // Apply permutation before extracting dimension for pool_kernel
   1250     tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
   1251                                  permutation_order);
   1252   }
   1253 
   1254   // Apply permutation before extracting dimension for pool_kernel
   1255   pool_kernel.d[0] = (idx_set.count(2) || permuted_index == 2) ? dims.d[1] : 1;
   1256   pool_kernel.d[1] = (idx_set.count(3) || permuted_index == 3) ? dims.d[2] : 1;
   1257 
   1258   nvinfer1::ITensor* output_tensor;
   1259 
   1260   if (node_def.op() == "Mean") {
   1261     nvinfer1::IPoolingLayer* layer =
   1262         ctx.network()->addPooling(*const_cast<nvinfer1::ITensor*>(tensor),
   1263                                   nvinfer1::PoolingType::kAVERAGE, pool_kernel);
   1264     output_tensor = layer->getOutput(0);
   1265   } else {
   1266     return tensorflow::errors::Unimplemented(
   1267         "Op not supported " + node_def.op() + " , at " + node_def.name());
   1268   }
   1269   if (permuted_index != -1) {
   1270     // Apply permutation before extracting dimension for pool_kernel
   1271     output_tensor = ctx.TransposeTensor(
   1272         const_cast<nvinfer1::ITensor*>(output_tensor), permutation_order);
   1273   }
   1274   return tensorflow::Status::OK();
   1275 }
   1276 
   1277 tensorflow::Status ConvertPad(Converter& ctx,
   1278                               const tensorflow::NodeDef& node_def,
   1279                               std::vector<TRT_TensorOrWeights> const& inputs,
   1280                               std::vector<TRT_TensorOrWeights>* outputs) {
   1281   if (inputs.size() != 2 || !inputs.at(0).is_tensor() ||
   1282       !inputs.at(1).is_weights())
   1283     return tensorflow::errors::InvalidArgument(
   1284         "Input expects tensor and weights, at" + node_def.name());
   1285 
   1286   // Implement tensor binaryOp weight [channel wise] for now;
   1287   nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
   1288   auto dims = tensor->getDimensions();
   1289   // Restore implicit batch dimension
   1290   int nb_dims = dims.nbDims + 1;
   1291 
   1292   TRT_ShapedWeights pads = inputs.at(1).weights();
   1293 
   1294   TFAttrs attrs(node_def);
   1295   // Padding type here is done through TF type
   1296   //   so I can leverage their EnumToDataType for my cast
   1297   auto padding_type = attrs.get<tensorflow::DataType>("Tpaddings");
   1298   // TODO(jie): handle data type conversion for TRT?
   1299 
   1300   if (pads.shape_.d[0] != nb_dims || pads.shape_.d[1] != 2)
   1301     return tensorflow::errors::InvalidArgument(
   1302         "Pad only supports explicit padding on 4 dimensional tensor, at " +
   1303         node_def.name());
   1304 
   1305   // Only expect to handle INT32 as attributes for now
   1306   if (padding_type != tensorflow::DataType::DT_INT32)
   1307     return tensorflow::errors::Unimplemented(
   1308         "Tpaddings supports only DT_INT32");
   1309   auto pad_data = static_cast<int*>(const_cast<void*>(pads.GetValues()));
   1310 
   1311   std::vector<int32_t> pad_index;
   1312   for (int i = 0; i < nb_dims; i++) {
   1313     if (pad_data[2 * i] != 0 || pad_data[2 * i + 1] != 0)
   1314       pad_index.push_back(i);
   1315   }
   1316 
   1317   // No padding at all, we should exit
   1318   if (pad_index.size() == 0) {
   1319     outputs->push_back(inputs.at(0));
   1320     return tensorflow::Status::OK();
   1321   }
   1322 
   1323   // Only supports padding on less than 2 axis GIE-2579
   1324   if (pad_index.size() > 2)
   1325     return tensorflow::errors::InvalidArgument(
   1326         "Padding layer does not support padding on > 2");
   1327 
   1328   // Padding on batch dimension is not supported
   1329   if (pad_index[0] == 0)
   1330     return tensorflow::errors::InvalidArgument(
   1331         "Padding layer does not support padding on batch dimension");
   1332 
   1333   // Not doing the legit thing here. ignoring padding on dim 1 and 3;
   1334   // TODO(jie): implement pad as uff parser
   1335   if (pad_index.size() == 2 && pad_index[0] == 0 && pad_index[1] == 3)
   1336     return tensorflow::errors::Unimplemented(
   1337         "Padding layer does not support padding on dimension 1 and 3 yet");
   1338 
   1339   bool legit_pad = true;
   1340   nvinfer1::DimsHW pre_padding(0, 0);
   1341   nvinfer1::DimsHW post_padding(0, 0);
   1342 
   1343   std::vector<int32_t> permuted_pad_index(pad_index);
   1344   if (pad_index[0] == 1) {
   1345     legit_pad = false;
   1346     tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor),
   1347                                  {0, 3, 2, 1});
   1348     permuted_pad_index[0] = 3;
   1349   }
   1350 
   1351   for (size_t i = 0; i < pad_index.size(); i++) {
   1352     int index = pad_index[i];
   1353     if (permuted_pad_index[i] == 2) {
   1354       pre_padding.h() = pad_data[index * 2];
   1355       post_padding.h() = pad_data[index * 2 + 1];
   1356     } else if (permuted_pad_index[i] == 3) {
   1357       pre_padding.w() = pad_data[index * 2];
   1358       post_padding.w() = pad_data[index * 2 + 1];
   1359     }
   1360   }
   1361 
   1362   nvinfer1::IPaddingLayer* layer = ctx.network()->addPadding(
   1363       *const_cast<nvinfer1::ITensor*>(tensor), pre_padding, post_padding);
   1364   nvinfer1::ITensor* output_tensor = layer->getOutput(0);
   1365 
   1366   if (!legit_pad)
   1367     output_tensor = ctx.TransposeTensor(
   1368         const_cast<nvinfer1::ITensor*>(output_tensor), {0, 3, 2, 1});
   1369 
   1370   outputs->push_back(TRT_TensorOrWeights(output_tensor));
   1371   return tensorflow::Status::OK();
   1372 }
   1373 
   1374 void Converter::register_op_converters() {
   1375   // vgg_16 slim implementation
   1376   op_registry_["Placeholder"] = ConvertPlaceholder;
   1377   op_registry_["Conv2D"] = ConvertConv2D;
   1378   op_registry_["Relu"] = ConvertActivation;
   1379   op_registry_["MaxPool"] = ConvertPool;
   1380   // This could be really handled as ConvertBinary
   1381   op_registry_["BiasAdd"] = ConvertScale;
   1382   op_registry_["Const"] = ConvertConst;
   1383   // op_registry_["MatMul"] = ConvertFullyConnected;  // Not used in vgg
   1384   // TODO(ben,jie): this is a temp hack.
   1385   op_registry_["Identity"] = ConvertIdentity;  // Identity should be removed
   1386   // op_registry_["AvgPool"] = ConvertPool;
   1387 
   1388   // resnet_50_v1 slim implementation
   1389   op_registry_["Add"] = ConvertBinary;
   1390   op_registry_["Mul"] = ConvertBinary;
   1391   op_registry_["Sub"] = ConvertBinary;
   1392   op_registry_["Rsqrt"] = ConvertUnary;
   1393   op_registry_["Mean"] = ConvertReduce;
   1394   op_registry_["Pad"] = ConvertPad;
   1395   // TODO(ben,jie): Add more ops
   1396 }
   1397 
   1398 }  // namespace
   1399 
   1400 tensorflow::Status ConvertSubGraphToTensorRTNodeDef(
   1401     const tensorflow::Graph& graph, const std::set<int>& subgraph_node_ids,
   1402     const std::vector<std::pair<int, int>>& input_inds,
   1403     const std::vector<std::pair<int, int>>& output_inds, size_t max_batch_size,
   1404     size_t max_workspace_size_bytes,
   1405     const tensorflow::grappler::GraphProperties& graph_properties,
   1406     tensorflow::NodeDef* trt_node) {
   1407   // Visit nodes in reverse topological order and construct the TRT network.
   1408 
   1409   // Toposort
   1410   std::vector<tensorflow::Node*> order_vec;
   1411   tensorflow::GetPostOrder(graph, &order_vec);
   1412   // Select just the subgraph
   1413   std::list<tensorflow::Node*> order;
   1414   for (tensorflow::Node* node : order_vec) {
   1415     if (subgraph_node_ids.count(node->id())) {
   1416       // We want topological order to contstruct the
   1417       // network layer by layer
   1418       order.push_front(node);
   1419     }
   1420   }
   1421   // Topological order is needed to build TRT network
   1422 
   1423   tensorflow::tensorrt::Logger trt_logger;
   1424 
   1425   auto trt_builder = infer_object(nvinfer1::createInferBuilder(trt_logger));
   1426   if (!trt_builder) {
   1427     return tensorflow::errors::Internal(
   1428         "Failed to create TensorRT builder object");
   1429   }
   1430 
   1431   auto trt_network = infer_object(trt_builder->createNetwork());
   1432   if (!trt_network) {
   1433     return tensorflow::errors::Internal(
   1434         "Failed to create TensorRT network object");
   1435   }
   1436 
   1437   // Build the network
   1438   Converter converter(trt_network.get());
   1439 
   1440   std::vector<string> input_names;
   1441   std::vector<tensorflow::DataType> input_dtypes;
   1442   for (std::pair<int, int> const& input : input_inds) {
   1443     int node_id = input.first;
   1444     int output_idx = input.second;
   1445     tensorflow::Node* node = graph.FindNodeId(node_id);
   1446     auto node_name = node->name();
   1447     input_names.push_back(node_name);  // Insert original node name without port
   1448     // TODO(jie): alternative :)
   1449     if (!graph_properties.HasOutputProperties(node_name))
   1450       return tensorflow::errors::Internal("Failed to find input node: " +
   1451                                           node_name);
   1452 
   1453     auto op_info_vec = graph_properties.GetOutputProperties(node_name);
   1454     if (static_cast<int>(op_info_vec.size()) < output_idx)
   1455       return tensorflow::errors::Internal(
   1456           "Accessing output index of: " + std::to_string(output_idx) +
   1457           ", at node: " + node_name + " with output entry from shape_map: " +
   1458           std::to_string(op_info_vec.size()));
   1459 
   1460     auto op_info = op_info_vec.at(output_idx);
   1461 
   1462     tensorflow::DataType tf_dtype = op_info.dtype();
   1463     input_dtypes.push_back(tf_dtype);
   1464 
   1465     nvinfer1::DataType dtype(nvinfer1::DataType::kFLOAT);
   1466     TF_CHECK_OK(ConvertDType(tf_dtype, &dtype));
   1467 
   1468     VLOG(2) << "Accessing output index of: " << std::to_string(output_idx)
   1469             << ", at node: " << node_name
   1470             << " with output entry from shape_map: "
   1471             << std::to_string(op_info_vec.size());
   1472 
   1473     // TODO(ben,jie): update TRT input format/dimension
   1474     nvinfer1::DimsCHW input_dim_psuedo_chw;
   1475     for (int i = 0; i < 3; i++) input_dim_psuedo_chw.d[i] = 1;
   1476 
   1477     for (int i = 1; i < op_info.shape().dim_size(); i++) {
   1478       VLOG(2) << "dimension: " << i
   1479               << " , size: " << op_info.shape().dim(i).size();
   1480       input_dim_psuedo_chw.d[i - 1] = op_info.shape().dim(i).size();
   1481     }
   1482 
   1483     // TODO(ben,jie): proper way to restore input tensor name?
   1484     auto input_tensor_name = node_name;
   1485     if (output_idx != 0)
   1486       input_tensor_name = node_name + ":" + std::to_string(output_idx);
   1487 
   1488     nvinfer1::ITensor* input_tensor = converter.network()->addInput(
   1489         input_tensor_name.c_str(), dtype, input_dim_psuedo_chw);
   1490 
   1491     if (!input_tensor)
   1492       return tensorflow::errors::InvalidArgument(
   1493           "Failed to create Input layer");
   1494     VLOG(2) << "Input tensor name :" << input_tensor_name;
   1495 
   1496     if (!converter.insert_input_tensor(input_tensor_name, input_tensor))
   1497       return tensorflow::errors::AlreadyExists(
   1498           "Output tensor already exists for op: " + input_tensor_name);
   1499   }
   1500 
   1501   VLOG(2) << "Finished sorting";
   1502 
   1503   for (const tensorflow::Node* node : order) {
   1504     const tensorflow::NodeDef& node_def = node->def();
   1505     VLOG(2) << "Converting node: " << node_def.name() << " , " << node_def.op();
   1506     TF_RETURN_IF_ERROR(converter.convert_node(node_def));
   1507   }
   1508 
   1509   VLOG(2) << "Finished conversion";
   1510 
   1511   // Gather output metadata
   1512   std::vector<string> output_names;
   1513   std::vector<tensorflow::DataType> output_dtypes;
   1514   for (std::pair<int, int> const& output : output_inds) {
   1515     int node_id = output.first;
   1516     int output_idx = output.second;
   1517     tensorflow::Node* node = graph.FindNodeId(node_id);
   1518     string op_name = node->name();
   1519     string tensor_name = op_name;
   1520     if (output_idx != 0)
   1521       tensor_name = tensor_name + ":" + std::to_string(output_idx);
   1522     VLOG(2) << "Output tensor name: " << tensor_name;
   1523     output_names.push_back(tensor_name);
   1524     auto tensor_or_weights = converter.get_tensor(tensor_name);
   1525     if (!tensor_or_weights.is_tensor()) {
   1526       return tensorflow::errors::InvalidArgument(
   1527           "Output node is weights not tensor");
   1528     }
   1529     nvinfer1::ITensor* tensor = tensor_or_weights.tensor();
   1530     if (!tensor) {
   1531       return tensorflow::errors::NotFound("Output tensor not found: " +
   1532                                           tensor_name);
   1533     }
   1534     converter.network()->markOutput(*tensor);
   1535     tensorflow::DataType tf_dtype = node->output_type(output_idx);
   1536     output_dtypes.push_back(tf_dtype);
   1537     nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT;
   1538     TF_RETURN_IF_ERROR(ConvertDType(tf_dtype, &trt_dtype));
   1539     tensor->setType(trt_dtype);
   1540   }
   1541 
   1542   VLOG(2) << "Finished output";
   1543   // TODO(jie): static_id is not thread safe.
   1544   static int static_id = 0;
   1545 
   1546   // Build the engine
   1547   trt_builder->setMaxBatchSize(max_batch_size);
   1548   trt_builder->setMaxWorkspaceSize(max_workspace_size_bytes);
   1549   VLOG(0) << "Starting build engine " << static_id;
   1550   // TODO(ben,jie): half2 and int8 mode support
   1551   string engine_plan_string;
   1552   {
   1553     auto trt_engine =
   1554         infer_object(trt_builder->buildCudaEngine(*converter.network()));
   1555     VLOG(0) << "Built network";
   1556     auto engine_plan = infer_object(trt_engine->serialize());
   1557     VLOG(0) << "Serialized engine";
   1558     const char* engine_plan_data =
   1559         static_cast<const char*>(engine_plan->data());
   1560     engine_plan_string =
   1561         string(engine_plan_data, engine_plan_data + engine_plan->size());
   1562   }
   1563 
   1564   VLOG(0) << "Finished engine";
   1565 
   1566   // Build the TRT op
   1567   // TODO(sami,ben,jie): proper naming!
   1568   tensorflow::NodeDefBuilder op_builder(
   1569       tensorflow::strings::StrCat("my_trt_op", static_id++), "TRTEngineOp");
   1570   std::vector<tensorflow::NodeDefBuilder::NodeOut> income_edges;
   1571   for (size_t i = 0; i < input_names.size(); ++i) {
   1572     int output_idx = input_inds.at(i).second;
   1573     // We wired up the input here already, it is redundant to do it again in
   1574     // ConvertSubGraphToTensorRT(convert_graph.cc)
   1575     auto incoming_edge = tensorflow::NodeDefBuilder::NodeOut(
   1576         input_names.at(i), output_idx, input_dtypes.at(i));
   1577     income_edges.push_back(incoming_edge);
   1578   }
   1579   tensorflow::gtl::ArraySlice<tensorflow::NodeDefBuilder::NodeOut> input_list(
   1580       income_edges);
   1581   op_builder.Input(input_list);
   1582 
   1583   VLOG(0) << "Finished op preparation";
   1584 
   1585   auto status = op_builder.Attr("serialized_engine", engine_plan_string)
   1586                     .Attr("input_nodes", input_names)
   1587                     .Attr("output_nodes", output_names)
   1588                     .Attr("OutT", output_dtypes)
   1589                     .Finalize(trt_node);
   1590 
   1591   VLOG(0) << status.ToString() << " finished op building";
   1592 
   1593   return tensorflow::Status::OK();
   1594 }
   1595 
   1596 }  // namespace convert
   1597 }  // namespace tensorrt
   1598 }  // namespace tensorflow
   1599 
   1600 #endif  // GOOGLE_TENSORRT
   1601 #endif  // GOOGLE_CUDA
   1602