1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <string> 17 18 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 19 #include "tensorflow/core/framework/kernel_def_builder.h" 20 #include "tensorflow/core/framework/op.h" 21 #include "tensorflow/core/framework/op_kernel.h" 22 #include "tensorflow/core/framework/tensor.h" 23 #include "tensorflow/core/framework/tensor_shape.h" 24 #include "tensorflow/core/framework/tensor_types.h" 25 #include "tensorflow/core/framework/types.h" 26 #include "tensorflow/core/kernels/bounds_check.h" 27 #include "tensorflow/core/lib/core/errors.h" 28 #include "tensorflow/core/util/bcast.h" 29 30 namespace tensorflow { 31 32 // Position/length can be 32 or 64-bit integers 33 template <typename T> 34 class SubstrOp : public OpKernel { 35 public: 36 using OpKernel::OpKernel; 37 38 void Compute(OpKernelContext* context) override { 39 // Get inputs 40 const Tensor& input_tensor = context->input(0); 41 const Tensor& pos_tensor = context->input(1); 42 const Tensor& len_tensor = context->input(2); 43 const TensorShape& input_shape = input_tensor.shape(); 44 const TensorShape& pos_shape = pos_tensor.shape(); 45 46 bool is_scalar = TensorShapeUtils::IsScalar(pos_shape); 47 48 if (is_scalar || input_shape == pos_shape) { 49 // pos/len are either scalar or match the shape of input_tensor 50 // Do not need to do broadcasting 51 52 // Reshape input 53 auto input = input_tensor.flat<string>(); 54 // Allocate output 55 Tensor* output_tensor = nullptr; 56 OP_REQUIRES_OK(context, 57 context->allocate_output("output", input_tensor.shape(), 58 &output_tensor)); 59 auto output = output_tensor->flat<string>(); 60 if (is_scalar) { 61 // Perform Op with scalar pos/len 62 const T pos = 63 tensorflow::internal::SubtleMustCopy(pos_tensor.scalar<T>()()); 64 const T len = 65 tensorflow::internal::SubtleMustCopy(len_tensor.scalar<T>()()); 66 for (size_t i = 0; i < input_tensor.NumElements(); ++i) { 67 string in = input(i); 68 OP_REQUIRES( 69 context, FastBoundsCheck(pos, in.size() + 1), 70 errors::InvalidArgument("pos ", pos, " out of range for string", 71 "b'", in, "' at index ", i)); 72 output(i) = in.substr(pos, len); 73 } 74 } else { 75 // Perform Op element-wise with tensor pos/len 76 auto pos_flat = pos_tensor.flat<T>(); 77 auto len_flat = len_tensor.flat<T>(); 78 for (size_t i = 0; i < input_tensor.NumElements(); ++i) { 79 string in = input(i); 80 const T pos = tensorflow::internal::SubtleMustCopy(pos_flat(i)); 81 const T len = tensorflow::internal::SubtleMustCopy(len_flat(i)); 82 OP_REQUIRES( 83 context, FastBoundsCheck(pos, in.size() + 1), 84 errors::InvalidArgument("pos ", pos, " out of range for string", 85 "b'", in, "' at index ", i)); 86 output(i) = in.substr(pos, len); 87 } 88 } 89 } else { 90 // Perform op with broadcasting 91 // TODO: Use ternary broadcasting for once available in Eigen. Current 92 // implementation iterates through broadcasted ops element-wise; 93 // this should be parallelized. 94 95 // Create BCast helper with shape of input and pos/len 96 BCast bcast(BCast::FromShape(input_shape), BCast::FromShape(pos_shape)); 97 OP_REQUIRES(context, bcast.IsValid(), 98 errors::InvalidArgument( 99 "Incompatible shapes: ", input_shape.DebugString(), 100 " vs. ", pos_shape.DebugString())); 101 TensorShape output_shape = BCast::ToShape(bcast.result_shape()); 102 int ndims = output_shape.dims(); 103 Tensor* output_tensor = nullptr; 104 OP_REQUIRES_OK(context, context->allocate_output("output", output_shape, 105 &output_tensor)); 106 switch (ndims) { 107 case 1: { 108 // Reshape tensors according to BCast results 109 auto input = input_tensor.shaped<string, 1>(bcast.x_reshape()); 110 auto output = output_tensor->shaped<string, 1>(bcast.result_shape()); 111 auto pos_shaped = pos_tensor.shaped<T, 1>(bcast.y_reshape()); 112 auto len_shaped = len_tensor.shaped<T, 1>(bcast.y_reshape()); 113 114 // Allocate temporary buffer for broadcasted input tensor 115 Tensor input_buffer; 116 OP_REQUIRES_OK(context, context->allocate_temp( 117 DT_STRING, output_shape, &input_buffer)); 118 TTypes<string, 1>::Tensor input_bcast = 119 input_buffer.shaped<string, 1>(bcast.result_shape()); 120 input_bcast = 121 input.broadcast(BCast::ToIndexArray<1>(bcast.x_bcast())); 122 123 // Allocate temporary buffer for broadcasted position tensor 124 Tensor pos_buffer; 125 OP_REQUIRES_OK(context, 126 context->allocate_temp(DataTypeToEnum<T>::v(), 127 output_shape, &pos_buffer)); 128 typename TTypes<T, 1>::Tensor pos_bcast( 129 pos_buffer.shaped<T, 1>(bcast.result_shape())); 130 pos_bcast = 131 pos_shaped.broadcast(BCast::ToIndexArray<1>(bcast.y_bcast())); 132 133 // Allocate temporary buffer for broadcasted length tensor 134 Tensor len_buffer; 135 OP_REQUIRES_OK(context, 136 context->allocate_temp(DataTypeToEnum<T>::v(), 137 output_shape, &len_buffer)); 138 typename TTypes<T, 1>::Tensor len_bcast( 139 len_buffer.shaped<T, 1>(bcast.result_shape())); 140 len_bcast = 141 len_shaped.broadcast(BCast::ToIndexArray<1>(bcast.y_bcast())); 142 143 // Iterate through broadcasted tensors and perform substr 144 for (int i = 0; i < output_shape.dim_size(0); ++i) { 145 string in = input_bcast(i); 146 const T pos = tensorflow::internal::SubtleMustCopy(pos_bcast(i)); 147 const T len = tensorflow::internal::SubtleMustCopy(len_bcast(i)); 148 OP_REQUIRES( 149 context, FastBoundsCheck(pos, input_bcast(i).size() + 1), 150 errors::InvalidArgument("pos ", pos, " out of range for string", 151 "b'", in, "' at index ", i)); 152 output(i) = in.substr(pos, len); 153 } 154 break; 155 } 156 case 2: { 157 // Reshape tensors according to BCast results 158 auto input = input_tensor.shaped<string, 2>(bcast.x_reshape()); 159 auto output = output_tensor->shaped<string, 2>(bcast.result_shape()); 160 auto pos_shaped = pos_tensor.shaped<T, 2>(bcast.y_reshape()); 161 auto len_shaped = len_tensor.shaped<T, 2>(bcast.y_reshape()); 162 163 // Allocate temporary buffer for broadcasted input tensor 164 Tensor input_buffer; 165 OP_REQUIRES_OK(context, context->allocate_temp( 166 DT_STRING, output_shape, &input_buffer)); 167 TTypes<string, 2>::Tensor input_bcast = 168 input_buffer.shaped<string, 2>(bcast.result_shape()); 169 input_bcast = 170 input.broadcast(BCast::ToIndexArray<2>(bcast.x_bcast())); 171 172 // Allocate temporary buffer for broadcasted position tensor 173 Tensor pos_buffer; 174 OP_REQUIRES_OK(context, 175 context->allocate_temp(DataTypeToEnum<T>::v(), 176 output_shape, &pos_buffer)); 177 typename TTypes<T, 2>::Tensor pos_bcast( 178 pos_buffer.shaped<T, 2>(bcast.result_shape())); 179 pos_bcast = 180 pos_shaped.broadcast(BCast::ToIndexArray<2>(bcast.y_bcast())); 181 182 // Allocate temporary buffer for broadcasted length tensor 183 Tensor len_buffer; 184 OP_REQUIRES_OK(context, 185 context->allocate_temp(DataTypeToEnum<T>::v(), 186 output_shape, &len_buffer)); 187 typename TTypes<T, 2>::Tensor len_bcast( 188 len_buffer.shaped<T, 2>(bcast.result_shape())); 189 len_bcast = 190 len_shaped.broadcast(BCast::ToIndexArray<2>(bcast.y_bcast())); 191 192 // Iterate through broadcasted tensors and perform substr 193 for (int i = 0; i < output_shape.dim_size(0); ++i) { 194 for (int j = 0; j < output_shape.dim_size(1); ++j) { 195 string in = input_bcast(i, j); 196 const T pos = 197 tensorflow::internal::SubtleMustCopy(pos_bcast(i, j)); 198 const T len = 199 tensorflow::internal::SubtleMustCopy(len_bcast(i, j)); 200 OP_REQUIRES(context, FastBoundsCheck(pos, in.size() + 1), 201 errors::InvalidArgument( 202 "pos ", pos, " out of range for ", "string b'", 203 in, "' at index (", i, ", ", j, ")")); 204 output(i, j) = in.substr(pos, len); 205 } 206 } 207 break; 208 } 209 default: { 210 context->SetStatus(errors::Unimplemented( 211 "Substr broadcast not implemented for ", ndims, " dimensions")); 212 } 213 } 214 } 215 } 216 }; 217 218 #define REGISTER_SUBSTR(type) \ 219 REGISTER_KERNEL_BUILDER( \ 220 Name("Substr").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ 221 SubstrOp<type>); 222 REGISTER_SUBSTR(int32); 223 REGISTER_SUBSTR(int64); 224 } // namespace tensorflow 225