1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_KERNELS_LIST_KERNELS_H_ 16 #define TENSORFLOW_CORE_KERNELS_LIST_KERNELS_H_ 17 18 #define EIGEN_USE_THREADS 19 #if GOOGLE_CUDA 20 #define EIGEN_USE_GPU 21 #endif // GOOGLE_CUDA 22 23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 24 #include "tensorflow/core/framework/op_kernel.h" 25 #include "tensorflow/core/framework/register_types.h" 26 #include "tensorflow/core/framework/tensor.h" 27 #include "tensorflow/core/framework/tensor_types.h" 28 #include "tensorflow/core/framework/variant.h" 29 #include "tensorflow/core/framework/variant_op_registry.h" 30 #include "tensorflow/core/kernels/concat_lib.h" 31 #include "tensorflow/core/lib/core/coding.h" 32 #include "tensorflow/core/lib/core/errors.h" 33 #include "tensorflow/core/util/util.h" 34 35 namespace tensorflow { 36 37 // Variant compatible type for a list of tensors. This is mutable but instances 38 // should never be mutated after stored in a variant tensor. 39 struct TensorList { 40 public: 41 TensorList() {} 42 TensorList(const TensorList& other); 43 44 static const char kTypeName[]; 45 string TypeName() const { return kTypeName; } 46 47 void Encode(VariantTensorData* data) const; 48 49 bool Decode(const VariantTensorData& data); 50 51 // TODO(apassos) fill this out 52 string DebugString() const { return "TensorList"; } 53 54 std::vector<Tensor> tensors; 55 PartialTensorShape element_shape; 56 DataType element_dtype; 57 }; 58 59 Status TensorShapeFromTensor(const Tensor& t, PartialTensorShape* out); 60 61 template <typename Device, typename T> 62 class TensorListStack : public OpKernel { 63 public: 64 typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>> 65 ConstMatrixVector; 66 explicit TensorListStack(OpKernelConstruction* c) : OpKernel(c) { 67 OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_)); 68 OP_REQUIRES_OK(c, c->GetAttr("num_elements", &num_elements_)); 69 } 70 71 ~TensorListStack() {} 72 73 void Compute(OpKernelContext* c) override { 74 const TensorList* l = c->input(0).scalar<Variant>()().get<TensorList>(); 75 OP_REQUIRES(c, l != nullptr, 76 errors::InvalidArgument( 77 "Input handle is not a list. Saw: '", 78 c->input(0).scalar<Variant>()().DebugString(), "'")); 79 OP_REQUIRES(c, element_dtype_ == l->element_dtype, 80 errors::InvalidArgument("Invalid data types; op elements ", 81 DataTypeString(element_dtype_), 82 " but list elements ", 83 DataTypeString(l->element_dtype))); 84 OP_REQUIRES(c, l->element_shape.IsFullyDefined(), 85 errors::InvalidArgument("Tried to stack elements from a list " 86 "with non-fully-defined shape.")); 87 if (num_elements_ != -1) { 88 OP_REQUIRES(c, l->tensors.size() == num_elements_, 89 errors::InvalidArgument("Operation expected a list with ", 90 num_elements_, 91 " elements but got a list with ", 92 l->tensors.size(), " elements.")); 93 } 94 TensorShape resulting_shape; 95 resulting_shape.AddDim(l->tensors.size()); 96 for (TensorShapeDim s : l->element_shape) { 97 resulting_shape.AddDim(s.size); 98 } 99 Tensor* output; 100 OP_REQUIRES_OK(c, c->allocate_output(0, resulting_shape, &output)); 101 if (output->NumElements() == 0) { 102 return; 103 } 104 105 ConstMatrixVector inputs_flat; 106 inputs_flat.reserve(l->tensors.size()); 107 for (const auto& t : l->tensors) { 108 OP_REQUIRES( 109 c, l->element_shape.IsCompatibleWith(t.shape()), 110 errors::InvalidArgument( 111 "Tensor with invalid shape in list. List element shape shape: ", 112 l->element_shape.DebugString(), 113 " and tensor shape: ", t.shape().DebugString())); 114 inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix( 115 t.shaped<T, 2>({1, t.NumElements()}))); 116 } 117 auto output_flat = output->shaped<T, 2>({1, output->NumElements()}); 118 119 #if GOOGLE_CUDA 120 if (std::is_same<Device, Eigen::GpuDevice>::value) { 121 ConcatGPU<T>(c, inputs_flat, output, &output_flat); 122 return; 123 } 124 #endif // GOOGLE_CUDA 125 ConcatCPU<T>(c->device(), inputs_flat, &output_flat); 126 } 127 128 private: 129 int num_elements_; 130 DataType element_dtype_; 131 }; 132 133 template <typename Device, typename T> 134 class TensorListFromTensor : public OpKernel { 135 public: 136 TensorListFromTensor(OpKernelConstruction* c) : OpKernel(c) {} 137 138 void Compute(OpKernelContext* c) override { 139 Tensor* output_tensor; 140 AllocatorAttributes attr; 141 attr.set_on_host(true); 142 OP_REQUIRES_OK(c, c->allocate_output(0, {}, &output_tensor, attr)); 143 PartialTensorShape element_shape; 144 OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(1), &element_shape)); 145 TensorList output_list; 146 const Tensor& t = c->input(0); 147 output_list.element_dtype = t.dtype(); 148 TensorShape output_shape(t.shape()); 149 output_shape.RemoveDim(0); 150 OP_REQUIRES(c, element_shape.IsCompatibleWith(output_shape), 151 errors::InvalidArgument( 152 "Specified a list with shape ", element_shape.DebugString(), 153 " from a tensor with shape ", output_shape.DebugString())); 154 output_list.element_shape = element_shape; 155 output_list.tensors.reserve(t.shape().dim_size(0)); 156 for (int i = 0; i < t.shape().dim_size(0); ++i) { 157 Tensor tmp = t.Slice(i, i + 1); 158 TensorShape tmp_shape = tmp.shape(); 159 tmp_shape.RemoveDim(0); 160 OP_REQUIRES(c, tmp.CopyFrom(tmp, tmp_shape), 161 errors::Unknown("Unexpected shape error.")); 162 if (tmp.IsAligned() || !DataTypeCanUseMemcpy(DataTypeToEnum<T>::value)) { 163 output_list.tensors.push_back(tmp); 164 } else { 165 Tensor aligned; 166 OP_REQUIRES_OK(c, c->allocate_temp(tmp.dtype(), tmp.shape(), &aligned)); 167 aligned.flat<T>().device(c->eigen_device<Device>()) = 168 tmp.unaligned_flat<T>(); 169 output_list.tensors.push_back(aligned); 170 } 171 } 172 output_tensor->scalar<Variant>()() = std::move(output_list); 173 } 174 }; 175 176 template <typename Device> 177 Status TensorListBinaryAdd(OpKernelContext* c, const TensorList& a, 178 const TensorList& b, TensorList* out) { 179 if (a.element_dtype != b.element_dtype) { 180 return errors::InvalidArgument( 181 "Trying to add two lists of tensors of different dtypes. One is ", 182 DataTypeString(a.element_dtype), " and the other is ", 183 DataTypeString(b.element_dtype)); 184 } 185 out->element_dtype = a.element_dtype; 186 if (!a.element_shape.IsCompatibleWith(b.element_shape)) { 187 return errors::InvalidArgument( 188 "Trying to add two lists of tensors with incompatible element shapes. " 189 "One is ", 190 a.element_shape.DebugString(), " and the other is ", 191 b.element_shape.DebugString()); 192 } 193 194 TF_RETURN_IF_ERROR( 195 a.element_shape.MergeWith(b.element_shape, &out->element_shape)); 196 if (a.tensors.size() != b.tensors.size()) { 197 return errors::InvalidArgument( 198 "Trying to add two lists of tensors with different lengths. One is ", 199 a.tensors.size(), " and the other is ", b.tensors.size()); 200 } 201 out->tensors.reserve(a.tensors.size()); 202 for (int i = 0; i < a.tensors.size(); ++i) { 203 const Tensor& a_tensor = a.tensors[i]; 204 const Tensor& b_tensor = b.tensors[i]; 205 if (a_tensor.dtype() == DT_INVALID) { 206 out->tensors.push_back(b_tensor); 207 continue; 208 } 209 if (b_tensor.dtype() == DT_INVALID) { 210 out->tensors.push_back(a_tensor); 211 continue; 212 } 213 if (a_tensor.shape() != b_tensor.shape()) { 214 // TODO(apassos) support broadcasting additions here? 215 return errors::InvalidArgument( 216 "Trying to add two tensors with incompatible element shapes. " 217 "One is ", 218 a_tensor.shape().DebugString(), " and the other is ", 219 b_tensor.shape().DebugString(), " in position ", i); 220 } 221 Tensor out_tensor; 222 TF_RETURN_IF_ERROR( 223 c->allocate_temp(a_tensor.dtype(), a_tensor.shape(), &out_tensor)); 224 out->tensors.push_back(out_tensor); 225 switch (out_tensor.dtype()) { 226 #define DTYPE_CASE(dtype) \ 227 case DataTypeToEnum<dtype>::value: \ 228 out_tensor.flat<dtype>().device(c->eigen_device<Device>()) = \ 229 a_tensor.flat<dtype>() + b_tensor.flat<dtype>(); \ 230 break; 231 232 TF_CALL_NUMBER_TYPES(DTYPE_CASE) 233 234 #undef DTYPE_CASE 235 default: 236 return errors::InvalidArgument("Trying to add unsupported dtype ", 237 out_tensor.dtype()); 238 } 239 } 240 return Status::OK(); 241 } 242 243 template <typename Device> 244 Status TensorListZerosLike(OpKernelContext* c, const TensorList& x, 245 TensorList* y) { 246 y->element_dtype = x.element_dtype; 247 y->element_shape = x.element_shape; 248 y->tensors.reserve(x.tensors.size()); 249 for (const Tensor& t : x.tensors) { 250 Tensor out_tensor; 251 TF_RETURN_IF_ERROR(c->allocate_temp(t.dtype(), t.shape(), &out_tensor)); 252 switch (out_tensor.dtype()) { 253 #define DTYPE_CASE(dtype) \ 254 case DataTypeToEnum<dtype>::value: \ 255 out_tensor.flat<dtype>().device(c->eigen_device<Device>()) = \ 256 out_tensor.flat<dtype>().constant(dtype(0)); \ 257 break; 258 259 TF_CALL_NUMBER_TYPES(DTYPE_CASE) 260 261 #undef DTYPE_CASE 262 default: 263 return errors::InvalidArgument( 264 "Trying to compute zeros_like for unsupported dtype", 265 out_tensor.dtype()); 266 } 267 } 268 return Status::OK(); 269 } 270 271 } // namespace tensorflow 272 273 #endif // TENSORFLOW_CORE_KERNELS_LIST_KERNELS_H_ 274