Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/sdca_ops.cc.
     17 
     18 #define EIGEN_USE_THREADS
     19 
     20 #include <stdint.h>
     21 #include <atomic>
     22 #include <limits>
     23 #include <memory>
     24 #include <new>
     25 #include <string>
     26 #include <vector>
     27 
     28 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     29 #include "tensorflow/core/framework/device_base.h"
     30 #include "tensorflow/core/framework/kernel_def_builder.h"
     31 #include "tensorflow/core/framework/op.h"
     32 #include "tensorflow/core/framework/op_def_builder.h"
     33 #include "tensorflow/core/framework/op_kernel.h"
     34 #include "tensorflow/core/framework/tensor.h"
     35 #include "tensorflow/core/framework/tensor_shape.h"
     36 #include "tensorflow/core/framework/tensor_types.h"
     37 #include "tensorflow/core/framework/types.h"
     38 #include "tensorflow/core/kernels/hinge-loss.h"
     39 #include "tensorflow/core/kernels/logistic-loss.h"
     40 #include "tensorflow/core/kernels/loss.h"
     41 #include "tensorflow/core/kernels/sdca_internal.h"
     42 #include "tensorflow/core/kernels/smooth-hinge-loss.h"
     43 #include "tensorflow/core/kernels/squared-loss.h"
     44 #include "tensorflow/core/lib/core/coding.h"
     45 #include "tensorflow/core/lib/core/errors.h"
     46 #include "tensorflow/core/lib/core/status.h"
     47 #include "tensorflow/core/lib/core/stringpiece.h"
     48 #include "tensorflow/core/lib/gtl/inlined_vector.h"
     49 #include "tensorflow/core/lib/strings/stringprintf.h"
     50 #include "tensorflow/core/platform/fingerprint.h"
     51 #include "tensorflow/core/platform/macros.h"
     52 #include "tensorflow/core/platform/mutex.h"
     53 #include "tensorflow/core/platform/types.h"
     54 #include "tensorflow/core/util/work_sharder.h"
     55 
     56 namespace tensorflow {
     57 
     58 namespace {
     59 
     60 using sdca::Example;
     61 using sdca::Examples;
     62 using sdca::ExampleStatistics;
     63 using sdca::ModelWeights;
     64 using sdca::Regularizations;
     65 
     66 struct ComputeOptions {
     67   explicit ComputeOptions(OpKernelConstruction* const context) {
     68     string loss_type;
     69     OP_REQUIRES_OK(context, context->GetAttr("loss_type", &loss_type));
     70     if (loss_type == "logistic_loss") {
     71       loss_updater.reset(new LogisticLossUpdater);
     72     } else if (loss_type == "squared_loss") {
     73       loss_updater.reset(new SquaredLossUpdater);
     74     } else if (loss_type == "hinge_loss") {
     75       loss_updater.reset(new HingeLossUpdater);
     76     } else if (loss_type == "smooth_hinge_loss") {
     77       loss_updater.reset(new SmoothHingeLossUpdater);
     78     } else {
     79       OP_REQUIRES(
     80           context, false,
     81           errors::InvalidArgument("Unsupported loss type: ", loss_type));
     82     }
     83     OP_REQUIRES_OK(context, context->GetAttr("adaptative", &adaptative));
     84     OP_REQUIRES_OK(
     85         context, context->GetAttr("num_sparse_features", &num_sparse_features));
     86     OP_REQUIRES_OK(context, context->GetAttr("num_sparse_features_with_values",
     87                                              &num_sparse_features_with_values));
     88     OP_REQUIRES_OK(context,
     89                    context->GetAttr("num_dense_features", &num_dense_features));
     90     OP_REQUIRES(
     91         context, num_sparse_features + num_dense_features > 0,
     92         errors::InvalidArgument("Requires at least one feature to train."));
     93 
     94     OP_REQUIRES(context,
     95                 static_cast<int64>(num_sparse_features) +
     96                         static_cast<int64>(num_dense_features) <=
     97                     std::numeric_limits<int>::max(),
     98                 errors::InvalidArgument(
     99                     strings::Printf("Too many feature groups: %lld > %d",
    100                                     static_cast<int64>(num_sparse_features) +
    101                                         static_cast<int64>(num_dense_features),
    102                                     std::numeric_limits<int>::max())));
    103     OP_REQUIRES_OK(
    104         context, context->GetAttr("num_loss_partitions", &num_loss_partitions));
    105     OP_REQUIRES_OK(context, context->GetAttr("num_inner_iterations",
    106                                              &num_inner_iterations));
    107     OP_REQUIRES_OK(context, regularizations.Initialize(context));
    108   }
    109 
    110   std::unique_ptr<DualLossUpdater> loss_updater;
    111   int num_sparse_features = 0;
    112   int num_sparse_features_with_values = 0;
    113   int num_dense_features = 0;
    114   int num_inner_iterations = 0;
    115   int num_loss_partitions = 0;
    116   bool adaptative = false;
    117   Regularizations regularizations;
    118 };
    119 
    120 // TODO(shengx): The helper classes/methods are changed to support multiclass
    121 // SDCA, which lead to changes within this function. Need to revisit the
    122 // convergence once the multiclass SDCA is in.
    123 void DoCompute(const ComputeOptions& options, OpKernelContext* const context) {
    124   ModelWeights model_weights;
    125   OP_REQUIRES_OK(context, model_weights.Initialize(context));
    126 
    127   Examples examples;
    128   OP_REQUIRES_OK(
    129       context,
    130       examples.Initialize(context, model_weights, options.num_sparse_features,
    131                           options.num_sparse_features_with_values,
    132                           options.num_dense_features));
    133 
    134   const Tensor* example_state_data_t;
    135   OP_REQUIRES_OK(context,
    136                  context->input("example_state_data", &example_state_data_t));
    137   TensorShape expected_example_state_shape({examples.num_examples(), 4});
    138   OP_REQUIRES(context,
    139               example_state_data_t->shape() == expected_example_state_shape,
    140               errors::InvalidArgument(
    141                   "Expected shape ", expected_example_state_shape.DebugString(),
    142                   " for example_state_data, got ",
    143                   example_state_data_t->shape().DebugString()));
    144 
    145   Tensor mutable_example_state_data_t(*example_state_data_t);
    146   auto example_state_data = mutable_example_state_data_t.matrix<float>();
    147   OP_REQUIRES_OK(context, context->set_output("out_example_state_data",
    148                                               mutable_example_state_data_t));
    149 
    150   if (options.adaptative) {
    151     OP_REQUIRES_OK(context,
    152                    examples.SampleAdaptativeProbabilities(
    153                        options.num_loss_partitions, options.regularizations,
    154                        model_weights, example_state_data, options.loss_updater,
    155                        /*num_weight_vectors =*/1));
    156   }
    157 
    158   mutex mu;
    159   Status train_step_status GUARDED_BY(mu);
    160   std::atomic<std::int64_t> atomic_index(-1);
    161   auto train_step = [&](const int64 begin, const int64 end) {
    162     // The static_cast here is safe since begin and end can be at most
    163     // num_examples which is an int.
    164     for (int id = static_cast<int>(begin); id < end; ++id) {
    165       const int64 example_index =
    166           examples.sampled_index(++atomic_index, options.adaptative);
    167       const Example& example = examples.example(example_index);
    168       const float dual = example_state_data(example_index, 0);
    169       const float example_weight = example.example_weight();
    170       float example_label = example.example_label();
    171       const Status conversion_status =
    172           options.loss_updater->ConvertLabel(&example_label);
    173       if (!conversion_status.ok()) {
    174         mutex_lock l(mu);
    175         train_step_status = conversion_status;
    176         // Return from this worker thread - the calling thread is
    177         // responsible for checking context status and returning on error.
    178         return;
    179       }
    180 
    181       // Compute wx, example norm weighted by regularization, dual loss,
    182       // primal loss.
    183       // For binary SDCA, num_weight_vectors should be one.
    184       const ExampleStatistics example_statistics =
    185           example.ComputeWxAndWeightedExampleNorm(
    186               options.num_loss_partitions, model_weights,
    187               options.regularizations, 1 /* num_weight_vectors */);
    188 
    189       const double new_dual = options.loss_updater->ComputeUpdatedDual(
    190           options.num_loss_partitions, example_label, example_weight, dual,
    191           example_statistics.wx[0], example_statistics.normalized_squared_norm);
    192 
    193       // Compute new weights.
    194       const double normalized_bounded_dual_delta =
    195           (new_dual - dual) * example_weight /
    196           options.regularizations.symmetric_l2();
    197       model_weights.UpdateDeltaWeights(
    198           context->eigen_cpu_device(), example,
    199           std::vector<double>{normalized_bounded_dual_delta});
    200 
    201       // Update example data.
    202       example_state_data(example_index, 0) = new_dual;
    203       example_state_data(example_index, 1) =
    204           options.loss_updater->ComputePrimalLoss(
    205               example_statistics.prev_wx[0], example_label, example_weight);
    206       example_state_data(example_index, 2) =
    207           options.loss_updater->ComputeDualLoss(dual, example_label,
    208                                                 example_weight);
    209       example_state_data(example_index, 3) = example_weight;
    210     }
    211   };
    212   // TODO(sibyl-Aix6ihai): Tune this properly based on sparsity of the data,
    213   // number of cpus, and cost per example.
    214   const int64 kCostPerUnit = examples.num_features();
    215   const DeviceBase::CpuWorkerThreads& worker_threads =
    216       *context->device()->tensorflow_cpu_worker_threads();
    217 
    218   Shard(worker_threads.num_threads, worker_threads.workers,
    219         examples.num_examples(), kCostPerUnit, train_step);
    220   OP_REQUIRES_OK(context, train_step_status);
    221 }
    222 
    223 }  // namespace
    224 
    225 class SdcaOptimizer : public OpKernel {
    226  public:
    227   explicit SdcaOptimizer(OpKernelConstruction* const context)
    228       : OpKernel(context), options_(context) {}
    229 
    230   void Compute(OpKernelContext* context) override {
    231     DoCompute(options_, context);
    232   }
    233 
    234  private:
    235   // TODO(sibyl-Aix6ihai): We could use the type-constraint on loss_type, and
    236   // template the entire class to avoid the virtual table lookup penalty in
    237   // the inner loop.
    238   ComputeOptions options_;
    239 };
    240 REGISTER_KERNEL_BUILDER(Name("SdcaOptimizer").Device(DEVICE_CPU),
    241                         SdcaOptimizer);
    242 
    243 class SdcaShrinkL1 : public OpKernel {
    244  public:
    245   explicit SdcaShrinkL1(OpKernelConstruction* const context)
    246       : OpKernel(context) {
    247     OP_REQUIRES_OK(context, regularizations_.Initialize(context));
    248   }
    249 
    250   void Compute(OpKernelContext* context) override {
    251     OpMutableInputList weights_inputs;
    252     OP_REQUIRES_OK(context,
    253                    context->mutable_input_list("weights", &weights_inputs));
    254 
    255     auto do_work = [&](const int64 begin, const int64 end) {
    256       for (int i = begin; i < end; ++i) {
    257         auto prox_w = weights_inputs.at(i, /*lock_held=*/true).flat<float>();
    258         prox_w.device(context->eigen_cpu_device()) =
    259             regularizations_.EigenShrinkVector(prox_w);
    260       }
    261     };
    262 
    263     if (weights_inputs.size() > 0) {
    264       int64 num_weights = 0;
    265       for (int i = 0; i < weights_inputs.size(); ++i) {
    266         num_weights += weights_inputs.at(i, /*lock_held=*/true).NumElements();
    267       }
    268       // TODO(sibyl-Aix6ihai): Tune this value.
    269       const int64 kCostPerUnit = (num_weights * 50) / weights_inputs.size();
    270       const DeviceBase::CpuWorkerThreads& worker_threads =
    271           *context->device()->tensorflow_cpu_worker_threads();
    272       Shard(worker_threads.num_threads, worker_threads.workers,
    273             weights_inputs.size(), kCostPerUnit, do_work);
    274     }
    275   }
    276 
    277  private:
    278   Regularizations regularizations_;
    279 };
    280 REGISTER_KERNEL_BUILDER(Name("SdcaShrinkL1").Device(DEVICE_CPU), SdcaShrinkL1);
    281 
    282 // Computes platform independent, compact and unique (with very high
    283 // probability) representation of an example id. It shouldn't be put in
    284 // persistent storage, as its implementation may change in the future.
    285 //
    286 // The current probability of at least one collision for 1B example_ids is
    287 // approximately 10^-21 (ie 2^60 / 2^129).
    288 class SdcaFprint : public OpKernel {
    289  public:
    290   explicit SdcaFprint(OpKernelConstruction* const context)
    291       : OpKernel(context) {}
    292 
    293   void Compute(OpKernelContext* context) override {
    294     const Tensor& input = context->input(0);
    295     OP_REQUIRES(context, TensorShapeUtils::IsVector(input.shape()),
    296                 errors::InvalidArgument("Input must be a vector, got shape ",
    297                                         input.shape().DebugString()));
    298     Tensor* out;
    299     const int64 num_elements = input.NumElements();
    300     OP_REQUIRES_OK(context, context->allocate_output(
    301                                 0, TensorShape({num_elements, 2}), &out));
    302 
    303     const auto in_values = input.flat<string>();
    304     auto out_values = out->matrix<int64>();
    305 
    306     for (int64 i = 0; i < num_elements; ++i) {
    307       const Fprint128 fprint = Fingerprint128(in_values(i));
    308       // Never return 0 or 1 as the first value of the hash to allow these to
    309       // safely be used as sentinel values (e.g. dense hash table empty key).
    310       out_values(i, 0) = TF_PREDICT_TRUE(fprint.low64 >= 2)
    311                              ? fprint.low64
    312                              : fprint.low64 + ~static_cast<uint64>(1);
    313       out_values(i, 1) = fprint.high64;
    314     }
    315   }
    316 };
    317 REGISTER_KERNEL_BUILDER(Name("SdcaFprint").Device(DEVICE_CPU), SdcaFprint);
    318 
    319 }  // namespace tensorflow
    320