1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/array_ops.cc. 17 18 #define EIGEN_USE_THREADS 19 20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 21 #include "tensorflow/core/framework/op_kernel.h" 22 #include "tensorflow/core/framework/register_types.h" 23 #include "tensorflow/core/framework/tensor.h" 24 #include "tensorflow/core/kernels/bounds_check.h" 25 #include "tensorflow/core/kernels/ops_util.h" 26 #include "tensorflow/core/kernels/split_lib.h" 27 #include "tensorflow/core/lib/core/status.h" 28 #include "tensorflow/core/lib/gtl/array_slice.h" 29 30 namespace tensorflow { 31 32 typedef Eigen::ThreadPoolDevice CPUDevice; 33 typedef Eigen::GpuDevice GPUDevice; 34 35 #ifdef TENSORFLOW_USE_SYCL 36 typedef Eigen::SyclDevice SYCLDevice; 37 #endif // TENSORFLOW_USE_SYCL 38 39 template <typename Device, typename T> 40 class UnpackOp : public OpKernel { 41 public: 42 explicit UnpackOp(OpKernelConstruction* context) : OpKernel(context) { 43 OP_REQUIRES_OK(context, context->GetAttr("axis", &axis_)); 44 } 45 46 void Compute(OpKernelContext* context) override { 47 const int32 num = num_outputs(); 48 const Tensor& input = context->input(0); 49 const TensorShape& input_shape = input.shape(); 50 51 int axis = axis_; 52 if (axis < 0) axis += input_shape.dims(); 53 54 OP_REQUIRES(context, 0 <= axis && axis < input_shape.dims(), 55 errors::InvalidArgument("axis = ", axis_, " not in [", 56 -input_shape.dims(), ", ", 57 input_shape.dims(), ")")); 58 59 OP_REQUIRES( 60 context, input_shape.dims() > 0 && input_shape.dim_size(axis) == num, 61 errors::InvalidArgument("Input shape axis ", axis, " must equal ", num, 62 ", got shape ", input_shape.DebugString())); 63 64 auto output_shape = input_shape; 65 output_shape.RemoveDim(axis); 66 const int64 output_size = output_shape.num_elements(); 67 OP_REQUIRES( 68 context, 69 FastBoundsCheck(output_size, 70 std::numeric_limits<Eigen::DenseIndex>::max()), 71 errors::InvalidArgument("output size must fit in Eigen DenseIndex")); 72 73 // This optimization is currently not applicable for SYCL devices 74 #ifndef TENSORFLOW_USE_SYCL 75 // Special case: Aligned, so we can share the underlying buffer. 76 // 77 // Apply this optimization conservatively: if input is aligned, 78 // the resulting tensors must be aligned. It's conservative 79 // because if the immediate consumer of the resulting tensors are 80 // not using eigen for computation, its perfectly fine to avoid 81 // the copying. 82 if (axis == 0 && 83 (output_size == 0 || IsInnerDimsSizeAligned<T>(input_shape))) { 84 for (int i = 0; i < num; ++i) { 85 Tensor output; 86 CHECK(output.CopyFrom(input.Slice(i, i + 1), output_shape)); 87 context->set_output(i, output); 88 } 89 return; 90 } 91 #endif // TENSORFLOW_USE_SYCL 92 93 int64 before_dim = 1; 94 for (int i = 0; i < axis; ++i) { 95 before_dim *= input_shape.dim_size(i); 96 } 97 98 int64 after_dim = 1; 99 for (int i = axis + 1; i < input_shape.dims(); ++i) { 100 after_dim *= input_shape.dim_size(i); 101 } 102 const int64 axis_dim = input_shape.dim_size(axis); 103 104 // Except for shape, unpack is a special case of split, so we reuse the 105 // same computational kernels. 106 auto input_reshaped = 107 input.shaped<T, 3>({1, before_dim, axis_dim * after_dim}); 108 109 for (int i = 0; i < num; ++i) { 110 Tensor* output; 111 OP_REQUIRES_OK(context, 112 context->allocate_output(i, output_shape, &output)); 113 114 if (output_shape.num_elements() > 0) { 115 auto output_shaped = output->shaped<T, 3>({1, before_dim, after_dim}); 116 Eigen::DSizes<Eigen::DenseIndex, 3> indices{0, 0, i * after_dim}; 117 Eigen::DSizes<Eigen::DenseIndex, 3> sizes{1, before_dim, after_dim}; 118 functor::Split<Device, T>()(context->eigen_device<Device>(), 119 output_shaped, input_reshaped, indices, 120 sizes); 121 } 122 } 123 } 124 125 private: 126 int axis_; 127 }; 128 129 #define REGISTER_UNPACK(type) \ 130 REGISTER_KERNEL_BUILDER( \ 131 Name("Unpack").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ 132 UnpackOp<CPUDevice, type>) 133 134 TF_CALL_ALL_TYPES(REGISTER_UNPACK); 135 136 #undef REGISTER_UNPACK 137 138 #if GOOGLE_CUDA 139 140 #define REGISTER_GPU(type) \ 141 REGISTER_KERNEL_BUILDER( \ 142 Name("Unpack").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ 143 UnpackOp<GPUDevice, type>) 144 145 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); 146 TF_CALL_bfloat16(REGISTER_GPU); 147 #undef REGISTER_GPU 148 149 // A special GPU kernel for int32. 150 // TODO(b/25387198): Also enable int32 in device memory. This kernel 151 // registration requires all int32 inputs and outputs to be in host memory. 152 REGISTER_KERNEL_BUILDER(Name("Unpack") 153 .Device(DEVICE_GPU) 154 .HostMemory("value") 155 .HostMemory("output") 156 .TypeConstraint<int32>("T"), 157 UnpackOp<CPUDevice, int32>); 158 REGISTER_KERNEL_BUILDER(Name("Unpack") 159 .Device(DEVICE_GPU) 160 .HostMemory("value") 161 .HostMemory("output") 162 .TypeConstraint<int64>("T"), 163 UnpackOp<CPUDevice, int64>); 164 165 #endif // GOOGLE_CUDA 166 167 #ifdef TENSORFLOW_USE_SYCL 168 #define REGISTER_SYCL(type) \ 169 REGISTER_KERNEL_BUILDER( \ 170 Name("Unpack").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ 171 UnpackOp<SYCLDevice, type>) 172 173 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL); 174 175 REGISTER_KERNEL_BUILDER(Name("Unpack") 176 .Device(DEVICE_SYCL) 177 .HostMemory("value") 178 .HostMemory("output") 179 .TypeConstraint<int32>("T"), 180 UnpackOp<CPUDevice, int32>); 181 182 REGISTER_KERNEL_BUILDER(Name("Unpack") 183 .Device(DEVICE_SYCL) 184 .HostMemory("value") 185 .HostMemory("output") 186 .TypeConstraint<int64>("T"), 187 UnpackOp<CPUDevice, int64>); 188 #undef REGISTER_SYCL 189 #endif // TENSORFLOW_USE_SYCL 190 191 } // end namespace tensorflow 192