1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/sdca_ops.cc. 17 18 #define EIGEN_USE_THREADS 19 20 #include <stdint.h> 21 #include <atomic> 22 #include <limits> 23 #include <memory> 24 #include <new> 25 #include <string> 26 #include <vector> 27 28 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 29 #include "tensorflow/core/framework/device_base.h" 30 #include "tensorflow/core/framework/kernel_def_builder.h" 31 #include "tensorflow/core/framework/op.h" 32 #include "tensorflow/core/framework/op_def_builder.h" 33 #include "tensorflow/core/framework/op_kernel.h" 34 #include "tensorflow/core/framework/tensor.h" 35 #include "tensorflow/core/framework/tensor_shape.h" 36 #include "tensorflow/core/framework/tensor_types.h" 37 #include "tensorflow/core/framework/types.h" 38 #include "tensorflow/core/kernels/hinge-loss.h" 39 #include "tensorflow/core/kernels/logistic-loss.h" 40 #include "tensorflow/core/kernels/loss.h" 41 #include "tensorflow/core/kernels/sdca_internal.h" 42 #include "tensorflow/core/kernels/smooth-hinge-loss.h" 43 #include "tensorflow/core/kernels/squared-loss.h" 44 #include "tensorflow/core/lib/core/coding.h" 45 #include "tensorflow/core/lib/core/errors.h" 46 #include "tensorflow/core/lib/core/status.h" 47 #include "tensorflow/core/lib/core/stringpiece.h" 48 #include "tensorflow/core/lib/gtl/inlined_vector.h" 49 #include "tensorflow/core/lib/strings/stringprintf.h" 50 #include "tensorflow/core/platform/fingerprint.h" 51 #include "tensorflow/core/platform/macros.h" 52 #include "tensorflow/core/platform/mutex.h" 53 #include "tensorflow/core/platform/types.h" 54 #include "tensorflow/core/util/work_sharder.h" 55 56 namespace tensorflow { 57 58 namespace { 59 60 using sdca::Example; 61 using sdca::Examples; 62 using sdca::ExampleStatistics; 63 using sdca::ModelWeights; 64 using sdca::Regularizations; 65 66 struct ComputeOptions { 67 explicit ComputeOptions(OpKernelConstruction* const context) { 68 string loss_type; 69 OP_REQUIRES_OK(context, context->GetAttr("loss_type", &loss_type)); 70 if (loss_type == "logistic_loss") { 71 loss_updater.reset(new LogisticLossUpdater); 72 } else if (loss_type == "squared_loss") { 73 loss_updater.reset(new SquaredLossUpdater); 74 } else if (loss_type == "hinge_loss") { 75 loss_updater.reset(new HingeLossUpdater); 76 } else if (loss_type == "smooth_hinge_loss") { 77 loss_updater.reset(new SmoothHingeLossUpdater); 78 } else { 79 OP_REQUIRES( 80 context, false, 81 errors::InvalidArgument("Unsupported loss type: ", loss_type)); 82 } 83 OP_REQUIRES_OK(context, context->GetAttr("adaptative", &adaptative)); 84 OP_REQUIRES_OK( 85 context, context->GetAttr("num_sparse_features", &num_sparse_features)); 86 OP_REQUIRES_OK(context, context->GetAttr("num_sparse_features_with_values", 87 &num_sparse_features_with_values)); 88 OP_REQUIRES_OK(context, 89 context->GetAttr("num_dense_features", &num_dense_features)); 90 OP_REQUIRES( 91 context, num_sparse_features + num_dense_features > 0, 92 errors::InvalidArgument("Requires at least one feature to train.")); 93 94 OP_REQUIRES(context, 95 static_cast<int64>(num_sparse_features) + 96 static_cast<int64>(num_dense_features) <= 97 std::numeric_limits<int>::max(), 98 errors::InvalidArgument( 99 strings::Printf("Too many feature groups: %lld > %d", 100 static_cast<int64>(num_sparse_features) + 101 static_cast<int64>(num_dense_features), 102 std::numeric_limits<int>::max()))); 103 OP_REQUIRES_OK( 104 context, context->GetAttr("num_loss_partitions", &num_loss_partitions)); 105 OP_REQUIRES_OK(context, context->GetAttr("num_inner_iterations", 106 &num_inner_iterations)); 107 OP_REQUIRES_OK(context, regularizations.Initialize(context)); 108 } 109 110 std::unique_ptr<DualLossUpdater> loss_updater; 111 int num_sparse_features = 0; 112 int num_sparse_features_with_values = 0; 113 int num_dense_features = 0; 114 int num_inner_iterations = 0; 115 int num_loss_partitions = 0; 116 bool adaptative = false; 117 Regularizations regularizations; 118 }; 119 120 // TODO(shengx): The helper classes/methods are changed to support multiclass 121 // SDCA, which lead to changes within this function. Need to revisit the 122 // convergence once the multiclass SDCA is in. 123 void DoCompute(const ComputeOptions& options, OpKernelContext* const context) { 124 ModelWeights model_weights; 125 OP_REQUIRES_OK(context, model_weights.Initialize(context)); 126 127 Examples examples; 128 OP_REQUIRES_OK( 129 context, 130 examples.Initialize(context, model_weights, options.num_sparse_features, 131 options.num_sparse_features_with_values, 132 options.num_dense_features)); 133 134 const Tensor* example_state_data_t; 135 OP_REQUIRES_OK(context, 136 context->input("example_state_data", &example_state_data_t)); 137 TensorShape expected_example_state_shape({examples.num_examples(), 4}); 138 OP_REQUIRES(context, 139 example_state_data_t->shape() == expected_example_state_shape, 140 errors::InvalidArgument( 141 "Expected shape ", expected_example_state_shape.DebugString(), 142 " for example_state_data, got ", 143 example_state_data_t->shape().DebugString())); 144 145 Tensor mutable_example_state_data_t(*example_state_data_t); 146 auto example_state_data = mutable_example_state_data_t.matrix<float>(); 147 OP_REQUIRES_OK(context, context->set_output("out_example_state_data", 148 mutable_example_state_data_t)); 149 150 if (options.adaptative) { 151 OP_REQUIRES_OK(context, 152 examples.SampleAdaptativeProbabilities( 153 options.num_loss_partitions, options.regularizations, 154 model_weights, example_state_data, options.loss_updater, 155 /*num_weight_vectors =*/1)); 156 } 157 158 mutex mu; 159 Status train_step_status GUARDED_BY(mu); 160 std::atomic<std::int64_t> atomic_index(-1); 161 auto train_step = [&](const int64 begin, const int64 end) { 162 // The static_cast here is safe since begin and end can be at most 163 // num_examples which is an int. 164 for (int id = static_cast<int>(begin); id < end; ++id) { 165 const int64 example_index = 166 examples.sampled_index(++atomic_index, options.adaptative); 167 const Example& example = examples.example(example_index); 168 const float dual = example_state_data(example_index, 0); 169 const float example_weight = example.example_weight(); 170 float example_label = example.example_label(); 171 const Status conversion_status = 172 options.loss_updater->ConvertLabel(&example_label); 173 if (!conversion_status.ok()) { 174 mutex_lock l(mu); 175 train_step_status = conversion_status; 176 // Return from this worker thread - the calling thread is 177 // responsible for checking context status and returning on error. 178 return; 179 } 180 181 // Compute wx, example norm weighted by regularization, dual loss, 182 // primal loss. 183 // For binary SDCA, num_weight_vectors should be one. 184 const ExampleStatistics example_statistics = 185 example.ComputeWxAndWeightedExampleNorm( 186 options.num_loss_partitions, model_weights, 187 options.regularizations, 1 /* num_weight_vectors */); 188 189 const double new_dual = options.loss_updater->ComputeUpdatedDual( 190 options.num_loss_partitions, example_label, example_weight, dual, 191 example_statistics.wx[0], example_statistics.normalized_squared_norm); 192 193 // Compute new weights. 194 const double normalized_bounded_dual_delta = 195 (new_dual - dual) * example_weight / 196 options.regularizations.symmetric_l2(); 197 model_weights.UpdateDeltaWeights( 198 context->eigen_cpu_device(), example, 199 std::vector<double>{normalized_bounded_dual_delta}); 200 201 // Update example data. 202 example_state_data(example_index, 0) = new_dual; 203 example_state_data(example_index, 1) = 204 options.loss_updater->ComputePrimalLoss( 205 example_statistics.prev_wx[0], example_label, example_weight); 206 example_state_data(example_index, 2) = 207 options.loss_updater->ComputeDualLoss(dual, example_label, 208 example_weight); 209 example_state_data(example_index, 3) = example_weight; 210 } 211 }; 212 // TODO(sibyl-Aix6ihai): Tune this properly based on sparsity of the data, 213 // number of cpus, and cost per example. 214 const int64 kCostPerUnit = examples.num_features(); 215 const DeviceBase::CpuWorkerThreads& worker_threads = 216 *context->device()->tensorflow_cpu_worker_threads(); 217 218 Shard(worker_threads.num_threads, worker_threads.workers, 219 examples.num_examples(), kCostPerUnit, train_step); 220 OP_REQUIRES_OK(context, train_step_status); 221 } 222 223 } // namespace 224 225 class SdcaOptimizer : public OpKernel { 226 public: 227 explicit SdcaOptimizer(OpKernelConstruction* const context) 228 : OpKernel(context), options_(context) {} 229 230 void Compute(OpKernelContext* context) override { 231 DoCompute(options_, context); 232 } 233 234 private: 235 // TODO(sibyl-Aix6ihai): We could use the type-constraint on loss_type, and 236 // template the entire class to avoid the virtual table lookup penalty in 237 // the inner loop. 238 ComputeOptions options_; 239 }; 240 REGISTER_KERNEL_BUILDER(Name("SdcaOptimizer").Device(DEVICE_CPU), 241 SdcaOptimizer); 242 243 class SdcaShrinkL1 : public OpKernel { 244 public: 245 explicit SdcaShrinkL1(OpKernelConstruction* const context) 246 : OpKernel(context) { 247 OP_REQUIRES_OK(context, regularizations_.Initialize(context)); 248 } 249 250 void Compute(OpKernelContext* context) override { 251 OpMutableInputList weights_inputs; 252 OP_REQUIRES_OK(context, 253 context->mutable_input_list("weights", &weights_inputs)); 254 255 auto do_work = [&](const int64 begin, const int64 end) { 256 for (int i = begin; i < end; ++i) { 257 auto prox_w = weights_inputs.at(i, /*lock_held=*/true).flat<float>(); 258 prox_w.device(context->eigen_cpu_device()) = 259 regularizations_.EigenShrinkVector(prox_w); 260 } 261 }; 262 263 if (weights_inputs.size() > 0) { 264 int64 num_weights = 0; 265 for (int i = 0; i < weights_inputs.size(); ++i) { 266 num_weights += weights_inputs.at(i, /*lock_held=*/true).NumElements(); 267 } 268 // TODO(sibyl-Aix6ihai): Tune this value. 269 const int64 kCostPerUnit = (num_weights * 50) / weights_inputs.size(); 270 const DeviceBase::CpuWorkerThreads& worker_threads = 271 *context->device()->tensorflow_cpu_worker_threads(); 272 Shard(worker_threads.num_threads, worker_threads.workers, 273 weights_inputs.size(), kCostPerUnit, do_work); 274 } 275 } 276 277 private: 278 Regularizations regularizations_; 279 }; 280 REGISTER_KERNEL_BUILDER(Name("SdcaShrinkL1").Device(DEVICE_CPU), SdcaShrinkL1); 281 282 // Computes platform independent, compact and unique (with very high 283 // probability) representation of an example id. It shouldn't be put in 284 // persistent storage, as its implementation may change in the future. 285 // 286 // The current probability of at least one collision for 1B example_ids is 287 // approximately 10^-21 (ie 2^60 / 2^129). 288 class SdcaFprint : public OpKernel { 289 public: 290 explicit SdcaFprint(OpKernelConstruction* const context) 291 : OpKernel(context) {} 292 293 void Compute(OpKernelContext* context) override { 294 const Tensor& input = context->input(0); 295 OP_REQUIRES(context, TensorShapeUtils::IsVector(input.shape()), 296 errors::InvalidArgument("Input must be a vector, got shape ", 297 input.shape().DebugString())); 298 Tensor* out; 299 const int64 num_elements = input.NumElements(); 300 OP_REQUIRES_OK(context, context->allocate_output( 301 0, TensorShape({num_elements, 2}), &out)); 302 303 const auto in_values = input.flat<string>(); 304 auto out_values = out->matrix<int64>(); 305 306 for (int64 i = 0; i < num_elements; ++i) { 307 const Fprint128 fprint = Fingerprint128(in_values(i)); 308 // Never return 0 or 1 as the first value of the hash to allow these to 309 // safely be used as sentinel values (e.g. dense hash table empty key). 310 out_values(i, 0) = TF_PREDICT_TRUE(fprint.low64 >= 2) 311 ? fprint.low64 312 : fprint.low64 + ~static_cast<uint64>(1); 313 out_values(i, 1) = fprint.high64; 314 } 315 } 316 }; 317 REGISTER_KERNEL_BUILDER(Name("SdcaFprint").Device(DEVICE_CPU), SdcaFprint); 318 319 } // namespace tensorflow 320