1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/platform/logging.h" 17 #include "tensorflow/core/platform/test.h" 18 19 #if GOOGLE_CUDA 20 #if GOOGLE_TENSORRT 21 #include "cuda/include/cuda.h" 22 #include "cuda/include/cuda_runtime_api.h" 23 #include "tensorrt/include/NvInfer.h" 24 25 namespace tensorflow { 26 namespace { 27 28 class Logger : public nvinfer1::ILogger { 29 public: 30 void log(nvinfer1::ILogger::Severity severity, const char* msg) override { 31 switch (severity) { 32 case Severity::kINFO: 33 LOG(INFO) << msg; 34 break; 35 case Severity::kWARNING: 36 LOG(WARNING) << msg; 37 break; 38 case Severity::kINTERNAL_ERROR: 39 case Severity::kERROR: 40 LOG(ERROR) << msg; 41 break; 42 default: 43 break; 44 } 45 } 46 }; 47 48 class ScopedWeights { 49 public: 50 ScopedWeights(float value) : value_(value) { 51 w.type = nvinfer1::DataType::kFLOAT; 52 w.values = &value_; 53 w.count = 1; 54 } 55 const nvinfer1::Weights& get() { return w; } 56 57 private: 58 float value_; 59 nvinfer1::Weights w; 60 }; 61 62 const char* kInputTensor = "input"; 63 const char* kOutputTensor = "output"; 64 65 // Creates a network to compute y=2x+3. 66 nvinfer1::IHostMemory* CreateNetwork() { 67 Logger logger; 68 nvinfer1::IBuilder* builder = nvinfer1::createInferBuilder(logger); 69 ScopedWeights weights(2.0); 70 ScopedWeights bias(3.0); 71 72 nvinfer1::INetworkDefinition* network = builder->createNetwork(); 73 // Add the input. 74 auto input = network->addInput(kInputTensor, nvinfer1::DataType::kFLOAT, 75 nvinfer1::DimsCHW{1, 1, 1}); 76 EXPECT_NE(input, nullptr); 77 // Add the hidden layer. 78 auto layer = network->addFullyConnected(*input, 1, weights.get(), bias.get()); 79 EXPECT_NE(layer, nullptr); 80 // Mark the output. 81 auto output = layer->getOutput(0); 82 output->setName(kOutputTensor); 83 network->markOutput(*output); 84 // Build the engine 85 builder->setMaxBatchSize(1); 86 builder->setMaxWorkspaceSize(1 << 10); 87 auto engine = builder->buildCudaEngine(*network); 88 EXPECT_NE(engine, nullptr); 89 // Serialize the engine to create a model, then close everything. 90 nvinfer1::IHostMemory* model = engine->serialize(); 91 network->destroy(); 92 engine->destroy(); 93 builder->destroy(); 94 return model; 95 } 96 97 // Executes the network. 98 void Execute(nvinfer1::IExecutionContext& context, const float* input, 99 float* output) { 100 const nvinfer1::ICudaEngine& engine = context.getEngine(); 101 102 // We have two bindings: input and output. 103 ASSERT_EQ(engine.getNbBindings(), 2); 104 const int input_index = engine.getBindingIndex(kInputTensor); 105 const int output_index = engine.getBindingIndex(kOutputTensor); 106 107 // Create GPU buffers and a stream 108 void* buffers[2]; 109 ASSERT_EQ(0, cudaMalloc(&buffers[input_index], sizeof(float))); 110 ASSERT_EQ(0, cudaMalloc(&buffers[output_index], sizeof(float))); 111 cudaStream_t stream; 112 ASSERT_EQ(0, cudaStreamCreate(&stream)); 113 114 // Copy the input to the GPU, execute the network, and copy the output back. 115 // 116 // Note that since the host buffer was not created as pinned memory, these 117 // async copies are turned into sync copies. So the following synchronization 118 // could be removed. 119 ASSERT_EQ(0, cudaMemcpyAsync(buffers[input_index], input, sizeof(float), 120 cudaMemcpyHostToDevice, stream)); 121 context.enqueue(1, buffers, stream, nullptr); 122 ASSERT_EQ(0, cudaMemcpyAsync(output, buffers[output_index], sizeof(float), 123 cudaMemcpyDeviceToHost, stream)); 124 cudaStreamSynchronize(stream); 125 126 // Release the stream and the buffers 127 cudaStreamDestroy(stream); 128 ASSERT_EQ(0, cudaFree(buffers[input_index])); 129 ASSERT_EQ(0, cudaFree(buffers[output_index])); 130 } 131 132 TEST(TensorrtTest, BasicFunctions) { 133 // Create the network model. 134 nvinfer1::IHostMemory* model = CreateNetwork(); 135 // Use the model to create an engine and then an execution context. 136 Logger logger; 137 nvinfer1::IRuntime* runtime = nvinfer1::createInferRuntime(logger); 138 nvinfer1::ICudaEngine* engine = 139 runtime->deserializeCudaEngine(model->data(), model->size(), nullptr); 140 model->destroy(); 141 nvinfer1::IExecutionContext* context = engine->createExecutionContext(); 142 143 // Execute the network. 144 float input = 1234; 145 float output; 146 Execute(*context, &input, &output); 147 EXPECT_EQ(output, input * 2 + 3); 148 149 // Destroy the engine. 150 context->destroy(); 151 engine->destroy(); 152 runtime->destroy(); 153 } 154 155 } // namespace 156 } // namespace tensorflow 157 158 #endif // GOOGLE_TENSORRT 159 #endif // GOOGLE_CUDA 160