1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" 17 18 #include <algorithm> 19 #include <list> 20 #include <map> 21 #include <memory> 22 #include <set> 23 #include <unordered_map> 24 #include <utility> 25 #include <vector> 26 27 #include "tensorflow/core/framework/node_def_builder.h" 28 #include "tensorflow/core/framework/tensor_shape.pb.h" // NOLINT 29 #include "tensorflow/core/framework/types.h" 30 #include "tensorflow/core/graph/algorithm.h" 31 #include "tensorflow/core/graph/graph.h" 32 #include "tensorflow/core/graph/graph_constructor.h" 33 #include "tensorflow/core/lib/core/errors.h" 34 #include "tensorflow/core/lib/core/status.h" 35 #include "tensorflow/core/lib/strings/strcat.h" 36 #include "tensorflow/core/platform/logging.h" 37 #include "tensorflow/core/platform/tensor_coding.h" 38 #include "tensorflow/core/platform/types.h" 39 40 #if GOOGLE_CUDA 41 #if GOOGLE_TENSORRT 42 #include "tensorflow/contrib/tensorrt/log/trt_logger.h" 43 #include "tensorrt/include/NvInfer.h" 44 45 // Check if the types are equal. Cast to int first so that failure log message 46 // would work! 47 #define CHECK_EQ_TYPE(val1, val2) CHECK_EQ((int)val1, (int)val2) 48 49 namespace tensorflow { 50 namespace tensorrt { 51 namespace convert { 52 53 namespace { 54 55 inline tensorflow::Status ConvertDType(tensorflow::DataType tf_dtype, 56 nvinfer1::DataType* trt_dtype) { 57 switch (tf_dtype) { 58 case tensorflow::DataType::DT_FLOAT: 59 *trt_dtype = nvinfer1::DataType::kFLOAT; 60 break; 61 case tensorflow::DataType::DT_INT8: 62 *trt_dtype = nvinfer1::DataType::kINT8; 63 break; 64 case tensorflow::DataType::DT_HALF: 65 *trt_dtype = nvinfer1::DataType::kHALF; 66 break; 67 default: 68 return tensorflow::errors::InvalidArgument("Unsupported data type"); 69 } 70 return tensorflow::Status::OK(); 71 } 72 73 inline nvinfer1::Dims GetTensorShape(const tensorflow::Tensor& tensor) { 74 nvinfer1::Dims dims; 75 dims.nbDims = tensor.dims(); 76 for (int i = 0; i < dims.nbDims; i++) { 77 dims.d[i] = tensor.dim_size(i); 78 } 79 return dims; 80 } 81 82 inline int64_t GetShapeSize(nvinfer1::Dims shape) { 83 // Returns total number of elements in shape 84 int64_t count = 1; 85 for (int d = 0; d < shape.nbDims; ++d) { 86 count *= shape.d[d]; 87 } 88 return count; 89 } 90 91 static std::vector<std::pair<int, int>> CreateSamePadding( 92 const nvinfer1::DimsHW& stride, const nvinfer1::DimsHW& kernel, 93 const std::vector<int64_t>& input_dims) { 94 std::vector<std::pair<int, int>> padding(input_dims.size()); 95 CHECK_EQ((size_t)stride.nbDims, input_dims.size()); // TODO(jie): N+C? NC+? 96 97 for (size_t i = 0; i < input_dims.size(); ++i) { 98 // Formula to calculate the padding 99 int p = ((input_dims[i] - 1) / stride.d[i]) * stride.d[i] + kernel.d[i] - 100 input_dims[i]; 101 p = (p > 0) ? p : 0; 102 103 // Right precedence padding, like in TensorFlow 104 int left = p / 2; 105 int right = p - left; 106 107 VLOG(2) << "PADDING_" << i << " pre: " << left << ", post: " << right 108 << "paras: " << input_dims[i] << ", " << stride.d[i] << ", " 109 << "kernel: " << kernel.d[i]; 110 padding[i] = {left, right}; 111 } 112 return padding; 113 } 114 115 class TRT_ShapedWeights { 116 public: 117 TRT_ShapedWeights(tensorflow::DataType type, const void* values, 118 nvinfer1::Dims shape) 119 : shape_(shape), type_(type), values_(values), empty_weight_flag_(false) { 120 // Note: this->shape.type[] is not used 121 } 122 123 explicit TRT_ShapedWeights(tensorflow::DataType type) 124 : shape_(), type_(type), values_(nullptr), empty_weight_flag_(true) {} 125 126 TRT_ShapedWeights(const TRT_ShapedWeights& rhs) 127 : shape_(rhs.shape_), 128 type_(rhs.type_), 129 values_(rhs.values_), 130 empty_weight_flag_(rhs.empty_weight_flag_) {} 131 132 int64_t count() const { 133 int64_t c = 1; 134 for (int i = 0; i < shape_.nbDims; i++) c *= shape_.d[i]; 135 return c; 136 } 137 138 nvinfer1::Weights GetWeightsForTRT() const { 139 nvinfer1::DataType trt_type(nvinfer1::DataType::kFLOAT); 140 TF_CHECK_OK(ConvertDType(type_, &trt_type)); 141 if (empty_weight_flag_) return nvinfer1::Weights{trt_type, nullptr, 0}; 142 143 // Note: this->shape.type[] is not used 144 return nvinfer1::Weights{trt_type, GetValues(), GetShapeSize(shape_)}; 145 } 146 147 const void* GetValues() const { return values_; } 148 149 void SetValues(const void* values) { values_ = values; } 150 151 size_t size_bytes() const { 152 int type_size = tensorflow::DataTypeSize(this->type_); 153 return this->count() * type_size; 154 } 155 156 // Default converter 157 operator nvinfer1::Weights() const { return GetWeightsForTRT(); } 158 159 nvinfer1::Dims shape_; 160 tensorflow::DataType type_; 161 162 private: 163 const void* values_; 164 bool empty_weight_flag_; 165 }; 166 167 class TRT_TensorOrWeights { 168 public: 169 explicit TRT_TensorOrWeights(nvinfer1::ITensor* tensor) 170 : tensor_(tensor), weights_(DT_FLOAT), variant_(TRT_NODE_TENSOR) {} 171 explicit TRT_TensorOrWeights(const TRT_ShapedWeights& weights) 172 : tensor_(nullptr), weights_(weights), variant_(TRT_NODE_WEIGHTS) {} 173 TRT_TensorOrWeights(const TRT_TensorOrWeights& rhs) 174 : tensor_(rhs.tensor_), weights_(rhs.weights_), variant_(rhs.variant_) {} 175 ~TRT_TensorOrWeights() {} 176 177 bool is_tensor() const { return variant_ == TRT_NODE_TENSOR; } 178 bool is_weights() const { return variant_ == TRT_NODE_WEIGHTS; } 179 180 nvinfer1::ITensor* tensor() { 181 CHECK_EQ(is_tensor(), true); 182 return tensor_; 183 } 184 const nvinfer1::ITensor* tensor() const { 185 CHECK_EQ(is_tensor(), true); 186 return tensor_; 187 } 188 TRT_ShapedWeights& weights() { 189 CHECK_EQ(is_weights(), true); 190 return weights_; 191 } 192 const TRT_ShapedWeights& weights() const { 193 CHECK_EQ(is_weights(), true); 194 return weights_; 195 } 196 nvinfer1::Dims shape() const { 197 if (is_tensor()) { 198 return tensor()->getDimensions(); 199 } else { 200 return weights().shape_; 201 } 202 } 203 204 private: 205 nvinfer1::ITensor* tensor_; 206 TRT_ShapedWeights weights_; 207 enum { TRT_NODE_TENSOR, TRT_NODE_WEIGHTS } variant_; 208 }; 209 210 class TFAttrs { 211 public: 212 explicit TFAttrs(const tensorflow::NodeDef& tf_node) { 213 for (const auto& attr : tf_node.attr()) { 214 attrs_.insert({attr.first, &attr.second}); 215 } 216 } 217 bool count(string key) const { return attrs_.count(key); } 218 tensorflow::AttrValue const* at(string key) const { 219 if (!attrs_.count(key)) { 220 LOG(FATAL) << "Attribute not found: " << key; 221 } 222 return attrs_.at(key); 223 } 224 template <typename T> 225 T get(string key) const; 226 template <typename T> 227 T get(string key, const T& default_value) const { 228 return attrs_.count(key) ? this->get<T>(key) : default_value; 229 } 230 231 private: 232 typedef std::map<string, tensorflow::AttrValue const*> AttrMap; 233 AttrMap attrs_; 234 }; 235 236 template <> 237 string TFAttrs::get<string>(string key) const { 238 return this->at(key)->s(); 239 } 240 241 template <> 242 std::vector<int> TFAttrs::get<std::vector<int>>(string key) const { 243 auto attr = this->at(key)->list().i(); 244 return std::vector<int>(attr.begin(), attr.end()); 245 } 246 247 template <> 248 nvinfer1::Dims TFAttrs::get<nvinfer1::Dims>(string key) const { 249 auto values = this->get<std::vector<int>>(key); 250 nvinfer1::Dims dims; 251 dims.nbDims = values.size(); 252 std::copy(values.begin(), values.end(), dims.d); 253 // Note: No dimension type information is included 254 return dims; 255 } 256 257 template <> 258 nvinfer1::DataType TFAttrs::get<nvinfer1::DataType>(string key) const { 259 nvinfer1::DataType trt_dtype(nvinfer1::DataType::kFLOAT); 260 TF_CHECK_OK(ConvertDType(this->at(key)->type(), &trt_dtype)); 261 return trt_dtype; 262 } 263 264 template <> 265 tensorflow::DataType TFAttrs::get<tensorflow::DataType>(string key) const { 266 return this->at(key)->type(); 267 } 268 269 template <typename T> 270 void Reorder4(nvinfer1::DimsNCHW shape, const T* idata, 271 nvinfer1::DimsNCHW istrides, T* odata, 272 nvinfer1::DimsNCHW ostrides) { 273 for (int n = 0; n < shape.n(); ++n) { 274 for (int c = 0; c < shape.c(); ++c) { 275 for (int h = 0; h < shape.h(); ++h) { 276 for (int w = 0; w < shape.w(); ++w) { 277 odata[n * ostrides.n() + c * ostrides.c() + h * ostrides.h() + 278 w * ostrides.w()] = idata[n * istrides.n() + c * istrides.c() + 279 h * istrides.h() + w * istrides.w()]; 280 } 281 } 282 } 283 } 284 } 285 286 void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights, 287 TRT_ShapedWeights* oweights) { 288 CHECK_EQ(iweights.type_, oweights->type_); 289 CHECK_EQ(iweights.size_bytes(), oweights->size_bytes()); 290 int r = iweights.shape_.d[0]; 291 int s = iweights.shape_.d[1]; 292 int c = iweights.shape_.d[2]; 293 int k = iweights.shape_.d[3]; 294 oweights->shape_.d[0] = k; 295 oweights->shape_.d[1] = c; 296 oweights->shape_.d[2] = r; 297 oweights->shape_.d[3] = s; 298 nvinfer1::DimsNCHW istrides = {1, k, s * k * c, c * k}; 299 nvinfer1::DimsNCHW ostrides = {c * r * s, r * s, s, 1}; 300 switch (iweights.type_) { 301 case tensorflow::DataType::DT_FLOAT: 302 Reorder4({k, c, r, s}, static_cast<float const*>(iweights.GetValues()), 303 istrides, 304 static_cast<float*>(const_cast<void*>(oweights->GetValues())), 305 ostrides); 306 break; 307 default: 308 LOG(FATAL) << "!!!!!!!!!!!!!!!!!!!!!!!!broke!!!!!!!!!!!!"; 309 } 310 } 311 312 struct InferDeleter { 313 template <typename T> 314 void operator()(T* obj) const { 315 if (obj) { 316 obj->destroy(); 317 } 318 } 319 }; 320 321 template <typename T> 322 inline std::shared_ptr<T> infer_object(T* obj) { 323 return std::shared_ptr<T>(obj, InferDeleter()); 324 } 325 326 // Logger for GIE info/warning/errors 327 class Converter; 328 329 using OpConverter = 330 std::function<tensorflow::Status(Converter&, const tensorflow::NodeDef&, 331 std::vector<TRT_TensorOrWeights> const&, 332 std::vector<TRT_TensorOrWeights>*)>; 333 334 class Converter { 335 std::unordered_map<string, TRT_TensorOrWeights> trt_tensors_; 336 std::unordered_map<string, OpConverter> op_registry_; 337 nvinfer1::INetworkDefinition* trt_network_; 338 std::list<std::vector<uint8_t>> temp_bufs_; 339 340 void register_op_converters(); 341 342 std::vector<TRT_TensorOrWeights> get_inputs( 343 const tensorflow::NodeDef& node_def) { 344 std::vector<TRT_TensorOrWeights> inputs; 345 for (const auto& input_name : node_def.input()) { 346 VLOG(2) << "Retrieve input: " << input_name; 347 inputs.push_back(trt_tensors_.at(input_name)); 348 } 349 return inputs; 350 } 351 352 public: 353 explicit Converter(nvinfer1::INetworkDefinition* trt_network) 354 : trt_network_(trt_network) { 355 this->register_op_converters(); 356 } 357 358 TRT_ShapedWeights get_temp_weights(tensorflow::DataType type, 359 nvinfer1::Dims shape) { 360 TRT_ShapedWeights weights(type, nullptr, shape); 361 // TODO(jie): check weights size_bytes. 0 means type error 362 temp_bufs_.push_back(std::vector<uint8_t>(weights.size_bytes())); 363 weights.SetValues(temp_bufs_.back().data()); 364 return weights; 365 } 366 367 TRT_ShapedWeights get_temp_weights_like(const TRT_ShapedWeights& weights) { 368 return this->get_temp_weights(weights.type_, weights.shape_); 369 } 370 371 tensorflow::Status convert_node(const tensorflow::NodeDef& node_def) { 372 std::vector<TRT_TensorOrWeights> inputs = this->get_inputs(node_def); 373 string op = node_def.op(); 374 if (!op_registry_.count(op)) { 375 return tensorflow::errors::Unimplemented( 376 "No converter registered for op: " + op); 377 } 378 OpConverter op_converter = op_registry_.at(op); 379 std::vector<TRT_TensorOrWeights> outputs; 380 TF_RETURN_IF_ERROR(op_converter(*this, node_def, inputs, &outputs)); 381 for (size_t i = 0; i < outputs.size(); ++i) { 382 TRT_TensorOrWeights output = outputs.at(i); 383 // TODO(jie): tf protobuf seems to be omitting the :0 suffix 384 string output_name = node_def.name(); 385 if (i != 0) output_name = output_name + ":" + std::to_string(i); 386 if (output.is_tensor()) { 387 output.tensor()->setName(output_name.c_str()); 388 } 389 VLOG(2) << "Write out tensor: " << output_name; 390 if (!trt_tensors_.insert({output_name, output}).second) { 391 return tensorflow::errors::AlreadyExists( 392 "Output tensor already exists for op: " + op); 393 } 394 } 395 return tensorflow::Status::OK(); 396 } 397 398 nvinfer1::INetworkDefinition* network() { return trt_network_; } 399 400 TRT_TensorOrWeights get_tensor(string name) { 401 if (!trt_tensors_.count(name)) { 402 return TRT_TensorOrWeights(nullptr); 403 } 404 return trt_tensors_.at(name); 405 } 406 407 bool insert_input_tensor(string name, nvinfer1::ITensor* tensor) { 408 return trt_tensors_.insert({name, TRT_TensorOrWeights(tensor)}).second; 409 } 410 411 nvinfer1::ITensor* TransposeTensor(nvinfer1::ITensor* input_tensor, 412 std::vector<int> order) { 413 auto dims = input_tensor->getDimensions(); 414 415 // TODO(jie): change the return to status and properly exit 416 if (order.size() - 1 != size_t(dims.nbDims)) 417 LOG(ERROR) << "Dimension does not match, fail gracefully"; 418 419 nvinfer1::IShuffleLayer* layer = this->network()->addShuffle(*input_tensor); 420 nvinfer1::Permutation permutation; 421 for (int32_t i = 0; i < dims.nbDims; ++i) { 422 permutation.order[i] = order[i + 1] - 1; 423 } 424 layer->setFirstTranspose(permutation); 425 426 nvinfer1::Dims reshape_dims; 427 reshape_dims.nbDims = dims.nbDims; 428 for (int32_t i = 0; i < reshape_dims.nbDims; ++i) { 429 reshape_dims.d[i] = 0; 430 reshape_dims.type[i] = dims.type[i]; 431 } 432 layer->setReshapeDimensions(reshape_dims); 433 return layer->getOutput(0); 434 } 435 }; 436 437 // **************************************************************************** 438 // Constant folding functions 439 // TODO(jie): once optimizer kicks in, we should have done constant folding 440 // there. 441 //*****************************************************************************/ 442 struct LambdaFactory { 443 enum class OP_CATEGORY : int { RSQRT = 0, NEG, ADD, MUL, SUB }; 444 OP_CATEGORY op; 445 446 template <typename T> 447 std::function<T(T)> unary() { 448 switch (op) { 449 case OP_CATEGORY::RSQRT: { 450 VLOG(2) << "RSQRT GETS DONE"; 451 return [](T t) -> T { return 1.0 / std::sqrt(t); }; 452 } 453 case OP_CATEGORY::NEG: 454 return [](T t) -> T { return -t; }; 455 default: 456 VLOG(2) << "Not supported op for unary: " << static_cast<int>(op); 457 return nullptr; 458 } 459 } 460 461 template <typename T> 462 std::function<T(T, T)> binary() { 463 switch (op) { 464 case OP_CATEGORY::ADD: 465 return [](T l, T r) -> T { return l + r; }; 466 case OP_CATEGORY::SUB: 467 return [](T l, T r) -> T { return l - r; }; 468 case OP_CATEGORY::MUL: 469 return [](T l, T r) -> T { return l * r; }; 470 default: 471 LOG(WARNING) << "Not supported op for binary: " << static_cast<int>(op); 472 } 473 return [](T l, T r) -> T { 474 LOG(FATAL) << "Unsupported op type "; 475 return l; 476 }; 477 } 478 479 template <typename T> 480 std::function<T(T)> broadcast_r(T val) { 481 VLOG(2) << "LAMBDA VAL : " << val; 482 switch (op) { 483 case OP_CATEGORY::ADD: 484 return [val](T l) -> T { 485 VLOG(2) << "LAMBDA VAL : " << val; 486 return l + val; 487 }; 488 // Return [val](T l)-> T {return l+val;}; 489 case OP_CATEGORY::SUB: 490 return [val](T l) -> T { 491 VLOG(2) << "LAMBDA VAL : " << val; 492 return l - val; 493 }; 494 case OP_CATEGORY::MUL: 495 return [val](T l) -> T { 496 VLOG(2) << "LAMBDA VAL : " << val; 497 return l * val; 498 }; 499 default: 500 LOG(WARNING) << "Not supported op for binary: " << static_cast<int>(op); 501 } 502 return [val](T l) -> T { 503 LOG(FATAL) << "Unsupported op type "; 504 return l; 505 }; 506 } 507 508 template <typename T> 509 std::function<T(T)> broadcast_l(T val) { 510 VLOG(2) << "LAMBDA VAL : " << val; 511 switch (op) { 512 case OP_CATEGORY::ADD: 513 return [val](T l) -> T { 514 VLOG(2) << "LAMBDA VAL : " << val; 515 return val + l; 516 }; 517 case OP_CATEGORY::SUB: 518 return [val](T l) -> T { 519 VLOG(2) << "LAMBDA VAL : " << val; 520 return val - l; 521 }; 522 case OP_CATEGORY::MUL: 523 return [val](T l) -> T { 524 VLOG(2) << "LAMBDA VAL : " << val; 525 return val * l; 526 }; 527 default: 528 LOG(ERROR) << "Not supported op for binary: " << static_cast<int>(op); 529 } 530 return [val](T l) -> T { 531 LOG(FATAL) << "Unsupported op type "; 532 return l; 533 }; 534 } 535 }; 536 537 tensorflow::Status UnaryCompute(const TRT_ShapedWeights& iweights, 538 TRT_ShapedWeights* oweights, 539 LambdaFactory unary_op) { 540 CHECK_EQ(iweights.type_, oweights->type_); 541 switch (iweights.type_) { 542 case tensorflow::DataType::DT_FLOAT: { 543 auto inp = static_cast<float const*>(iweights.GetValues()); 544 auto oup = static_cast<float*>(const_cast<void*>(oweights->GetValues())); 545 std::transform(inp, inp + iweights.count(), oup, unary_op.unary<float>()); 546 break; 547 } 548 default: 549 return tensorflow::errors::Unimplemented( 550 "Data type not supported: " + 551 tensorflow::DataTypeString(iweights.type_)); 552 } 553 return tensorflow::Status::OK(); 554 } 555 556 tensorflow::Status BinaryCompute(const TRT_ShapedWeights& iweights_l, 557 const TRT_ShapedWeights& iweights_r, 558 TRT_ShapedWeights* oweights, 559 LambdaFactory binary_op) { 560 // Assume iweights_l.type == iweight_r.type 561 CHECK_EQ(iweights_l.type_, oweights->type_); 562 CHECK_EQ(iweights_r.type_, oweights->type_); 563 VLOG(2) << "SANITY CHECK!"; 564 565 switch (iweights_l.type_) { 566 case tensorflow::DataType::DT_FLOAT: { 567 auto inp_l = static_cast<const float*>(iweights_l.GetValues()); 568 auto inp_r = static_cast<const float*>(iweights_r.GetValues()); 569 auto oup = static_cast<float*>(const_cast<void*>(oweights->GetValues())); 570 571 if (iweights_l.count() != iweights_r.count()) { 572 // We only supports broadcast of RankZero 573 if (iweights_l.count() == 1) { 574 VLOG(2) << "I bet it is not working!" << (*inp_l); 575 std::transform(inp_r, inp_r + iweights_r.count(), oup, 576 binary_op.broadcast_l<float>(*inp_l)); 577 } else if (iweights_r.count() == 1) { 578 VLOG(2) << "I bet it is not working!" << (*inp_r); 579 std::transform(inp_l, inp_l + iweights_l.count(), oup, 580 binary_op.broadcast_r<float>(*inp_r)); 581 } else { 582 return tensorflow::errors::Unimplemented( 583 "Binary op with non-rankZero broadcast not supported"); 584 } 585 } else { 586 std::transform(inp_l, inp_l + iweights_l.count(), inp_r, oup, 587 binary_op.binary<float>()); 588 } 589 break; 590 } 591 default: 592 return tensorflow::errors::Unimplemented( 593 "Data type not supported: " + 594 tensorflow::DataTypeString(iweights_l.type_)); 595 } 596 597 return tensorflow::Status::OK(); 598 } 599 600 tensorflow::Status ConstantFoldUnary( 601 Converter& ctx, const tensorflow::NodeDef& node_def, 602 std::vector<TRT_TensorOrWeights> const& inputs, 603 std::vector<TRT_TensorOrWeights>* outputs) { 604 TRT_ShapedWeights weights_input = inputs.at(0).weights(); 605 606 // Allocate output weights 607 TRT_ShapedWeights weights_output = ctx.get_temp_weights_like(weights_input); 608 609 // FIXME assume type matches input weights 610 // Get trt type & shape 611 // Maybe this part has to be moved into the block of rsqrt later 612 // Check type consistency 613 CHECK_EQ(weights_input.type_, 614 TFAttrs(node_def).get<tensorflow::DataType>("T")); 615 616 // Maybe I should do a switch 617 LambdaFactory unary_op; 618 if (node_def.op() == "Rsqrt") { 619 // Compute rsqrt 620 unary_op.op = LambdaFactory::OP_CATEGORY::RSQRT; 621 auto ret = UnaryCompute(weights_input, &weights_output, unary_op); 622 // PAss the output 623 if (ret == tensorflow::Status::OK()) { 624 outputs->push_back(TRT_TensorOrWeights(weights_output)); 625 } 626 return ret; 627 } else { 628 return tensorflow::errors::Unimplemented("Binary op not supported: " + 629 node_def.op()); 630 } 631 } 632 633 // TODO(jie,ben) broadcast is needed yet not implemented 634 // Let's get the simple stuff working first. Maybe we should fall bakc to TF 635 // approach for constant folding 636 tensorflow::Status ConstantFoldBinary( 637 Converter& ctx, const tensorflow::NodeDef& node_def, 638 std::vector<TRT_TensorOrWeights> const& inputs, 639 std::vector<TRT_TensorOrWeights>* outputs) { 640 TRT_ShapedWeights weights_input_l = inputs.at(0).weights(); 641 TRT_ShapedWeights weights_input_r = inputs.at(1).weights(); 642 643 // Check type consistency 644 CHECK_EQ(weights_input_l.type_, weights_input_r.type_); 645 646 if (weights_input_l.shape_.nbDims != weights_input_r.shape_.nbDims) 647 return tensorflow::errors::Unimplemented( 648 "Binary op implicit broadcast not supported: " + node_def.op()); 649 650 // TODO(jie): constant fold should really fall back to TF. 651 int nb_dims = weights_input_l.shape_.nbDims; 652 nvinfer1::Dims output_shape; 653 output_shape.nbDims = nb_dims; 654 VLOG(2) << "nb_dims: " << nb_dims 655 << ", the other: " << weights_input_r.shape_.nbDims; 656 for (int i = 0; i < nb_dims; i++) { 657 if (weights_input_l.shape_.d[i] == weights_input_r.shape_.d[i]) { 658 output_shape.d[i] = weights_input_l.shape_.d[i]; 659 } else if (weights_input_l.shape_.d[i] == 1 || 660 weights_input_r.shape_.d[i] == 1) { 661 output_shape.d[i] = 662 std::max(weights_input_l.shape_.d[i], weights_input_r.shape_.d[i]); 663 } else { 664 return tensorflow::errors::Unimplemented( 665 "Binary op with incompatible shape at, " + node_def.op()); 666 } 667 VLOG(2) << "left: " << weights_input_l.shape_.d[i] 668 << "right: " << weights_input_r.shape_.d[i] 669 << "output: " << output_shape.d[i]; 670 } 671 672 // FIXME assume type matches input weights 673 // Get trt type & shape 674 TFAttrs attrs(node_def); 675 // Maybe this part has to be moved into the block of rsqrt later 676 tensorflow::DataType dtype = attrs.get<tensorflow::DataType>("T"); 677 678 // Allocate output weights 679 TRT_ShapedWeights weights_output = ctx.get_temp_weights(dtype, output_shape); 680 681 // Maybe I should do a switch 682 LambdaFactory binary_op; 683 if (node_def.op() == "Sub") { 684 binary_op.op = LambdaFactory::OP_CATEGORY::SUB; 685 } else if (node_def.op() == "Mul") { 686 binary_op.op = LambdaFactory::OP_CATEGORY::MUL; 687 } else if (node_def.op() == "Add") { 688 binary_op.op = LambdaFactory::OP_CATEGORY::ADD; 689 } else { 690 return tensorflow::errors::Unimplemented("Binary op not supported: " + 691 node_def.op()); 692 } 693 auto ret = BinaryCompute(weights_input_l, weights_input_r, &weights_output, 694 binary_op); 695 696 // Pass the output 697 if (ret == tensorflow::Status::OK()) { 698 outputs->push_back(TRT_TensorOrWeights(weights_output)); 699 } 700 701 return ret; 702 } 703 704 // TODO(jie): broadcast is needed yet not implemented. 705 // Only implemented channel wise for the time being 706 tensorflow::Status BinaryTensorOpWeight( 707 Converter& ctx, const tensorflow::NodeDef& node_def, 708 const nvinfer1::ITensor* tensor, TRT_ShapedWeights weights, 709 std::vector<TRT_TensorOrWeights>* outputs) { 710 // FIXME assume type matches input weights 711 // Get trt type & shape 712 // Maybe this part has to be moved into the block of rsqrt later 713 714 // Check type consistency 715 auto dtype = TFAttrs(node_def).get<nvinfer1::DataType>("T"); 716 CHECK_EQ_TYPE(tensor->getType(), dtype); // Cast to int for error messages 717 nvinfer1::DataType ttype; 718 TF_CHECK_OK(ConvertDType(weights.type_, &ttype)); 719 CHECK_EQ_TYPE(ttype, dtype); // Cast to int for error message 720 721 // Check scale mode 722 auto dims_w = weights.shape_; 723 auto dims_t = tensor->getDimensions(); 724 725 // Default to channel-wise 726 auto scale_mode = nvinfer1::ScaleMode::kELEMENTWISE; 727 728 if (weights.count() == 1) { 729 VLOG(2) << "UNIFORM"; 730 scale_mode = nvinfer1::ScaleMode::kUNIFORM; 731 } else { 732 // No broadcasting on Batch dimension; 733 assert(dims_w.d[0] == 1); 734 735 // Broadcasting on Channel dimension only allowed in kUNIFORM 736 assert(dims_w.d[1] == dims_t.d[0]); 737 assert(dims_w.nbDims == dims_t.nbDims); 738 739 // Default is element; 740 for (int i = 2; i < dims_w.nbDims; i++) { 741 if (dims_w.d[i] != dims_t.d[i - 1]) { 742 scale_mode = nvinfer1::ScaleMode::kCHANNEL; 743 break; 744 } 745 } 746 if (scale_mode == nvinfer1::ScaleMode::kELEMENTWISE) { 747 scale_mode = nvinfer1::ScaleMode::kELEMENTWISE; 748 for (int i = 2; i < dims_w.nbDims; i++) { 749 if (dims_w.d[i] != 1) 750 return tensorflow::errors::InvalidArgument( 751 "Weight shape not compatible at, " + node_def.name()); 752 } 753 } 754 } 755 756 // Prepare weights 757 TRT_ShapedWeights shift_weights(weights.type_); 758 TRT_ShapedWeights scale_weights(weights.type_); 759 TRT_ShapedWeights power_weights(weights.type_); 760 761 // Maybe I should do a switch 762 if (node_def.op() == "Sub") { 763 TRT_ShapedWeights neg_weights = ctx.get_temp_weights_like(weights); 764 LambdaFactory unary_op; 765 unary_op.op = LambdaFactory::OP_CATEGORY::NEG; 766 TF_RETURN_IF_ERROR(UnaryCompute(weights, &neg_weights, unary_op)); 767 shift_weights = neg_weights; 768 } else if (node_def.op() == "Mul") { 769 scale_weights = weights; 770 } else if (node_def.op() == "Add") { 771 shift_weights = weights; 772 } else { 773 return tensorflow::errors::Unimplemented("Binary op not supported: " + 774 node_def.op()); 775 } 776 777 nvinfer1::IScaleLayer* layer = ctx.network()->addScale( 778 *const_cast<nvinfer1::ITensor*>(tensor), scale_mode, shift_weights, 779 scale_weights, power_weights); 780 781 nvinfer1::ITensor* output_tensor = layer->getOutput(0); 782 783 // Pass the output 784 outputs->push_back(TRT_TensorOrWeights(output_tensor)); 785 return tensorflow::Status::OK(); 786 } 787 788 tensorflow::Status BinaryTensorOpTensor( 789 Converter& ctx, const tensorflow::NodeDef& node_def, 790 const nvinfer1::ITensor* tensor_l, const nvinfer1::ITensor* tensor_r, 791 std::vector<TRT_TensorOrWeights>* outputs) { 792 static const std::unordered_map<string, nvinfer1::ElementWiseOperation> ops{ 793 {"Add", nvinfer1::ElementWiseOperation::kSUM}, 794 {"Mul", nvinfer1::ElementWiseOperation::kPROD}, 795 // {"max", nvinfer1::ElementWiseOperation::kMAX}, 796 // {"min", nvinfer1::ElementWiseOperation::kMIN}, 797 {"Sub", nvinfer1::ElementWiseOperation::kSUB}, 798 {"Div", nvinfer1::ElementWiseOperation::kDIV}, 799 }; 800 801 // FIXME assume type matches input weights 802 // Get trt type & shape 803 TFAttrs attrs(node_def); 804 // Maybe this part has to be moved into the block of rsqrt later 805 nvinfer1::DataType dtype = attrs.get<nvinfer1::DataType>("T"); 806 807 // Check type consistency 808 CHECK_EQ_TYPE(tensor_l->getType(), dtype); 809 CHECK_EQ_TYPE(tensor_r->getType(), dtype); 810 auto op_pair = ops.find(node_def.op()); 811 if (op_pair == ops.end()) 812 return tensorflow::errors::Unimplemented("binary op: " + node_def.op() + 813 " not supported at: " + 814 node_def.name()); 815 816 nvinfer1::IElementWiseLayer* layer = ctx.network()->addElementWise( 817 *const_cast<nvinfer1::ITensor*>(tensor_l), 818 *const_cast<nvinfer1::ITensor*>(tensor_r), op_pair->second); 819 820 nvinfer1::ITensor* output_tensor = layer->getOutput(0); 821 822 // Pass the output 823 outputs->push_back(TRT_TensorOrWeights(output_tensor)); 824 return tensorflow::Status::OK(); 825 } 826 827 tensorflow::Status ConvertPlaceholder( 828 Converter& ctx, const tensorflow::NodeDef& node_def, 829 std::vector<TRT_TensorOrWeights> const& inputs, 830 std::vector<TRT_TensorOrWeights>* outputs) { 831 VLOG(2) << "Placeholder should have been replace already"; 832 return tensorflow::errors::Unimplemented(", cannot convert Placeholder op"); 833 // OK this make sense since we are supposed to replace it with input 834 TFAttrs attrs(node_def); 835 nvinfer1::DataType dtype = attrs.get<nvinfer1::DataType>("dtype"); 836 nvinfer1::Dims dims = attrs.get<nvinfer1::Dims>("shape"); 837 838 dims.nbDims--; 839 for (int i = 0; i < dims.nbDims; i++) dims.d[i] = dims.d[i + 1]; 840 841 nvinfer1::ITensor* output = 842 ctx.network()->addInput(node_def.name().c_str(), dtype, dims); 843 if (!output) { 844 return tensorflow::errors::InvalidArgument("Failed to create Input layer"); 845 } 846 outputs->push_back(TRT_TensorOrWeights(output)); 847 return tensorflow::Status::OK(); 848 } 849 850 tensorflow::Status ConvertConv2D(Converter& ctx, 851 const tensorflow::NodeDef& node_def, 852 const std::vector<TRT_TensorOrWeights>& inputs, 853 std::vector<TRT_TensorOrWeights>* outputs) { 854 nvinfer1::ITensor const* tensor = inputs.at(0).tensor(); 855 // TODO(jie): handle NHWC/NCHW transpose; 856 TRT_ShapedWeights weights_rsck = inputs.at(1).weights(); 857 TRT_ShapedWeights weights = ctx.get_temp_weights_like(weights_rsck); 858 ReorderRSCKToKCRS(weights_rsck, &weights); 859 TRT_ShapedWeights biases(weights.type_); 860 int noutput = weights.shape_.d[0]; 861 nvinfer1::DimsHW kernel_size; 862 kernel_size.h() = weights.shape_.d[2]; 863 kernel_size.w() = weights.shape_.d[3]; 864 TFAttrs attrs(node_def); 865 866 int h_index = 2; 867 int w_index = 3; 868 auto data_format = attrs.get<string>("data_format"); 869 if (data_format == "NHWC") { 870 tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor), 871 {0, 3, 1, 2}); 872 h_index = 1; 873 w_index = 2; 874 // TODO(jie): transpose it 875 } 876 877 // TODO(jie): stride. (NHWC/NCHW) 878 auto tf_stride = attrs.get<std::vector<int>>("strides"); 879 nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]); 880 881 auto tensor_dim = tensor->getDimensions(); 882 std::vector<std::pair<int, int>> padding; 883 // TODO(jie): padding. 884 if (attrs.get<string>("padding") == "SAME") { 885 // This is NCHW tensor with no batch dimension. 886 // 1 -> h 887 // 2 -> w 888 padding = CreateSamePadding( 889 stride, kernel_size, 890 {static_cast<int>(tensor_dim.d[1]), static_cast<int>(tensor_dim.d[2])}); 891 } else { 892 padding = {{0, 0}, {0, 0}}; 893 } 894 895 if (padding[0].first != padding[0].second || 896 padding[1].first != padding[1].second) { 897 // TODO(jie): handle asymmetric padding 898 VLOG(2) << "Padding!!!: " << padding[0].first << padding[0].second 899 << padding[1].first << padding[1].second; 900 901 auto dim_before = tensor->getDimensions(); 902 VLOG(2) << "TENSOR before: " << dim_before.d[0] << ", " << dim_before.d[1] 903 << dim_before.d[2] << ", " << dim_before.d[3]; 904 auto pad_layer = ctx.network()->addPadding( 905 *const_cast<nvinfer1::ITensor*>(tensor), 906 nvinfer1::DimsHW(padding[0].first, padding[1].first), 907 nvinfer1::DimsHW(padding[0].second, padding[1].second)); 908 padding = {{0, 0}, {0, 0}}; 909 tensor = pad_layer->getOutput(0); 910 auto dim_after = tensor->getDimensions(); 911 VLOG(2) << "TENSOR after: " << dim_after.d[0] << ", " << dim_after.d[1] 912 << dim_after.d[2] << ", " << dim_after.d[3]; 913 } 914 915 nvinfer1::IConvolutionLayer* layer = 916 ctx.network()->addConvolution(*const_cast<nvinfer1::ITensor*>(tensor), 917 noutput, kernel_size, weights, biases); 918 919 layer->setStride(stride); 920 layer->setPadding({padding[0].first, padding[1].first}); 921 layer->setName(node_def.name().c_str()); 922 nvinfer1::ITensor* output_tensor = layer->getOutput(0); 923 924 auto dim_after = output_tensor->getDimensions(); 925 VLOG(2) << "TENSOR out: " << dim_after.d[0] << ", " << dim_after.d[1] 926 << dim_after.d[2] << ", " << dim_after.d[3]; 927 928 if (data_format == "NHWC") { 929 // TODO(jie): transpose it back! 930 output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1}); 931 } else { 932 VLOG(2) << "NCHW !!!!"; 933 } 934 outputs->push_back(TRT_TensorOrWeights(output_tensor)); 935 return tensorflow::Status::OK(); 936 } 937 938 tensorflow::Status ConvertPool(Converter& ctx, 939 const tensorflow::NodeDef& node_def, 940 std::vector<TRT_TensorOrWeights> const& inputs, 941 std::vector<TRT_TensorOrWeights>* outputs) { 942 nvinfer1::ITensor const* tensor = inputs.at(0).tensor(); 943 TFAttrs attrs(node_def); 944 945 int h_index = 2; 946 int w_index = 3; 947 auto data_format = attrs.get<string>("data_format"); 948 if (data_format == "NHWC") { 949 h_index = 1; 950 w_index = 2; 951 tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor), 952 {0, 3, 1, 2}); 953 } else { 954 VLOG(2) << "NCHW !!!!"; 955 } 956 nvinfer1::PoolingType type; 957 // TODO(jie): support other pooling type 958 if (node_def.op() == "MaxPool") 959 type = nvinfer1::PoolingType::kMAX; 960 else 961 return tensorflow::errors::Unimplemented("Only supports Max pool"); 962 963 // TODO(jie): NCHW 964 auto tf_stride = attrs.get<std::vector<int>>("strides"); 965 nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]); 966 967 auto tf_kernel = attrs.get<std::vector<int>>("ksize"); 968 nvinfer1::DimsHW ksize(tf_kernel[h_index], tf_kernel[w_index]); 969 970 auto tensor_dim = tensor->getDimensions(); 971 std::vector<std::pair<int, int>> padding; 972 // TODO(jie): padding. 973 if (attrs.get<string>("padding") == "SAME") { 974 // This is NCHW tensor with no batch dimension. 975 // 1 -> h 976 // 2 -> w 977 padding = CreateSamePadding( 978 stride, ksize, 979 {static_cast<int>(tensor_dim.d[1]), static_cast<int>(tensor_dim.d[2])}); 980 } else if (attrs.get<string>("padding") == "VALID") { 981 // No padding for valid padding here 982 VLOG(2) << "No padding added for VALID padding in pool" << node_def.name(); 983 padding = {{0, 0}, {0, 0}}; 984 } else { 985 return tensorflow::errors::Unimplemented( 986 "Current MaxPool cannot support padding other than SAME"); 987 } 988 989 if (padding[0].first != padding[0].second || 990 padding[1].first != padding[1].second) { 991 // TODO(jie): handle asymmetric padding 992 VLOG(2) << "Padding!!!: " << padding[0].first << padding[0].second 993 << padding[1].first << padding[1].second; 994 auto pad_layer = ctx.network()->addPadding( 995 *const_cast<nvinfer1::ITensor*>(tensor), 996 nvinfer1::DimsHW(padding[0].first, padding[1].first), 997 nvinfer1::DimsHW(padding[0].second, padding[1].second)); 998 padding = {{0, 0}, {0, 0}}; 999 tensor = pad_layer->getOutput(0); 1000 } 1001 1002 nvinfer1::IPoolingLayer* layer = ctx.network()->addPooling( 1003 *const_cast<nvinfer1::ITensor*>(tensor), type, ksize); 1004 1005 layer->setStride(stride); 1006 layer->setPadding({padding[0].first, padding[1].first}); 1007 layer->setName(node_def.name().c_str()); 1008 nvinfer1::ITensor* output_tensor = layer->getOutput(0); 1009 1010 if (data_format == "NHWC") { 1011 // TODO(jie): transpose it back! 1012 output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1}); 1013 } else { 1014 VLOG(2) << "NCHW !!!!"; 1015 } 1016 outputs->push_back(TRT_TensorOrWeights(output_tensor)); 1017 return tensorflow::Status::OK(); 1018 } 1019 1020 tensorflow::Status ConvertActivation( 1021 Converter& ctx, const tensorflow::NodeDef& node_def, 1022 std::vector<TRT_TensorOrWeights> const& inputs, 1023 std::vector<TRT_TensorOrWeights>* outputs) { 1024 nvinfer1::ITensor const* tensor = inputs.at(0).tensor(); 1025 nvinfer1::IActivationLayer* layer = ctx.network()->addActivation( 1026 *const_cast<nvinfer1::ITensor*>(tensor), nvinfer1::ActivationType::kRELU); 1027 nvinfer1::ITensor* output_tensor = layer->getOutput(0); 1028 outputs->push_back(TRT_TensorOrWeights(output_tensor)); 1029 return tensorflow::Status::OK(); 1030 } 1031 1032 tensorflow::Status ConvertScale(Converter& ctx, 1033 const tensorflow::NodeDef& node_def, 1034 std::vector<TRT_TensorOrWeights> const& inputs, 1035 std::vector<TRT_TensorOrWeights>* outputs) { 1036 if (inputs.size() != 2 || !inputs.at(0).is_tensor() || 1037 !inputs.at(1).is_weights()) 1038 return tensorflow::errors::Unimplemented( 1039 "Only supports tensor op weight for now, at " + node_def.name()); 1040 // Implement tensor binaryOp weight [channel wise] for now; 1041 nvinfer1::ITensor const* tensor = inputs.at(0).tensor(); 1042 1043 // TODO(jie): handle NHWC/NCHW transpose; 1044 TRT_ShapedWeights weights = inputs.at(1).weights(); 1045 TRT_ShapedWeights empty_weights(weights.type_); 1046 1047 TFAttrs attrs(node_def); 1048 1049 // Transpose NHWC 1050 auto data_format = attrs.get<string>("data_format"); 1051 if (data_format == "NHWC") { 1052 tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor), 1053 {0, 3, 1, 2}); 1054 // TODO(jie): transpose it 1055 } else { 1056 VLOG(2) << "NCHW !!!!"; 1057 } 1058 nvinfer1::IScaleLayer* layer = ctx.network()->addScale( 1059 *const_cast<nvinfer1::ITensor*>(tensor), nvinfer1::ScaleMode::kCHANNEL, 1060 weights, empty_weights, empty_weights); 1061 1062 nvinfer1::ITensor* output_tensor = layer->getOutput(0); 1063 if (data_format == "NHWC") { 1064 // TODO(jie): transpose it back! 1065 output_tensor = ctx.TransposeTensor(output_tensor, {0, 2, 3, 1}); 1066 } else { 1067 VLOG(2) << "NCHW !!!!"; 1068 } 1069 outputs->push_back(TRT_TensorOrWeights(output_tensor)); 1070 return tensorflow::Status::OK(); 1071 } 1072 1073 tensorflow::Status ConvertConst(Converter& ctx, 1074 const tensorflow::NodeDef& node_def, 1075 std::vector<TRT_TensorOrWeights> const& inputs, 1076 std::vector<TRT_TensorOrWeights>* outputs) { 1077 const auto& weights_tensor = node_def.attr().at("value").tensor(); 1078 1079 // Get trt type & shape 1080 TFAttrs attrs(node_def); 1081 const tensorflow::DataType dtype = attrs.get<tensorflow::DataType>("dtype"); 1082 1083 // Create shaped weights as output 1084 tensorflow::Tensor tensor; 1085 if (!tensor.FromProto(weights_tensor)) 1086 return tensorflow::errors::Internal("Cannot parse weight tensor proto: " + 1087 node_def.name()); 1088 1089 TRT_ShapedWeights weights(dtype); 1090 if (!weights_tensor.float_val().empty()) { 1091 VLOG(2) << "SCALAR!!!" << node_def.name(); 1092 nvinfer1::Dims scalar_shape; 1093 if (tensor.dims() > 0) { 1094 VLOG(2) << "Dimensions: " << tensor.dims(); 1095 weights = TRT_ShapedWeights(dtype, weights_tensor.float_val().data(), 1096 GetTensorShape(tensor)); 1097 } else { 1098 VLOG(2) << "Dimensions: " << tensor.dims(); 1099 scalar_shape.nbDims = 1; 1100 scalar_shape.d[0] = 1; 1101 scalar_shape.type[0] = nvinfer1::DimensionType::kSPATIAL; 1102 for (int i = 1; i < nvinfer1::Dims::MAX_DIMS; i++) { 1103 scalar_shape.d[i] = 0; 1104 scalar_shape.type[i] = nvinfer1::DimensionType::kSPATIAL; 1105 } 1106 weights = TRT_ShapedWeights(dtype, weights_tensor.float_val().data(), 1107 scalar_shape); 1108 } 1109 } else if (!weights_tensor.tensor_content().empty()) { 1110 VLOG(2) << "TENSOR!!!" << node_def.name(); 1111 const auto& content = weights_tensor.tensor_content(); 1112 1113 weights = ctx.get_temp_weights(dtype, GetTensorShape(tensor)); 1114 if (content.size() > 0) { 1115 const int dtype_size = tensorflow::DataTypeSize(dtype); 1116 CHECK_EQ(0, content.size() % dtype_size) 1117 << "Tensor content size (" << content.size() 1118 << ") is not a multiple of " << dtype_size; 1119 port::CopyToArray( 1120 content, static_cast<char*>(const_cast<void*>(weights.GetValues()))); 1121 } 1122 } else { 1123 return tensorflow::errors::Unimplemented( 1124 "Not supported constant type, at " + node_def.name()); 1125 } 1126 // Pass the output 1127 outputs->push_back(TRT_TensorOrWeights(weights)); 1128 return tensorflow::Status::OK(); 1129 } 1130 1131 tensorflow::Status ConvertIdentity( 1132 Converter& ctx, const tensorflow::NodeDef& node_def, 1133 std::vector<TRT_TensorOrWeights> const& inputs, 1134 std::vector<TRT_TensorOrWeights>* outputs) { 1135 outputs->push_back(inputs.at(0)); 1136 return tensorflow::Status::OK(); 1137 } 1138 1139 tensorflow::Status ConvertBinary(Converter& ctx, 1140 const tensorflow::NodeDef& node_def, 1141 std::vector<TRT_TensorOrWeights> const& inputs, 1142 std::vector<TRT_TensorOrWeights>* outputs) { 1143 if (inputs.size() != 2) 1144 return tensorflow::errors::FailedPrecondition( 1145 "Binary ops require two tensor input, at " + node_def.name()); 1146 1147 if (inputs.at(0).is_weights() && inputs.at(1).is_weights()) 1148 return ConstantFoldBinary(ctx, node_def, inputs, outputs); 1149 1150 if (inputs.at(0).is_tensor() && inputs.at(1).is_weights()) 1151 return BinaryTensorOpWeight(ctx, node_def, inputs.at(0).tensor(), 1152 inputs.at(1).weights(), outputs); 1153 1154 if (inputs.at(0).is_weights() && inputs.at(1).is_tensor()) 1155 return BinaryTensorOpWeight(ctx, node_def, inputs.at(1).tensor(), 1156 inputs.at(0).weights(), outputs); 1157 1158 if (inputs.at(0).is_tensor() && inputs.at(1).is_tensor()) 1159 return BinaryTensorOpTensor(ctx, node_def, inputs.at(0).tensor(), 1160 inputs.at(1).tensor(), outputs); 1161 1162 return tensorflow::errors::Unknown("Binary op input error, at " + 1163 node_def.name()); 1164 } 1165 1166 tensorflow::Status ConvertUnary(Converter& ctx, 1167 const tensorflow::NodeDef& node_def, 1168 std::vector<TRT_TensorOrWeights> const& inputs, 1169 std::vector<TRT_TensorOrWeights>* outputs) { 1170 if (inputs.size() != 1) 1171 return tensorflow::errors::FailedPrecondition( 1172 "Unary ops require single tensor input, at " + node_def.name()); 1173 1174 if (inputs.at(0).is_weights()) 1175 return ConstantFoldUnary(ctx, node_def, inputs, outputs); 1176 else if (inputs.at(0).is_tensor()) 1177 return tensorflow::errors::Unimplemented( 1178 "Unary op for tensor not supported, at " + node_def.name()); 1179 1180 return tensorflow::errors::Unknown("Binary op input error, at " + 1181 node_def.name()); 1182 } 1183 1184 tensorflow::Status ConvertReduce(Converter& ctx, 1185 const tensorflow::NodeDef& node_def, 1186 std::vector<TRT_TensorOrWeights> const& inputs, 1187 std::vector<TRT_TensorOrWeights>* outputs) { 1188 if (inputs.size() != 2 || !inputs.at(0).is_tensor() || 1189 !inputs.at(1).is_weights()) 1190 return tensorflow::errors::InvalidArgument( 1191 "Input expects tensor and weights, at" + node_def.name()); 1192 1193 // Implement tensor binaryOp weight [channel wise] for now; 1194 nvinfer1::ITensor const* tensor = inputs.at(0).tensor(); 1195 auto dims = tensor->getDimensions(); 1196 // Restore implicit batch dimension 1197 int nb_dims = dims.nbDims + 1; 1198 1199 TRT_ShapedWeights index_list = inputs.at(1).weights(); 1200 1201 TFAttrs attrs(node_def); 1202 // TODO(jie): handle data type. 1203 // Index type here is done through TF type, so I can leverage their 1204 // EnumToDataType for my cast 1205 auto index_type = attrs.get<tensorflow::DataType>("Tidx"); 1206 1207 // Only expect to handle INT32 as attributes for now 1208 if (index_type != tensorflow::DataType::DT_INT32) 1209 return tensorflow::errors::Unimplemented("Tidx supports only DT_INT32"); 1210 auto index_list_data = 1211 static_cast<int*>(const_cast<void*>(index_list.GetValues())); 1212 1213 // Hack warning: have to fall back to pool layer since reduce is not in public 1214 // TRT yet. 1215 if (nb_dims != 4) 1216 return tensorflow::errors::InvalidArgument( 1217 "TRT only support reduce on 4 dimensional tensors, at" + 1218 node_def.name()); 1219 if (index_list.count() > 2) 1220 return tensorflow::errors::InvalidArgument( 1221 "TRT cannot support reduce on more than 2 dimensions, at" + 1222 node_def.name()); 1223 1224 std::set<int> idx_set; 1225 // We cannot operate on Channel. permutation flag used to transpose tensor 1226 int permuted_index = -1; 1227 for (int i = 0; i < index_list.count(); i++) { 1228 if (index_list_data[i] == 0) 1229 return tensorflow::errors::InvalidArgument("TRT cannot reduce at 0, at" + 1230 node_def.name()); 1231 if (index_list_data[i] == 1) permuted_index = 1; 1232 idx_set.emplace(index_list_data[i]); 1233 } 1234 1235 std::vector<int> permutation_order(nb_dims); 1236 nvinfer1::DimsHW pool_kernel; 1237 if (permuted_index == 1) { 1238 for (int i = 2; i < nb_dims; i++) { 1239 if (idx_set.count(i)) { 1240 permuted_index = i; 1241 break; 1242 } 1243 } 1244 for (int i = 0; i < nb_dims; i++) permutation_order[i] = i; 1245 1246 permutation_order[permuted_index] = 1; 1247 permutation_order[1] = permuted_index; 1248 1249 // Apply permutation before extracting dimension for pool_kernel 1250 tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor), 1251 permutation_order); 1252 } 1253 1254 // Apply permutation before extracting dimension for pool_kernel 1255 pool_kernel.d[0] = (idx_set.count(2) || permuted_index == 2) ? dims.d[1] : 1; 1256 pool_kernel.d[1] = (idx_set.count(3) || permuted_index == 3) ? dims.d[2] : 1; 1257 1258 nvinfer1::ITensor* output_tensor; 1259 1260 if (node_def.op() == "Mean") { 1261 nvinfer1::IPoolingLayer* layer = 1262 ctx.network()->addPooling(*const_cast<nvinfer1::ITensor*>(tensor), 1263 nvinfer1::PoolingType::kAVERAGE, pool_kernel); 1264 output_tensor = layer->getOutput(0); 1265 } else { 1266 return tensorflow::errors::Unimplemented( 1267 "Op not supported " + node_def.op() + " , at " + node_def.name()); 1268 } 1269 if (permuted_index != -1) { 1270 // Apply permutation before extracting dimension for pool_kernel 1271 output_tensor = ctx.TransposeTensor( 1272 const_cast<nvinfer1::ITensor*>(output_tensor), permutation_order); 1273 } 1274 return tensorflow::Status::OK(); 1275 } 1276 1277 tensorflow::Status ConvertPad(Converter& ctx, 1278 const tensorflow::NodeDef& node_def, 1279 std::vector<TRT_TensorOrWeights> const& inputs, 1280 std::vector<TRT_TensorOrWeights>* outputs) { 1281 if (inputs.size() != 2 || !inputs.at(0).is_tensor() || 1282 !inputs.at(1).is_weights()) 1283 return tensorflow::errors::InvalidArgument( 1284 "Input expects tensor and weights, at" + node_def.name()); 1285 1286 // Implement tensor binaryOp weight [channel wise] for now; 1287 nvinfer1::ITensor const* tensor = inputs.at(0).tensor(); 1288 auto dims = tensor->getDimensions(); 1289 // Restore implicit batch dimension 1290 int nb_dims = dims.nbDims + 1; 1291 1292 TRT_ShapedWeights pads = inputs.at(1).weights(); 1293 1294 TFAttrs attrs(node_def); 1295 // Padding type here is done through TF type 1296 // so I can leverage their EnumToDataType for my cast 1297 auto padding_type = attrs.get<tensorflow::DataType>("Tpaddings"); 1298 // TODO(jie): handle data type conversion for TRT? 1299 1300 if (pads.shape_.d[0] != nb_dims || pads.shape_.d[1] != 2) 1301 return tensorflow::errors::InvalidArgument( 1302 "Pad only supports explicit padding on 4 dimensional tensor, at " + 1303 node_def.name()); 1304 1305 // Only expect to handle INT32 as attributes for now 1306 if (padding_type != tensorflow::DataType::DT_INT32) 1307 return tensorflow::errors::Unimplemented( 1308 "Tpaddings supports only DT_INT32"); 1309 auto pad_data = static_cast<int*>(const_cast<void*>(pads.GetValues())); 1310 1311 std::vector<int32_t> pad_index; 1312 for (int i = 0; i < nb_dims; i++) { 1313 if (pad_data[2 * i] != 0 || pad_data[2 * i + 1] != 0) 1314 pad_index.push_back(i); 1315 } 1316 1317 // No padding at all, we should exit 1318 if (pad_index.size() == 0) { 1319 outputs->push_back(inputs.at(0)); 1320 return tensorflow::Status::OK(); 1321 } 1322 1323 // Only supports padding on less than 2 axis GIE-2579 1324 if (pad_index.size() > 2) 1325 return tensorflow::errors::InvalidArgument( 1326 "Padding layer does not support padding on > 2"); 1327 1328 // Padding on batch dimension is not supported 1329 if (pad_index[0] == 0) 1330 return tensorflow::errors::InvalidArgument( 1331 "Padding layer does not support padding on batch dimension"); 1332 1333 // Not doing the legit thing here. ignoring padding on dim 1 and 3; 1334 // TODO(jie): implement pad as uff parser 1335 if (pad_index.size() == 2 && pad_index[0] == 0 && pad_index[1] == 3) 1336 return tensorflow::errors::Unimplemented( 1337 "Padding layer does not support padding on dimension 1 and 3 yet"); 1338 1339 bool legit_pad = true; 1340 nvinfer1::DimsHW pre_padding(0, 0); 1341 nvinfer1::DimsHW post_padding(0, 0); 1342 1343 std::vector<int32_t> permuted_pad_index(pad_index); 1344 if (pad_index[0] == 1) { 1345 legit_pad = false; 1346 tensor = ctx.TransposeTensor(const_cast<nvinfer1::ITensor*>(tensor), 1347 {0, 3, 2, 1}); 1348 permuted_pad_index[0] = 3; 1349 } 1350 1351 for (size_t i = 0; i < pad_index.size(); i++) { 1352 int index = pad_index[i]; 1353 if (permuted_pad_index[i] == 2) { 1354 pre_padding.h() = pad_data[index * 2]; 1355 post_padding.h() = pad_data[index * 2 + 1]; 1356 } else if (permuted_pad_index[i] == 3) { 1357 pre_padding.w() = pad_data[index * 2]; 1358 post_padding.w() = pad_data[index * 2 + 1]; 1359 } 1360 } 1361 1362 nvinfer1::IPaddingLayer* layer = ctx.network()->addPadding( 1363 *const_cast<nvinfer1::ITensor*>(tensor), pre_padding, post_padding); 1364 nvinfer1::ITensor* output_tensor = layer->getOutput(0); 1365 1366 if (!legit_pad) 1367 output_tensor = ctx.TransposeTensor( 1368 const_cast<nvinfer1::ITensor*>(output_tensor), {0, 3, 2, 1}); 1369 1370 outputs->push_back(TRT_TensorOrWeights(output_tensor)); 1371 return tensorflow::Status::OK(); 1372 } 1373 1374 void Converter::register_op_converters() { 1375 // vgg_16 slim implementation 1376 op_registry_["Placeholder"] = ConvertPlaceholder; 1377 op_registry_["Conv2D"] = ConvertConv2D; 1378 op_registry_["Relu"] = ConvertActivation; 1379 op_registry_["MaxPool"] = ConvertPool; 1380 // This could be really handled as ConvertBinary 1381 op_registry_["BiasAdd"] = ConvertScale; 1382 op_registry_["Const"] = ConvertConst; 1383 // op_registry_["MatMul"] = ConvertFullyConnected; // Not used in vgg 1384 // TODO(ben,jie): this is a temp hack. 1385 op_registry_["Identity"] = ConvertIdentity; // Identity should be removed 1386 // op_registry_["AvgPool"] = ConvertPool; 1387 1388 // resnet_50_v1 slim implementation 1389 op_registry_["Add"] = ConvertBinary; 1390 op_registry_["Mul"] = ConvertBinary; 1391 op_registry_["Sub"] = ConvertBinary; 1392 op_registry_["Rsqrt"] = ConvertUnary; 1393 op_registry_["Mean"] = ConvertReduce; 1394 op_registry_["Pad"] = ConvertPad; 1395 // TODO(ben,jie): Add more ops 1396 } 1397 1398 } // namespace 1399 1400 tensorflow::Status ConvertSubGraphToTensorRTNodeDef( 1401 const tensorflow::Graph& graph, const std::set<int>& subgraph_node_ids, 1402 const std::vector<std::pair<int, int>>& input_inds, 1403 const std::vector<std::pair<int, int>>& output_inds, size_t max_batch_size, 1404 size_t max_workspace_size_bytes, 1405 const tensorflow::grappler::GraphProperties& graph_properties, 1406 tensorflow::NodeDef* trt_node) { 1407 // Visit nodes in reverse topological order and construct the TRT network. 1408 1409 // Toposort 1410 std::vector<tensorflow::Node*> order_vec; 1411 tensorflow::GetPostOrder(graph, &order_vec); 1412 // Select just the subgraph 1413 std::list<tensorflow::Node*> order; 1414 for (tensorflow::Node* node : order_vec) { 1415 if (subgraph_node_ids.count(node->id())) { 1416 // We want topological order to contstruct the 1417 // network layer by layer 1418 order.push_front(node); 1419 } 1420 } 1421 // Topological order is needed to build TRT network 1422 1423 tensorflow::tensorrt::Logger trt_logger; 1424 1425 auto trt_builder = infer_object(nvinfer1::createInferBuilder(trt_logger)); 1426 if (!trt_builder) { 1427 return tensorflow::errors::Internal( 1428 "Failed to create TensorRT builder object"); 1429 } 1430 1431 auto trt_network = infer_object(trt_builder->createNetwork()); 1432 if (!trt_network) { 1433 return tensorflow::errors::Internal( 1434 "Failed to create TensorRT network object"); 1435 } 1436 1437 // Build the network 1438 Converter converter(trt_network.get()); 1439 1440 std::vector<string> input_names; 1441 std::vector<tensorflow::DataType> input_dtypes; 1442 for (std::pair<int, int> const& input : input_inds) { 1443 int node_id = input.first; 1444 int output_idx = input.second; 1445 tensorflow::Node* node = graph.FindNodeId(node_id); 1446 auto node_name = node->name(); 1447 input_names.push_back(node_name); // Insert original node name without port 1448 // TODO(jie): alternative :) 1449 if (!graph_properties.HasOutputProperties(node_name)) 1450 return tensorflow::errors::Internal("Failed to find input node: " + 1451 node_name); 1452 1453 auto op_info_vec = graph_properties.GetOutputProperties(node_name); 1454 if (static_cast<int>(op_info_vec.size()) < output_idx) 1455 return tensorflow::errors::Internal( 1456 "Accessing output index of: " + std::to_string(output_idx) + 1457 ", at node: " + node_name + " with output entry from shape_map: " + 1458 std::to_string(op_info_vec.size())); 1459 1460 auto op_info = op_info_vec.at(output_idx); 1461 1462 tensorflow::DataType tf_dtype = op_info.dtype(); 1463 input_dtypes.push_back(tf_dtype); 1464 1465 nvinfer1::DataType dtype(nvinfer1::DataType::kFLOAT); 1466 TF_CHECK_OK(ConvertDType(tf_dtype, &dtype)); 1467 1468 VLOG(2) << "Accessing output index of: " << std::to_string(output_idx) 1469 << ", at node: " << node_name 1470 << " with output entry from shape_map: " 1471 << std::to_string(op_info_vec.size()); 1472 1473 // TODO(ben,jie): update TRT input format/dimension 1474 nvinfer1::DimsCHW input_dim_psuedo_chw; 1475 for (int i = 0; i < 3; i++) input_dim_psuedo_chw.d[i] = 1; 1476 1477 for (int i = 1; i < op_info.shape().dim_size(); i++) { 1478 VLOG(2) << "dimension: " << i 1479 << " , size: " << op_info.shape().dim(i).size(); 1480 input_dim_psuedo_chw.d[i - 1] = op_info.shape().dim(i).size(); 1481 } 1482 1483 // TODO(ben,jie): proper way to restore input tensor name? 1484 auto input_tensor_name = node_name; 1485 if (output_idx != 0) 1486 input_tensor_name = node_name + ":" + std::to_string(output_idx); 1487 1488 nvinfer1::ITensor* input_tensor = converter.network()->addInput( 1489 input_tensor_name.c_str(), dtype, input_dim_psuedo_chw); 1490 1491 if (!input_tensor) 1492 return tensorflow::errors::InvalidArgument( 1493 "Failed to create Input layer"); 1494 VLOG(2) << "Input tensor name :" << input_tensor_name; 1495 1496 if (!converter.insert_input_tensor(input_tensor_name, input_tensor)) 1497 return tensorflow::errors::AlreadyExists( 1498 "Output tensor already exists for op: " + input_tensor_name); 1499 } 1500 1501 VLOG(2) << "Finished sorting"; 1502 1503 for (const tensorflow::Node* node : order) { 1504 const tensorflow::NodeDef& node_def = node->def(); 1505 VLOG(2) << "Converting node: " << node_def.name() << " , " << node_def.op(); 1506 TF_RETURN_IF_ERROR(converter.convert_node(node_def)); 1507 } 1508 1509 VLOG(2) << "Finished conversion"; 1510 1511 // Gather output metadata 1512 std::vector<string> output_names; 1513 std::vector<tensorflow::DataType> output_dtypes; 1514 for (std::pair<int, int> const& output : output_inds) { 1515 int node_id = output.first; 1516 int output_idx = output.second; 1517 tensorflow::Node* node = graph.FindNodeId(node_id); 1518 string op_name = node->name(); 1519 string tensor_name = op_name; 1520 if (output_idx != 0) 1521 tensor_name = tensor_name + ":" + std::to_string(output_idx); 1522 VLOG(2) << "Output tensor name: " << tensor_name; 1523 output_names.push_back(tensor_name); 1524 auto tensor_or_weights = converter.get_tensor(tensor_name); 1525 if (!tensor_or_weights.is_tensor()) { 1526 return tensorflow::errors::InvalidArgument( 1527 "Output node is weights not tensor"); 1528 } 1529 nvinfer1::ITensor* tensor = tensor_or_weights.tensor(); 1530 if (!tensor) { 1531 return tensorflow::errors::NotFound("Output tensor not found: " + 1532 tensor_name); 1533 } 1534 converter.network()->markOutput(*tensor); 1535 tensorflow::DataType tf_dtype = node->output_type(output_idx); 1536 output_dtypes.push_back(tf_dtype); 1537 nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT; 1538 TF_RETURN_IF_ERROR(ConvertDType(tf_dtype, &trt_dtype)); 1539 tensor->setType(trt_dtype); 1540 } 1541 1542 VLOG(2) << "Finished output"; 1543 // TODO(jie): static_id is not thread safe. 1544 static int static_id = 0; 1545 1546 // Build the engine 1547 trt_builder->setMaxBatchSize(max_batch_size); 1548 trt_builder->setMaxWorkspaceSize(max_workspace_size_bytes); 1549 VLOG(0) << "Starting build engine " << static_id; 1550 // TODO(ben,jie): half2 and int8 mode support 1551 string engine_plan_string; 1552 { 1553 auto trt_engine = 1554 infer_object(trt_builder->buildCudaEngine(*converter.network())); 1555 VLOG(0) << "Built network"; 1556 auto engine_plan = infer_object(trt_engine->serialize()); 1557 VLOG(0) << "Serialized engine"; 1558 const char* engine_plan_data = 1559 static_cast<const char*>(engine_plan->data()); 1560 engine_plan_string = 1561 string(engine_plan_data, engine_plan_data + engine_plan->size()); 1562 } 1563 1564 VLOG(0) << "Finished engine"; 1565 1566 // Build the TRT op 1567 // TODO(sami,ben,jie): proper naming! 1568 tensorflow::NodeDefBuilder op_builder( 1569 tensorflow::strings::StrCat("my_trt_op", static_id++), "TRTEngineOp"); 1570 std::vector<tensorflow::NodeDefBuilder::NodeOut> income_edges; 1571 for (size_t i = 0; i < input_names.size(); ++i) { 1572 int output_idx = input_inds.at(i).second; 1573 // We wired up the input here already, it is redundant to do it again in 1574 // ConvertSubGraphToTensorRT(convert_graph.cc) 1575 auto incoming_edge = tensorflow::NodeDefBuilder::NodeOut( 1576 input_names.at(i), output_idx, input_dtypes.at(i)); 1577 income_edges.push_back(incoming_edge); 1578 } 1579 tensorflow::gtl::ArraySlice<tensorflow::NodeDefBuilder::NodeOut> input_list( 1580 income_edges); 1581 op_builder.Input(input_list); 1582 1583 VLOG(0) << "Finished op preparation"; 1584 1585 auto status = op_builder.Attr("serialized_engine", engine_plan_string) 1586 .Attr("input_nodes", input_names) 1587 .Attr("output_nodes", output_names) 1588 .Attr("OutT", output_dtypes) 1589 .Finalize(trt_node); 1590 1591 VLOG(0) << status.ToString() << " finished op building"; 1592 1593 return tensorflow::Status::OK(); 1594 } 1595 1596 } // namespace convert 1597 } // namespace tensorrt 1598 } // namespace tensorflow 1599 1600 #endif // GOOGLE_TENSORRT 1601 #endif // GOOGLE_CUDA 1602