Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/framework/op_kernel.h"
     17 #include "tensorflow/core/framework/tensor.h"
     18 #include "tensorflow/core/framework/tensor_shape.h"
     19 #include "tensorflow/core/framework/types.h"
     20 #include "tensorflow/core/lib/core/errors.h"
     21 #include "tensorflow/core/lib/strings/numbers.h"
     22 #include "tensorflow/core/lib/strings/str_util.h"
     23 
     24 namespace tensorflow {
     25 
     26 template <typename T, typename Tlabel>
     27 class DecodeLibsvmOp : public OpKernel {
     28  public:
     29   explicit DecodeLibsvmOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
     30     OP_REQUIRES_OK(ctx, ctx->GetAttr("num_features", &num_features_));
     31     OP_REQUIRES(ctx, (num_features_ >= 1),
     32                 errors::InvalidArgument("Invalid number of features \"",
     33                                         num_features_, "\""));
     34   }
     35 
     36   void Compute(OpKernelContext* ctx) override {
     37     const Tensor* input_tensor;
     38     OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor));
     39     const auto& input_flat = input_tensor->flat<string>();
     40 
     41     Tensor* label_tensor;
     42     OP_REQUIRES_OK(
     43         ctx, ctx->allocate_output(0, input_tensor->shape(), &label_tensor));
     44     auto label = label_tensor->flat<Tlabel>();
     45 
     46     std::vector<T> out_values;
     47     std::vector<std::pair<int64, int64>> out_indices;
     48     for (int i = 0; i < input_flat.size(); ++i) {
     49       StringPiece line(input_flat(i));
     50       str_util::RemoveWhitespaceContext(&line);
     51 
     52       StringPiece piece;
     53       OP_REQUIRES(ctx, str_util::ConsumeNonWhitespace(&line, &piece),
     54                   errors::InvalidArgument("No label found for input[", i,
     55                                           "]: \"", input_flat(i), "\""));
     56 
     57       Tlabel label_value;
     58       OP_REQUIRES(ctx,
     59                   strings::SafeStringToNumeric<Tlabel>(piece, &label_value),
     60                   errors::InvalidArgument("Label format incorrect: ", piece));
     61 
     62       label(i) = label_value;
     63 
     64       str_util::RemoveLeadingWhitespace(&line);
     65       while (str_util::ConsumeNonWhitespace(&line, &piece)) {
     66         size_t p = piece.find(':');
     67         OP_REQUIRES(ctx, (p != StringPiece::npos),
     68                     errors::InvalidArgument("Invalid feature \"", piece, "\""));
     69 
     70         int64 feature_index;
     71         OP_REQUIRES(
     72             ctx, strings::safe_strto64(piece.substr(0, p), &feature_index),
     73             errors::InvalidArgument("Feature format incorrect: ", piece));
     74         OP_REQUIRES(ctx, (feature_index >= 0),
     75                     errors::InvalidArgument(
     76                         "Feature index should be >= 0, got ", feature_index));
     77 
     78         T feature_value;
     79         OP_REQUIRES(
     80 
     81             ctx,
     82             strings::SafeStringToNumeric<T>(piece.substr(p + 1),
     83                                             &feature_value),
     84             errors::InvalidArgument("Feature format incorrect: ", piece));
     85 
     86         out_values.emplace_back(feature_value);
     87         out_indices.emplace_back(std::pair<int64, int64>(i, feature_index));
     88 
     89         str_util::RemoveLeadingWhitespace(&line);
     90       }
     91     }
     92 
     93     Tensor* indices_tensor;
     94     OP_REQUIRES_OK(ctx, ctx->allocate_output(
     95                             1,
     96                             TensorShape({static_cast<int64>(out_indices.size()),
     97                                          input_tensor->shape().dims() + 1}),
     98                             &indices_tensor));
     99     auto indices = indices_tensor->matrix<int64>();
    100     // Translate flat index to shaped index like np.unravel_index
    101     // Calculate factors for each dimension
    102     std::vector<int64> factors(input_tensor->shape().dims());
    103     factors[input_tensor->shape().dims() - 1] = 1;
    104     for (int j = input_tensor->shape().dims() - 2; j >= 0; j--) {
    105       factors[j] = factors[j + 1] * input_tensor->shape().dim_size(j + 1);
    106     }
    107     for (int i = 0; i < out_indices.size(); i++) {
    108       indices(i, 0) = out_indices[i].first;
    109       int64 value = out_indices[i].first;
    110       for (int j = 0; j < input_tensor->shape().dims(); j++) {
    111         indices(i, j) = value / factors[j];
    112         value = value % factors[j];
    113       }
    114       indices(i, input_tensor->shape().dims()) = out_indices[i].second;
    115     }
    116 
    117     Tensor* values_tensor;
    118     OP_REQUIRES_OK(ctx,
    119                    ctx->allocate_output(
    120                        2, TensorShape({static_cast<int64>(out_values.size())}),
    121                        &values_tensor));
    122     auto values = values_tensor->vec<T>();
    123     std::copy_n(out_values.begin(), out_values.size(), &values(0));
    124 
    125     Tensor* shape_tensor;
    126     OP_REQUIRES_OK(ctx, ctx->allocate_output(
    127                             3, TensorShape({input_tensor->shape().dims() + 1}),
    128                             &shape_tensor));
    129     auto shape = shape_tensor->flat<int64>();
    130     for (int i = 0; i < input_tensor->shape().dims(); i++) {
    131       shape(i) = input_tensor->shape().dim_size(i);
    132     }
    133     shape(input_tensor->shape().dims()) = num_features_;
    134   }
    135 
    136  private:
    137   int64 num_features_;
    138 };
    139 
    140 #define REGISTER_KERNEL(type)                                         \
    141   REGISTER_KERNEL_BUILDER(Name("DecodeLibsvm")                        \
    142                               .Device(DEVICE_CPU)                     \
    143                               .TypeConstraint<type>("dtype")          \
    144                               .TypeConstraint<int32>("label_dtype"),  \
    145                           DecodeLibsvmOp<type, int32>);               \
    146   REGISTER_KERNEL_BUILDER(Name("DecodeLibsvm")                        \
    147                               .Device(DEVICE_CPU)                     \
    148                               .TypeConstraint<type>("dtype")          \
    149                               .TypeConstraint<int64>("label_dtype"),  \
    150                           DecodeLibsvmOp<type, int64>);               \
    151   REGISTER_KERNEL_BUILDER(Name("DecodeLibsvm")                        \
    152                               .Device(DEVICE_CPU)                     \
    153                               .TypeConstraint<type>("dtype")          \
    154                               .TypeConstraint<float>("label_dtype"),  \
    155                           DecodeLibsvmOp<type, float>);               \
    156   REGISTER_KERNEL_BUILDER(Name("DecodeLibsvm")                        \
    157                               .Device(DEVICE_CPU)                     \
    158                               .TypeConstraint<type>("dtype")          \
    159                               .TypeConstraint<double>("label_dtype"), \
    160                           DecodeLibsvmOp<type, double>);
    161 
    162 REGISTER_KERNEL(float);
    163 REGISTER_KERNEL(double);
    164 REGISTER_KERNEL(int32);
    165 REGISTER_KERNEL(int64);
    166 #undef REGISTER_KERNEL
    167 
    168 }  // namespace tensorflow
    169