Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <string>
     17 
     18 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     19 #include "tensorflow/core/framework/kernel_def_builder.h"
     20 #include "tensorflow/core/framework/op.h"
     21 #include "tensorflow/core/framework/op_kernel.h"
     22 #include "tensorflow/core/framework/tensor.h"
     23 #include "tensorflow/core/framework/tensor_shape.h"
     24 #include "tensorflow/core/framework/tensor_types.h"
     25 #include "tensorflow/core/framework/types.h"
     26 #include "tensorflow/core/kernels/bounds_check.h"
     27 #include "tensorflow/core/lib/core/errors.h"
     28 #include "tensorflow/core/util/bcast.h"
     29 
     30 namespace tensorflow {
     31 
     32 // Position/length can be 32 or 64-bit integers
     33 template <typename T>
     34 class SubstrOp : public OpKernel {
     35  public:
     36   using OpKernel::OpKernel;
     37 
     38   void Compute(OpKernelContext* context) override {
     39     // Get inputs
     40     const Tensor& input_tensor = context->input(0);
     41     const Tensor& pos_tensor = context->input(1);
     42     const Tensor& len_tensor = context->input(2);
     43     const TensorShape& input_shape = input_tensor.shape();
     44     const TensorShape& pos_shape = pos_tensor.shape();
     45 
     46     bool is_scalar = TensorShapeUtils::IsScalar(pos_shape);
     47 
     48     if (is_scalar || input_shape == pos_shape) {
     49       // pos/len are either scalar or match the shape of input_tensor
     50       // Do not need to do broadcasting
     51 
     52       // Reshape input
     53       auto input = input_tensor.flat<string>();
     54       // Allocate output
     55       Tensor* output_tensor = nullptr;
     56       OP_REQUIRES_OK(context,
     57                      context->allocate_output("output", input_tensor.shape(),
     58                                               &output_tensor));
     59       auto output = output_tensor->flat<string>();
     60       if (is_scalar) {
     61         // Perform Op with scalar pos/len
     62         const T pos =
     63             tensorflow::internal::SubtleMustCopy(pos_tensor.scalar<T>()());
     64         const T len =
     65             tensorflow::internal::SubtleMustCopy(len_tensor.scalar<T>()());
     66         for (size_t i = 0; i < input_tensor.NumElements(); ++i) {
     67           string in = input(i);
     68           OP_REQUIRES(
     69               context, FastBoundsCheck(pos, in.size() + 1),
     70               errors::InvalidArgument("pos ", pos, " out of range for string",
     71                                       "b'", in, "' at index ", i));
     72           output(i) = in.substr(pos, len);
     73         }
     74       } else {
     75         // Perform Op element-wise with tensor pos/len
     76         auto pos_flat = pos_tensor.flat<T>();
     77         auto len_flat = len_tensor.flat<T>();
     78         for (size_t i = 0; i < input_tensor.NumElements(); ++i) {
     79           string in = input(i);
     80           const T pos = tensorflow::internal::SubtleMustCopy(pos_flat(i));
     81           const T len = tensorflow::internal::SubtleMustCopy(len_flat(i));
     82           OP_REQUIRES(
     83               context, FastBoundsCheck(pos, in.size() + 1),
     84               errors::InvalidArgument("pos ", pos, " out of range for string",
     85                                       "b'", in, "' at index ", i));
     86           output(i) = in.substr(pos, len);
     87         }
     88       }
     89     } else {
     90       // Perform op with broadcasting
     91       // TODO: Use ternary broadcasting for once available in Eigen. Current
     92       //       implementation iterates through broadcasted ops element-wise;
     93       //       this should be parallelized.
     94 
     95       // Create BCast helper with shape of input and pos/len
     96       BCast bcast(BCast::FromShape(input_shape), BCast::FromShape(pos_shape));
     97       OP_REQUIRES(context, bcast.IsValid(),
     98                   errors::InvalidArgument(
     99                       "Incompatible shapes: ", input_shape.DebugString(),
    100                       " vs. ", pos_shape.DebugString()));
    101       TensorShape output_shape = BCast::ToShape(bcast.result_shape());
    102       int ndims = output_shape.dims();
    103       Tensor* output_tensor = nullptr;
    104       OP_REQUIRES_OK(context, context->allocate_output("output", output_shape,
    105                                                        &output_tensor));
    106       switch (ndims) {
    107         case 1: {
    108           // Reshape tensors according to BCast results
    109           auto input = input_tensor.shaped<string, 1>(bcast.x_reshape());
    110           auto output = output_tensor->shaped<string, 1>(bcast.result_shape());
    111           auto pos_shaped = pos_tensor.shaped<T, 1>(bcast.y_reshape());
    112           auto len_shaped = len_tensor.shaped<T, 1>(bcast.y_reshape());
    113 
    114           // Allocate temporary buffer for broadcasted input tensor
    115           Tensor input_buffer;
    116           OP_REQUIRES_OK(context, context->allocate_temp(
    117                                       DT_STRING, output_shape, &input_buffer));
    118           TTypes<string, 1>::Tensor input_bcast =
    119               input_buffer.shaped<string, 1>(bcast.result_shape());
    120           input_bcast =
    121               input.broadcast(BCast::ToIndexArray<1>(bcast.x_bcast()));
    122 
    123           // Allocate temporary buffer for broadcasted position tensor
    124           Tensor pos_buffer;
    125           OP_REQUIRES_OK(context,
    126                          context->allocate_temp(DataTypeToEnum<T>::v(),
    127                                                 output_shape, &pos_buffer));
    128           typename TTypes<T, 1>::Tensor pos_bcast(
    129               pos_buffer.shaped<T, 1>(bcast.result_shape()));
    130           pos_bcast =
    131               pos_shaped.broadcast(BCast::ToIndexArray<1>(bcast.y_bcast()));
    132 
    133           // Allocate temporary buffer for broadcasted length tensor
    134           Tensor len_buffer;
    135           OP_REQUIRES_OK(context,
    136                          context->allocate_temp(DataTypeToEnum<T>::v(),
    137                                                 output_shape, &len_buffer));
    138           typename TTypes<T, 1>::Tensor len_bcast(
    139               len_buffer.shaped<T, 1>(bcast.result_shape()));
    140           len_bcast =
    141               len_shaped.broadcast(BCast::ToIndexArray<1>(bcast.y_bcast()));
    142 
    143           // Iterate through broadcasted tensors and perform substr
    144           for (int i = 0; i < output_shape.dim_size(0); ++i) {
    145             string in = input_bcast(i);
    146             const T pos = tensorflow::internal::SubtleMustCopy(pos_bcast(i));
    147             const T len = tensorflow::internal::SubtleMustCopy(len_bcast(i));
    148             OP_REQUIRES(
    149                 context, FastBoundsCheck(pos, input_bcast(i).size() + 1),
    150                 errors::InvalidArgument("pos ", pos, " out of range for string",
    151                                         "b'", in, "' at index ", i));
    152             output(i) = in.substr(pos, len);
    153           }
    154           break;
    155         }
    156         case 2: {
    157           // Reshape tensors according to BCast results
    158           auto input = input_tensor.shaped<string, 2>(bcast.x_reshape());
    159           auto output = output_tensor->shaped<string, 2>(bcast.result_shape());
    160           auto pos_shaped = pos_tensor.shaped<T, 2>(bcast.y_reshape());
    161           auto len_shaped = len_tensor.shaped<T, 2>(bcast.y_reshape());
    162 
    163           // Allocate temporary buffer for broadcasted input tensor
    164           Tensor input_buffer;
    165           OP_REQUIRES_OK(context, context->allocate_temp(
    166                                       DT_STRING, output_shape, &input_buffer));
    167           TTypes<string, 2>::Tensor input_bcast =
    168               input_buffer.shaped<string, 2>(bcast.result_shape());
    169           input_bcast =
    170               input.broadcast(BCast::ToIndexArray<2>(bcast.x_bcast()));
    171 
    172           // Allocate temporary buffer for broadcasted position tensor
    173           Tensor pos_buffer;
    174           OP_REQUIRES_OK(context,
    175                          context->allocate_temp(DataTypeToEnum<T>::v(),
    176                                                 output_shape, &pos_buffer));
    177           typename TTypes<T, 2>::Tensor pos_bcast(
    178               pos_buffer.shaped<T, 2>(bcast.result_shape()));
    179           pos_bcast =
    180               pos_shaped.broadcast(BCast::ToIndexArray<2>(bcast.y_bcast()));
    181 
    182           // Allocate temporary buffer for broadcasted length tensor
    183           Tensor len_buffer;
    184           OP_REQUIRES_OK(context,
    185                          context->allocate_temp(DataTypeToEnum<T>::v(),
    186                                                 output_shape, &len_buffer));
    187           typename TTypes<T, 2>::Tensor len_bcast(
    188               len_buffer.shaped<T, 2>(bcast.result_shape()));
    189           len_bcast =
    190               len_shaped.broadcast(BCast::ToIndexArray<2>(bcast.y_bcast()));
    191 
    192           // Iterate through broadcasted tensors and perform substr
    193           for (int i = 0; i < output_shape.dim_size(0); ++i) {
    194             for (int j = 0; j < output_shape.dim_size(1); ++j) {
    195               string in = input_bcast(i, j);
    196               const T pos =
    197                   tensorflow::internal::SubtleMustCopy(pos_bcast(i, j));
    198               const T len =
    199                   tensorflow::internal::SubtleMustCopy(len_bcast(i, j));
    200               OP_REQUIRES(context, FastBoundsCheck(pos, in.size() + 1),
    201                           errors::InvalidArgument(
    202                               "pos ", pos, " out of range for ", "string b'",
    203                               in, "' at index (", i, ", ", j, ")"));
    204               output(i, j) = in.substr(pos, len);
    205             }
    206           }
    207           break;
    208         }
    209         default: {
    210           context->SetStatus(errors::Unimplemented(
    211               "Substr broadcast not implemented for ", ndims, " dimensions"));
    212         }
    213       }
    214     }
    215   }
    216 };
    217 
    218 #define REGISTER_SUBSTR(type)                                      \
    219   REGISTER_KERNEL_BUILDER(                                         \
    220       Name("Substr").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
    221       SubstrOp<type>);
    222 REGISTER_SUBSTR(int32);
    223 REGISTER_SUBSTR(int64);
    224 }  // namespace tensorflow
    225