1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/array_ops.cc. 17 18 #include <limits> 19 #include <vector> 20 21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 22 #include "tensorflow/core/framework/op_kernel.h" 23 #include "tensorflow/core/framework/register_types.h" 24 #include "tensorflow/core/framework/tensor.h" 25 #include "tensorflow/core/framework/tensor_types.h" 26 #include "tensorflow/core/framework/types.h" 27 #include "tensorflow/core/kernels/concat_lib.h" 28 #include "tensorflow/core/lib/core/status.h" 29 #include "tensorflow/core/platform/types.h" 30 31 namespace tensorflow { 32 33 typedef Eigen::ThreadPoolDevice CPUDevice; 34 #if GOOGLE_CUDA 35 typedef Eigen::GpuDevice GPUDevice; 36 #endif // GOOGLE_CUDA 37 #ifdef TENSORFLOW_USE_SYCL 38 typedef Eigen::SyclDevice SYCLDevice; 39 #endif // TENSORFLOW_USE_SYCL 40 41 // -------------------------------------------------------------------------- 42 template <typename Device, typename T> 43 class PackOp : public OpKernel { 44 public: 45 typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>> 46 ConstMatrixVector; 47 48 explicit PackOp(OpKernelConstruction* context) : OpKernel(context) { 49 OP_REQUIRES_OK(context, context->GetAttr("axis", &axis_)); 50 } 51 52 void Compute(OpKernelContext* c) override { 53 OpInputList values; 54 OP_REQUIRES_OK(c, c->input_list("values", &values)); 55 const int num = values.size(); 56 57 // Verify that all input shapes match 58 for (int i = 1; i < num; i++) { 59 OP_REQUIRES(c, values[0].shape().IsSameSize(values[i].shape()), 60 errors::InvalidArgument( 61 "Shapes of all inputs must match: values[0].shape = ", 62 values[0].shape().DebugString(), " != values[", i, 63 "].shape = ", values[i].shape().DebugString())); 64 } 65 66 int expanded_num_dims = values[0].dims() + 1; 67 int axis = axis_; 68 if (axis < 0) axis += expanded_num_dims; 69 70 OP_REQUIRES(c, 0 <= axis && axis < expanded_num_dims, 71 errors::InvalidArgument("axis = ", axis_, " not in [", 72 -expanded_num_dims, ", ", 73 expanded_num_dims, ")")); 74 75 TensorShape output_shape(values[0].shape()); 76 output_shape.InsertDim(axis, num); 77 78 // In the num = 1 case, just reshape the input 79 if (num == 1) { 80 Tensor output; 81 CHECK(output.CopyFrom(values[0], output_shape)); 82 c->set_output(0, output); 83 return; 84 } 85 86 // Allocate output 87 Tensor* output; 88 OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output)); 89 90 int64 before_dim = 1; 91 for (int i = 0; i < axis; ++i) { 92 before_dim *= output_shape.dim_size(i); 93 } 94 95 int64 after_dim = 1; 96 for (int i = axis + 1; i < output_shape.dims(); ++i) { 97 after_dim *= output_shape.dim_size(i); 98 } 99 100 const int64 axis_dim = output_shape.dim_size(axis); 101 102 const int64 output_size = output->NumElements(); 103 if (output_size > 0) { 104 auto output_flat = 105 output->shaped<T, 2>({before_dim, after_dim * axis_dim}); 106 107 // Except for shapes, pack is a special case of concat, so we reuse the 108 // same computational kernels. 109 ConstMatrixVector inputs_flat; 110 inputs_flat.reserve(num); 111 for (int i = 0; i < num; ++i) { 112 inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix( 113 values[i].shaped<T, 2>({before_dim, after_dim}))); 114 } 115 #if GOOGLE_CUDA 116 if (std::is_same<Device, GPUDevice>::value) { 117 ConcatGPU<T>(c, inputs_flat, output, &output_flat); 118 return; 119 } 120 #endif // GOOGLE_CUDA 121 #ifdef TENSORFLOW_USE_SYCL 122 if (std::is_same<Device, SYCLDevice>::value) { 123 ConcatSYCL<T>(c->eigen_sycl_device(), inputs_flat, &output_flat); 124 return; 125 } 126 #endif // TENSORFLOW_USE_SYCL 127 ConcatCPU<T>(c->device(), inputs_flat, &output_flat); 128 } 129 } 130 131 private: 132 int axis_; 133 }; 134 135 #define REGISTER_PACK(type) \ 136 REGISTER_KERNEL_BUILDER( \ 137 Name("Pack").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ 138 PackOp<CPUDevice, type>) 139 140 TF_CALL_ALL_TYPES(REGISTER_PACK); 141 TF_CALL_QUANTIZED_TYPES(REGISTER_PACK); 142 143 #if defined(IS_MOBILE_PLATFORM) && !defined(SUPPORT_SELECTIVE_REGISTRATION) 144 // Primarily used for SavedModel support on mobile. 145 REGISTER_PACK(string); 146 #endif // defined(IS_MOBILE_PLATFORM) && 147 // !defined(SUPPORT_SELECTIVE_REGISTRATION) 148 149 #undef REGISTER_PACK 150 151 #if GOOGLE_CUDA 152 153 #define REGISTER_GPU(type) \ 154 REGISTER_KERNEL_BUILDER( \ 155 Name("Pack").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ 156 PackOp<GPUDevice, type>) 157 158 TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); 159 TF_CALL_bfloat16(REGISTER_GPU); 160 TF_CALL_int64(REGISTER_GPU); 161 REGISTER_GPU(bool); 162 #undef REGISTER_GPU 163 164 // A special GPU kernel for int32. 165 // TODO(b/25387198): Also enable int32 in device memory. This kernel 166 // registration requires all int32 inputs and outputs to be in host memory. 167 REGISTER_KERNEL_BUILDER(Name("Pack") 168 .Device(DEVICE_GPU) 169 .HostMemory("values") 170 .HostMemory("output") 171 .TypeConstraint<int32>("T"), 172 PackOp<CPUDevice, int32>); 173 174 #endif // GOOGLE_CUDA 175 176 #ifdef TENSORFLOW_USE_SYCL 177 #define REGISTER_SYCL(type) \ 178 REGISTER_KERNEL_BUILDER( \ 179 Name("Pack").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ 180 PackOp<SYCLDevice, type>) 181 182 TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL); 183 REGISTER_KERNEL_BUILDER(Name("Pack") 184 .Device(DEVICE_SYCL) 185 .HostMemory("values") 186 .HostMemory("output") 187 .TypeConstraint<int32>("T"), 188 PackOp<CPUDevice, int32>); 189 #undef REGISTER_SYCL 190 #endif // TENSORFLOW_USE_SYCL 191 } // namespace tensorflow 192