Home | History | Annotate | Download | only in tensorrt
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/platform/logging.h"
     17 #include "tensorflow/core/platform/test.h"
     18 
     19 #if GOOGLE_CUDA
     20 #if GOOGLE_TENSORRT
     21 #include "cuda/include/cuda.h"
     22 #include "cuda/include/cuda_runtime_api.h"
     23 #include "tensorrt/include/NvInfer.h"
     24 
     25 namespace tensorflow {
     26 namespace {
     27 
     28 class Logger : public nvinfer1::ILogger {
     29  public:
     30   void log(nvinfer1::ILogger::Severity severity, const char* msg) override {
     31     switch (severity) {
     32       case Severity::kINFO:
     33         LOG(INFO) << msg;
     34         break;
     35       case Severity::kWARNING:
     36         LOG(WARNING) << msg;
     37         break;
     38       case Severity::kINTERNAL_ERROR:
     39       case Severity::kERROR:
     40         LOG(ERROR) << msg;
     41         break;
     42       default:
     43         break;
     44     }
     45   }
     46 };
     47 
     48 class ScopedWeights {
     49  public:
     50   ScopedWeights(float value) : value_(value) {
     51     w.type = nvinfer1::DataType::kFLOAT;
     52     w.values = &value_;
     53     w.count = 1;
     54   }
     55   const nvinfer1::Weights& get() { return w; }
     56 
     57  private:
     58   float value_;
     59   nvinfer1::Weights w;
     60 };
     61 
     62 const char* kInputTensor = "input";
     63 const char* kOutputTensor = "output";
     64 
     65 // Creates a network to compute y=2x+3.
     66 nvinfer1::IHostMemory* CreateNetwork() {
     67   Logger logger;
     68   nvinfer1::IBuilder* builder = nvinfer1::createInferBuilder(logger);
     69   ScopedWeights weights(2.0);
     70   ScopedWeights bias(3.0);
     71 
     72   nvinfer1::INetworkDefinition* network = builder->createNetwork();
     73   // Add the input.
     74   auto input = network->addInput(kInputTensor, nvinfer1::DataType::kFLOAT,
     75                                  nvinfer1::DimsCHW{1, 1, 1});
     76   EXPECT_NE(input, nullptr);
     77   // Add the hidden layer.
     78   auto layer = network->addFullyConnected(*input, 1, weights.get(), bias.get());
     79   EXPECT_NE(layer, nullptr);
     80   // Mark the output.
     81   auto output = layer->getOutput(0);
     82   output->setName(kOutputTensor);
     83   network->markOutput(*output);
     84   // Build the engine
     85   builder->setMaxBatchSize(1);
     86   builder->setMaxWorkspaceSize(1 << 10);
     87   auto engine = builder->buildCudaEngine(*network);
     88   EXPECT_NE(engine, nullptr);
     89   // Serialize the engine to create a model, then close everything.
     90   nvinfer1::IHostMemory* model = engine->serialize();
     91   network->destroy();
     92   engine->destroy();
     93   builder->destroy();
     94   return model;
     95 }
     96 
     97 // Executes the network.
     98 void Execute(nvinfer1::IExecutionContext& context, const float* input,
     99              float* output) {
    100   const nvinfer1::ICudaEngine& engine = context.getEngine();
    101 
    102   // We have two bindings: input and output.
    103   ASSERT_EQ(engine.getNbBindings(), 2);
    104   const int input_index = engine.getBindingIndex(kInputTensor);
    105   const int output_index = engine.getBindingIndex(kOutputTensor);
    106 
    107   // Create GPU buffers and a stream
    108   void* buffers[2];
    109   ASSERT_EQ(0, cudaMalloc(&buffers[input_index], sizeof(float)));
    110   ASSERT_EQ(0, cudaMalloc(&buffers[output_index], sizeof(float)));
    111   cudaStream_t stream;
    112   ASSERT_EQ(0, cudaStreamCreate(&stream));
    113 
    114   // Copy the input to the GPU, execute the network, and copy the output back.
    115   //
    116   // Note that since the host buffer was not created as pinned memory, these
    117   // async copies are turned into sync copies. So the following synchronization
    118   // could be removed.
    119   ASSERT_EQ(0, cudaMemcpyAsync(buffers[input_index], input, sizeof(float),
    120                                cudaMemcpyHostToDevice, stream));
    121   context.enqueue(1, buffers, stream, nullptr);
    122   ASSERT_EQ(0, cudaMemcpyAsync(output, buffers[output_index], sizeof(float),
    123                                cudaMemcpyDeviceToHost, stream));
    124   cudaStreamSynchronize(stream);
    125 
    126   // Release the stream and the buffers
    127   cudaStreamDestroy(stream);
    128   ASSERT_EQ(0, cudaFree(buffers[input_index]));
    129   ASSERT_EQ(0, cudaFree(buffers[output_index]));
    130 }
    131 
    132 TEST(TensorrtTest, BasicFunctions) {
    133   // Create the network model.
    134   nvinfer1::IHostMemory* model = CreateNetwork();
    135   // Use the model to create an engine and then an execution context.
    136   Logger logger;
    137   nvinfer1::IRuntime* runtime = nvinfer1::createInferRuntime(logger);
    138   nvinfer1::ICudaEngine* engine =
    139       runtime->deserializeCudaEngine(model->data(), model->size(), nullptr);
    140   model->destroy();
    141   nvinfer1::IExecutionContext* context = engine->createExecutionContext();
    142 
    143   // Execute the network.
    144   float input = 1234;
    145   float output;
    146   Execute(*context, &input, &output);
    147   EXPECT_EQ(output, input * 2 + 3);
    148 
    149   // Destroy the engine.
    150   context->destroy();
    151   engine->destroy();
    152   runtime->destroy();
    153 }
    154 
    155 }  // namespace
    156 }  // namespace tensorflow
    157 
    158 #endif  // GOOGLE_TENSORRT
    159 #endif  // GOOGLE_CUDA
    160