Home | History | Annotate | Download | only in custom_plugin_examples
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_plugin.h"
     17 
     18 #include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h"
     19 #include "tensorflow/contrib/tensorrt/custom_plugin_examples/inc_op_kernel.h"
     20 
     21 #if GOOGLE_CUDA
     22 #if GOOGLE_TENSORRT
     23 
     24 namespace tensorflow {
     25 namespace tensorrt {
     26 
     27 const char* kPluginName = "IncPluginTRT";
     28 
     29 IncOpPlugin* CreateIncPlugin() { return new IncOpPlugin(); }
     30 
     31 IncOpPlugin* CreateIncPluginDeserialize(const void* buffer, size_t length) {
     32   return new IncOpPlugin(buffer, length);
     33 }
     34 
     35 REGISTER_TRT_PLUGIN(kPluginName, CreateIncPluginDeserialize, CreateIncPlugin);
     36 
     37 IncOpPlugin::IncOpPlugin() : plugin_name_(kPluginName) {}
     38 
     39 IncOpPlugin::IncOpPlugin(const void* serialized_data, size_t length)
     40     : PluginTensorRT(serialized_data, length), plugin_name_(kPluginName) {
     41   // account for the consumed pointer.
     42   size_t consumed_data = PluginTensorRT::getSerializationSize();
     43   assert(length - consumed_data >= sizeof(float));
     44   const char* buffer = reinterpret_cast<const char*>(serialized_data);
     45   SetAttribute("inc", buffer + consumed_data, sizeof(float));
     46 }
     47 
     48 bool IncOpPlugin::SetAttribute(const string& key, const void* ptr,
     49                                const size_t size) {
     50   if (strcmp(key.c_str(), "inc") == 0 && size == sizeof(float)) {
     51     StoreAttribute(key, ptr, size);  // save the attribute to own the data;
     52     inc_ = *static_cast<const float*>(ptr);
     53     return true;
     54   }
     55   return false;
     56 }
     57 
     58 bool IncOpPlugin::GetAttribute(const string& key, const void** ptr,
     59                                size_t* size) const {
     60   const auto& iter = attr_map_.find(key);
     61   if (iter != attr_map_.end()) {
     62     *ptr = iter->second.data();
     63     *size = iter->second.size();
     64     return true;
     65   }
     66   return false;
     67 }
     68 
     69 int IncOpPlugin::enqueue(int batch_size, const void* const* inputs,
     70                          void** outputs, void*, cudaStream_t stream) {
     71   int count = 1;
     72   for (int i = 0; i < input_dim_list_[0].nbDims; i++) {
     73     count *= input_dim_list_[0].d[i];
     74   }
     75   count *= batch_size;
     76   const float* input = reinterpret_cast<const float*>(inputs[0]);
     77   float* output = reinterpret_cast<float*>(outputs[0]);
     78   IncrementKernel(input, inc_, output, count, stream);
     79   return 0;
     80 }
     81 
     82 }  // namespace tensorrt
     83 }  // namespace tensorflow
     84 
     85 #endif  // GOOGLE_CUDA
     86 #endif  // GOOGLE_TENSORRT
     87