1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_KERNELS_CONV_OPS_H_ 17 #define TENSORFLOW_KERNELS_CONV_OPS_H_ 18 19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 20 #include "tensorflow/core/framework/resource_mgr.h" 21 #include "tensorflow/core/platform/mem.h" 22 #include "tensorflow/core/util/tensor_format.h" 23 24 #if GOOGLE_CUDA 25 #include "tensorflow/core/kernels/conv_ops_gpu.h" 26 #include "tensorflow/core/platform/stream_executor.h" 27 #endif // GOOGLE_CUDA 28 29 namespace tensorflow { 30 31 // Forward declaration. 32 class OpKernelContext; 33 34 template <typename Device, typename T> 35 struct LaunchConv2DOp { 36 void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, 37 const Tensor& input, const Tensor& filter, int row_dilation, 38 int col_dilation, int row_stride, int col_stride, 39 const Padding& padding, Tensor* output, 40 TensorFormat data_format); 41 }; 42 43 #ifdef GOOGLE_CUDA 44 template <typename T> 45 struct LaunchConv2DOp<Eigen::GpuDevice, T> { 46 void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, 47 const Tensor& input, const Tensor& filter, int row_dilation, 48 int col_dilation, int row_stride, int col_stride, 49 const Padding& padding, Tensor* output, 50 TensorFormat data_format); 51 }; 52 #endif // GOOGLE_CUDA 53 54 // Used to keep track of persistent memory buffers used within the op. 55 // It uses malloc and free to avoid the time cost of initializing the memory. 56 template <class T, size_t size> 57 struct Im2ColBufferResource : public ResourceBase { 58 Im2ColBufferResource<T, size>() { 59 data = static_cast<T*>(port::Malloc(size * sizeof(T))); 60 } 61 ~Im2ColBufferResource<T, size>() { port::Free(data); } 62 // This mutex ensures that only a single operation at a time is able to use 63 // the buffer memory held by this resource. 64 mutex mu; 65 T* data; 66 string DebugString() { return "Im2ColBufferResource"; } 67 }; 68 69 } // namespace tensorflow 70 71 #endif // TENSORFLOW_KERNELS_CONV_OPS_H 72