1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #define EIGEN_USE_THREADS 17 18 #include "tensorflow/core/framework/op_kernel.h" 19 #include "tensorflow/core/framework/register_types.h" 20 #include "tensorflow/core/framework/tensor.h" 21 #include "tensorflow/core/framework/tensor_shape.h" 22 #include "tensorflow/core/kernels/fill_functor.h" 23 #include "tensorflow/core/kernels/inplace_ops_functor.h" 24 #include "tensorflow/core/lib/core/status.h" 25 26 namespace tensorflow { 27 typedef Eigen::ThreadPoolDevice CPUDevice; 28 #ifdef TENSORFLOW_USE_SYCL 29 typedef Eigen::SyclDevice SyclDevice; 30 #endif // TENSORFLOW_USE_SYCL 31 32 namespace functor { 33 34 template <typename Device, typename T> 35 Status DoParallelConcatUpdate(const Device& d, const Tensor& value, int32 loc, 36 Tensor* output) { 37 auto Tvalue = value.shaped<T, 2>({1, value.NumElements()}); 38 auto Toutput = output->flat_outer_dims<T>(); 39 auto nrows = Toutput.dimension(0); 40 auto r = (loc % nrows + nrows) % nrows; // Guard index range. 41 Toutput.template chip<0>(r).device(d) = Tvalue.template chip<0>(0); 42 return Status::OK(); 43 } 44 45 template <> 46 Status DoParallelConcat(const CPUDevice& d, const Tensor& value, int32 loc, 47 Tensor* output) { 48 CHECK_EQ(value.dtype(), output->dtype()); 49 switch (value.dtype()) { 50 #define CASE(type) \ 51 case DataTypeToEnum<type>::value: \ 52 return DoParallelConcatUpdate<CPUDevice, type>(d, value, loc, output); 53 TF_CALL_NUMBER_TYPES(CASE); 54 TF_CALL_string(CASE); 55 TF_CALL_variant(CASE); 56 #undef CASE 57 default: 58 return errors::InvalidArgument("Unsupported data type: ", value.dtype()); 59 } 60 } 61 62 #ifdef TENSORFLOW_USE_SYCL 63 template <> 64 Status DoParallelConcat(const SyclDevice& d, const Tensor& value, int32 loc, 65 Tensor* output) { 66 CHECK_EQ(value.dtype(), output->dtype()); 67 switch (value.dtype()) { 68 #define CASE(type) \ 69 case DataTypeToEnum<type>::value: \ 70 return DoParallelConcatUpdate<SyclDevice, type>(d, value, loc, output); 71 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(CASE); 72 #undef CASE 73 default: 74 return errors::InvalidArgument("Unsupported data type: ", value.dtype()); 75 } 76 } 77 #endif // TENSORFLOW_USE_SYCL 78 79 } // end namespace functor 80 81 namespace { 82 83 template <typename Device> 84 class ParallelConcatUpdate : public OpKernel { 85 public: 86 explicit ParallelConcatUpdate(OpKernelConstruction* ctx) : OpKernel(ctx) { 87 OP_REQUIRES_OK(ctx, ctx->GetAttr("loc", &loc_)); 88 } 89 90 void Compute(OpKernelContext* ctx) override { 91 auto value = ctx->input(0); 92 auto update = ctx->input(1); 93 94 OP_REQUIRES( 95 ctx, value.dims() == update.dims(), 96 errors::InvalidArgument("value and update shape doesn't match: ", 97 value.shape().DebugString(), " vs. ", 98 update.shape().DebugString())); 99 for (int i = 1; i < value.dims(); ++i) { 100 OP_REQUIRES( 101 ctx, value.dim_size(i) == update.dim_size(i), 102 errors::InvalidArgument("value and update shape doesn't match ", 103 value.shape().DebugString(), " vs. ", 104 update.shape().DebugString())); 105 } 106 OP_REQUIRES(ctx, 1 == update.dim_size(0), 107 errors::InvalidArgument("update shape doesn't match: ", 108 update.shape().DebugString())); 109 110 Tensor output = value; // This creates an alias intentionally. 111 const auto& d = ctx->eigen_device<Device>(); 112 OP_REQUIRES_OK( 113 ctx, ::tensorflow::functor::DoParallelConcat(d, update, loc_, &output)); 114 ctx->set_output(0, output); 115 } 116 117 private: 118 int32 loc_; 119 }; 120 121 template <typename Device, typename T> 122 class ParallelConcatStart : public OpKernel { 123 public: 124 explicit ParallelConcatStart(OpKernelConstruction* ctx) : OpKernel(ctx) { 125 OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &shape_)); 126 } 127 128 void Compute(OpKernelContext* ctx) override { 129 Tensor* out = nullptr; 130 // We do not know whether the output will be used on GPU. Setting it to be 131 // gpu-compatible for now. 132 AllocatorAttributes attr; 133 attr.set_gpu_compatible(true); 134 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, shape_, &out, attr)); 135 } 136 137 private: 138 TensorShape shape_; 139 }; 140 141 class FailureKernel : public OpKernel { 142 public: 143 explicit FailureKernel(OpKernelConstruction* ctx) : OpKernel(ctx) { 144 OP_REQUIRES_OK(ctx, 145 errors::Internal("Found instance of parallel_stack which " 146 "could not be properly replaced.")); 147 } 148 149 void Compute(OpKernelContext*) override {} 150 }; 151 152 #define REGISTER(type) \ 153 REGISTER_KERNEL_BUILDER(Name("_ParallelConcatUpdate") \ 154 .Device(DEVICE_CPU) \ 155 .TypeConstraint<type>("T"), \ 156 ParallelConcatUpdate<CPUDevice>); 157 TF_CALL_POD_STRING_TYPES(REGISTER) 158 #undef REGISTER 159 160 #define REGISTER_EMPTY(type) \ 161 REGISTER_KERNEL_BUILDER(Name("_ParallelConcatStart") \ 162 .Device(DEVICE_CPU) \ 163 .TypeConstraint<type>("dtype"), \ 164 ParallelConcatStart<CPUDevice, type>) 165 166 TF_CALL_POD_STRING_TYPES(REGISTER_EMPTY) 167 #undef REGISTER_EMPTY 168 169 #define REGISTER_PARALLEL_CONCAT(type) \ 170 REGISTER_KERNEL_BUILDER( \ 171 Name("ParallelConcat").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ 172 FailureKernel); 173 TF_CALL_POD_STRING_TYPES(REGISTER_PARALLEL_CONCAT); 174 #undef REGISTER_PARALLEL_CONCAT 175 176 #ifdef TENSORFLOW_USE_SYCL 177 #define REGISTER_EMPTY(type) \ 178 REGISTER_KERNEL_BUILDER(Name("_ParallelConcatStart") \ 179 .Device(DEVICE_SYCL) \ 180 .TypeConstraint<type>("dtype"), \ 181 ParallelConcatStart<SyclDevice, type>); 182 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_EMPTY) 183 #undef REGISTER_EMPTY 184 185 #define REGISTER_PARALLEL_CONCAT(type) \ 186 REGISTER_KERNEL_BUILDER( \ 187 Name("ParallelConcat").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ 188 FailureKernel); 189 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_PARALLEL_CONCAT); 190 #undef REGISTER_PARALLEL_CONCAT 191 192 #define REGISTER(type) \ 193 REGISTER_KERNEL_BUILDER(Name("_ParallelConcatUpdate") \ 194 .Device(DEVICE_SYCL) \ 195 .TypeConstraint<type>("T"), \ 196 ParallelConcatUpdate<SyclDevice>); 197 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER) 198 #undef REGISTER 199 200 // Register versions that operate on int32 data on the CPU even though the op 201 // has been placed on the SYCL 202 203 REGISTER_KERNEL_BUILDER(Name("_ParallelConcatUpdate") 204 .Device(DEVICE_SYCL) 205 .HostMemory("value") 206 .HostMemory("update") 207 .HostMemory("output") 208 .TypeConstraint<int32>("T"), 209 ParallelConcatUpdate<CPUDevice>); 210 #endif // TENSORFLOW_USE_SYCL 211 212 #if GOOGLE_CUDA 213 214 typedef Eigen::GpuDevice GPUDevice; 215 216 #define REGISTER_EMPTY(type) \ 217 REGISTER_KERNEL_BUILDER(Name("_ParallelConcatStart") \ 218 .Device(DEVICE_GPU) \ 219 .TypeConstraint<type>("dtype"), \ 220 ParallelConcatStart<GPUDevice, type>); 221 TF_CALL_GPU_NUMBER_TYPES(REGISTER_EMPTY) 222 #undef REGISTER_EMPTY 223 224 #define REGISTER_PARALLEL_CONCAT(type) \ 225 REGISTER_KERNEL_BUILDER( \ 226 Name("ParallelConcat").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ 227 FailureKernel); 228 TF_CALL_GPU_NUMBER_TYPES(REGISTER_PARALLEL_CONCAT); 229 #undef REGISTER_PARALLEL_CONCAT 230 231 #define REGISTER(type) \ 232 REGISTER_KERNEL_BUILDER(Name("_ParallelConcatUpdate") \ 233 .Device(DEVICE_GPU) \ 234 .TypeConstraint<type>("T"), \ 235 ParallelConcatUpdate<GPUDevice>); 236 TF_CALL_GPU_NUMBER_TYPES(REGISTER) 237 #undef REGISTER 238 239 // Register versions that operate on int32 data on the CPU even though the op 240 // has been placed on the GPU 241 242 REGISTER_KERNEL_BUILDER(Name("_ParallelConcatUpdate") 243 .Device(DEVICE_GPU) 244 .HostMemory("value") 245 .HostMemory("update") 246 .HostMemory("output") 247 .TypeConstraint<int32>("T"), 248 ParallelConcatUpdate<CPUDevice>); 249 #endif 250 251 } // end namespace 252 } // end namespace tensorflow 253