Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #define EIGEN_USE_THREADS
     17 
     18 #include "tensorflow/core/kernels/sdca_internal.h"
     19 
     20 #include <limits>
     21 #include <random>
     22 
     23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     24 #include "tensorflow/core/lib/math/math_util.h"
     25 #include "tensorflow/core/lib/random/simple_philox.h"
     26 
     27 namespace tensorflow {
     28 namespace sdca {
     29 
     30 using UnalignedFloatVector = TTypes<const float>::UnalignedConstVec;
     31 using UnalignedInt64Vector = TTypes<const int64>::UnalignedConstVec;
     32 
     33 void FeatureWeightsDenseStorage::UpdateDenseDeltaWeights(
     34     const Eigen::ThreadPoolDevice& device,
     35     const Example::DenseVector& dense_vector,
     36     const std::vector<double>& normalized_bounded_dual_delta) {
     37   const size_t num_weight_vectors = normalized_bounded_dual_delta.size();
     38   if (num_weight_vectors == 1) {
     39     deltas_.device(device) =
     40         deltas_ + dense_vector.RowAsMatrix() *
     41                       deltas_.constant(normalized_bounded_dual_delta[0]);
     42   } else {
     43     // Transform the dual vector into a column matrix.
     44     const Eigen::TensorMap<Eigen::Tensor<const double, 2, Eigen::RowMajor>>
     45         dual_matrix(normalized_bounded_dual_delta.data(), num_weight_vectors,
     46                     1);
     47     const Eigen::array<Eigen::IndexPair<int>, 1> product_dims = {
     48         Eigen::IndexPair<int>(1, 0)};
     49     // This computes delta_w += delta_vector / \lamdba * N.
     50     deltas_.device(device) =
     51         (deltas_.cast<double>() +
     52          dual_matrix.contract(dense_vector.RowAsMatrix().cast<double>(),
     53                               product_dims))
     54             .cast<float>();
     55   }
     56 }
     57 
     58 void FeatureWeightsSparseStorage::UpdateSparseDeltaWeights(
     59     const Eigen::ThreadPoolDevice& device,
     60     const Example::SparseFeatures& sparse_features,
     61     const std::vector<double>& normalized_bounded_dual_delta) {
     62   for (int64 k = 0; k < sparse_features.indices->size(); ++k) {
     63     const double feature_value =
     64         sparse_features.values == nullptr ? 1.0 : (*sparse_features.values)(k);
     65     auto it = indices_to_id_.find((*sparse_features.indices)(k));
     66     for (size_t l = 0; l < normalized_bounded_dual_delta.size(); ++l) {
     67       deltas_(l, it->second) +=
     68           feature_value * normalized_bounded_dual_delta[l];
     69     }
     70   }
     71 }
     72 
     73 void ModelWeights::UpdateDeltaWeights(
     74     const Eigen::ThreadPoolDevice& device, const Example& example,
     75     const std::vector<double>& normalized_bounded_dual_delta) {
     76   // Sparse weights.
     77   for (size_t j = 0; j < sparse_weights_.size(); ++j) {
     78     sparse_weights_[j].UpdateSparseDeltaWeights(
     79         device, example.sparse_features_[j], normalized_bounded_dual_delta);
     80   }
     81 
     82   // Dense weights.
     83   for (size_t j = 0; j < dense_weights_.size(); ++j) {
     84     dense_weights_[j].UpdateDenseDeltaWeights(
     85         device, *example.dense_vectors_[j], normalized_bounded_dual_delta);
     86   }
     87 }
     88 
     89 Status ModelWeights::Initialize(OpKernelContext* const context) {
     90   OpInputList sparse_indices_inputs;
     91   TF_RETURN_IF_ERROR(
     92       context->input_list("sparse_indices", &sparse_indices_inputs));
     93   OpInputList sparse_weights_inputs;
     94   TF_RETURN_IF_ERROR(
     95       context->input_list("sparse_weights", &sparse_weights_inputs));
     96   OpInputList dense_weights_inputs;
     97   TF_RETURN_IF_ERROR(
     98       context->input_list("dense_weights", &dense_weights_inputs));
     99 
    100   OpOutputList sparse_weights_outputs;
    101   TF_RETURN_IF_ERROR(context->output_list("out_delta_sparse_weights",
    102                                           &sparse_weights_outputs));
    103 
    104   OpOutputList dense_weights_outputs;
    105   TF_RETURN_IF_ERROR(
    106       context->output_list("out_delta_dense_weights", &dense_weights_outputs));
    107 
    108   for (int i = 0; i < sparse_weights_inputs.size(); ++i) {
    109     Tensor* delta_t;
    110     TF_RETURN_IF_ERROR(sparse_weights_outputs.allocate(
    111         i, sparse_weights_inputs[i].shape(), &delta_t));
    112     // Convert the input vector to a row matrix in internal representation.
    113     auto deltas = delta_t->shaped<float, 2>({1, delta_t->NumElements()});
    114     deltas.setZero();
    115     sparse_weights_.emplace_back(FeatureWeightsSparseStorage{
    116         sparse_indices_inputs[i].flat<int64>(),
    117         sparse_weights_inputs[i].shaped<float, 2>(
    118             {1, sparse_weights_inputs[i].NumElements()}),
    119         deltas});
    120   }
    121 
    122   // Reads in the weights, and allocates and initializes the delta weights.
    123   const auto initialize_weights =
    124       [&](const OpInputList& weight_inputs, OpOutputList* const weight_outputs,
    125           std::vector<FeatureWeightsDenseStorage>* const feature_weights) {
    126         for (int i = 0; i < weight_inputs.size(); ++i) {
    127           Tensor* delta_t;
    128           TF_RETURN_IF_ERROR(
    129               weight_outputs->allocate(i, weight_inputs[i].shape(), &delta_t));
    130           // Convert the input vector to a row matrix in internal
    131           // representation.
    132           auto deltas = delta_t->shaped<float, 2>({1, delta_t->NumElements()});
    133           deltas.setZero();
    134           feature_weights->emplace_back(FeatureWeightsDenseStorage{
    135               weight_inputs[i].shaped<float, 2>(
    136                   {1, weight_inputs[i].NumElements()}),
    137               deltas});
    138         }
    139         return Status::OK();
    140       };
    141 
    142   return initialize_weights(dense_weights_inputs, &dense_weights_outputs,
    143                             &dense_weights_);
    144 }
    145 
    146 // Computes the example statistics for given example, and model. Defined here
    147 // as we need definition of ModelWeights and Regularizations.
    148 const ExampleStatistics Example::ComputeWxAndWeightedExampleNorm(
    149     const int num_loss_partitions, const ModelWeights& model_weights,
    150     const Regularizations& regularization, const int num_weight_vectors) const {
    151   ExampleStatistics result(num_weight_vectors);
    152 
    153   result.normalized_squared_norm =
    154       squared_norm_ / regularization.symmetric_l2();
    155 
    156   // Compute w \dot x and prev_w \dot x.
    157   // This is for sparse features contribution to the logit.
    158   for (size_t j = 0; j < sparse_features_.size(); ++j) {
    159     const Example::SparseFeatures& sparse_features = sparse_features_[j];
    160     const FeatureWeightsSparseStorage& sparse_weights =
    161         model_weights.sparse_weights()[j];
    162 
    163     for (int64 k = 0; k < sparse_features.indices->size(); ++k) {
    164       const int64 feature_index = (*sparse_features.indices)(k);
    165       const double feature_value = sparse_features.values == nullptr
    166                                        ? 1.0
    167                                        : (*sparse_features.values)(k);
    168       for (int l = 0; l < num_weight_vectors; ++l) {
    169         const float sparse_weight = sparse_weights.nominals(l, feature_index);
    170         const double feature_weight =
    171             sparse_weight +
    172             sparse_weights.deltas(l, feature_index) * num_loss_partitions;
    173         result.prev_wx[l] +=
    174             feature_value * regularization.Shrink(sparse_weight);
    175         result.wx[l] += feature_value * regularization.Shrink(feature_weight);
    176       }
    177     }
    178   }
    179 
    180   // Compute w \dot x and prev_w \dot x.
    181   // This is for dense features contribution to the logit.
    182   for (size_t j = 0; j < dense_vectors_.size(); ++j) {
    183     const Example::DenseVector& dense_vector = *dense_vectors_[j];
    184     const FeatureWeightsDenseStorage& dense_weights =
    185         model_weights.dense_weights()[j];
    186 
    187     const Eigen::Tensor<float, 2, Eigen::RowMajor> feature_weights =
    188         dense_weights.nominals() +
    189         dense_weights.deltas() *
    190             dense_weights.deltas().constant(num_loss_partitions);
    191     if (num_weight_vectors == 1) {
    192       const Eigen::Tensor<float, 0, Eigen::RowMajor> prev_prediction =
    193           (dense_vector.Row() *
    194            regularization.EigenShrinkVector(
    195                Eigen::TensorMap<Eigen::Tensor<const float, 1, Eigen::RowMajor>>(
    196                    dense_weights.nominals().data(),
    197                    dense_weights.nominals().dimension(1))))
    198               .sum();
    199       const Eigen::Tensor<float, 0, Eigen::RowMajor> prediction =
    200           (dense_vector.Row() *
    201            regularization.EigenShrinkVector(
    202                Eigen::TensorMap<Eigen::Tensor<const float, 1, Eigen::RowMajor>>(
    203                    feature_weights.data(), feature_weights.dimension(1))))
    204               .sum();
    205       result.prev_wx[0] += prev_prediction();
    206       result.wx[0] += prediction();
    207     } else {
    208       const Eigen::array<Eigen::IndexPair<int>, 1> product_dims = {
    209           Eigen::IndexPair<int>(1, 1)};
    210       const Eigen::Tensor<float, 2, Eigen::RowMajor> prev_prediction =
    211           regularization.EigenShrinkMatrix(dense_weights.nominals())
    212               .contract(dense_vector.RowAsMatrix(), product_dims);
    213       const Eigen::Tensor<float, 2, Eigen::RowMajor> prediction =
    214           regularization.EigenShrinkMatrix(feature_weights)
    215               .contract(dense_vector.RowAsMatrix(), product_dims);
    216       // The result of "tensor contraction" (multiplication)  in the code
    217       // above is of dimension num_weight_vectors * 1.
    218       for (int l = 0; l < num_weight_vectors; ++l) {
    219         result.prev_wx[l] += prev_prediction(l, 0);
    220         result.wx[l] += prediction(l, 0);
    221       }
    222     }
    223   }
    224 
    225   return result;
    226 }
    227 
    228 // Examples contains all the training examples that SDCA uses for a mini-batch.
    229 Status Examples::SampleAdaptativeProbabilities(
    230     const int num_loss_partitions, const Regularizations& regularization,
    231     const ModelWeights& model_weights,
    232     const TTypes<float>::Matrix example_state_data,
    233     const std::unique_ptr<DualLossUpdater>& loss_updater,
    234     const int num_weight_vectors) {
    235   if (num_weight_vectors != 1) {
    236     return errors::InvalidArgument(
    237         "Adaptive SDCA only works with binary SDCA, "
    238         "where num_weight_vectors should be 1.");
    239   }
    240   // Compute the probabilities
    241   for (int example_id = 0; example_id < num_examples(); ++example_id) {
    242     const Example& example = examples_[example_id];
    243     const double example_weight = example.example_weight();
    244     float label = example.example_label();
    245     const Status conversion_status = loss_updater->ConvertLabel(&label);
    246     const ExampleStatistics example_statistics =
    247         example.ComputeWxAndWeightedExampleNorm(num_loss_partitions,
    248                                                 model_weights, regularization,
    249                                                 num_weight_vectors);
    250     const double kappa = example_state_data(example_id, 0) +
    251                          loss_updater->PrimalLossDerivative(
    252                              example_statistics.wx[0], label, example_weight);
    253     probabilities_[example_id] = example_weight *
    254                                  sqrt(examples_[example_id].squared_norm_ +
    255                                       regularization.symmetric_l2() *
    256                                           loss_updater->SmoothnessConstant()) *
    257                                  std::abs(kappa);
    258   }
    259 
    260   // Sample the index
    261   random::DistributionSampler sampler(probabilities_);
    262   GuardedPhiloxRandom generator;
    263   generator.Init(0, 0);
    264   auto local_gen = generator.ReserveSamples32(num_examples());
    265   random::SimplePhilox random(&local_gen);
    266   std::random_device rd;
    267   std::mt19937 gen(rd());
    268   std::uniform_real_distribution<> dis(0, 1);
    269 
    270   // We use a decay of 10: the probability of an example is divided by 10
    271   // once that example is picked. A good approximation of that is to only
    272   // keep a picked example with probability (1 / 10) ^ k where k is the
    273   // number of times we already picked that example. We add a num_retries
    274   // to avoid taking too long to sample. We then fill the sampled_index with
    275   // unseen examples sorted by probabilities.
    276   int id = 0;
    277   int num_retries = 0;
    278   while (id < num_examples() && num_retries < num_examples()) {
    279     int picked_id = sampler.Sample(&random);
    280     if (dis(gen) > MathUtil::IPow(0.1, sampled_count_[picked_id])) {
    281       num_retries++;
    282       continue;
    283     }
    284     sampled_count_[picked_id]++;
    285     sampled_index_[id++] = picked_id;
    286   }
    287 
    288   std::vector<std::pair<int, float>> examples_not_seen;
    289   examples_not_seen.reserve(num_examples());
    290   for (int i = 0; i < num_examples(); ++i) {
    291     if (sampled_count_[i] == 0)
    292       examples_not_seen.emplace_back(sampled_index_[i], probabilities_[i]);
    293   }
    294   std::sort(
    295       examples_not_seen.begin(), examples_not_seen.end(),
    296       [](const std::pair<int, float>& lhs, const std::pair<int, float>& rhs) {
    297         return lhs.second > rhs.second;
    298       });
    299   for (int i = id; i < num_examples(); ++i) {
    300     sampled_count_[i] = examples_not_seen[i - id].first;
    301   }
    302   return Status::OK();
    303 }
    304 
    305 // TODO(sibyl-Aix6ihai): Refactor/shorten this function.
    306 Status Examples::Initialize(OpKernelContext* const context,
    307                             const ModelWeights& weights,
    308                             const int num_sparse_features,
    309                             const int num_sparse_features_with_values,
    310                             const int num_dense_features) {
    311   num_features_ = num_sparse_features + num_dense_features;
    312 
    313   OpInputList sparse_example_indices_inputs;
    314   TF_RETURN_IF_ERROR(context->input_list("sparse_example_indices",
    315                                          &sparse_example_indices_inputs));
    316   OpInputList sparse_feature_indices_inputs;
    317   TF_RETURN_IF_ERROR(context->input_list("sparse_feature_indices",
    318                                          &sparse_feature_indices_inputs));
    319   OpInputList sparse_feature_values_inputs;
    320   if (num_sparse_features_with_values > 0) {
    321     TF_RETURN_IF_ERROR(context->input_list("sparse_feature_values",
    322                                            &sparse_feature_values_inputs));
    323   }
    324 
    325   const Tensor* example_weights_t;
    326   TF_RETURN_IF_ERROR(context->input("example_weights", &example_weights_t));
    327   auto example_weights = example_weights_t->flat<float>();
    328 
    329   if (example_weights.size() >= std::numeric_limits<int>::max()) {
    330     return errors::InvalidArgument(strings::Printf(
    331         "Too many examples in a mini-batch: %zu > %d", example_weights.size(),
    332         std::numeric_limits<int>::max()));
    333   }
    334 
    335   // The static_cast here is safe since num_examples can be at max an int.
    336   const int num_examples = static_cast<int>(example_weights.size());
    337   const Tensor* example_labels_t;
    338   TF_RETURN_IF_ERROR(context->input("example_labels", &example_labels_t));
    339   auto example_labels = example_labels_t->flat<float>();
    340 
    341   OpInputList dense_features_inputs;
    342   TF_RETURN_IF_ERROR(
    343       context->input_list("dense_features", &dense_features_inputs));
    344 
    345   examples_.clear();
    346   examples_.resize(num_examples);
    347   probabilities_.resize(num_examples);
    348   sampled_index_.resize(num_examples);
    349   sampled_count_.resize(num_examples);
    350   for (int example_id = 0; example_id < num_examples; ++example_id) {
    351     Example* const example = &examples_[example_id];
    352     example->sparse_features_.resize(num_sparse_features);
    353     example->dense_vectors_.resize(num_dense_features);
    354     example->example_weight_ = example_weights(example_id);
    355     example->example_label_ = example_labels(example_id);
    356   }
    357   const DeviceBase::CpuWorkerThreads& worker_threads =
    358       *context->device()->tensorflow_cpu_worker_threads();
    359   TF_RETURN_IF_ERROR(CreateSparseFeatureRepresentation(
    360       worker_threads, num_examples, num_sparse_features, weights,
    361       sparse_example_indices_inputs, sparse_feature_indices_inputs,
    362       sparse_feature_values_inputs, &examples_));
    363   TF_RETURN_IF_ERROR(CreateDenseFeatureRepresentation(
    364       worker_threads, num_examples, num_dense_features, weights,
    365       dense_features_inputs, &examples_));
    366   ComputeSquaredNormPerExample(worker_threads, num_examples,
    367                                num_sparse_features, num_dense_features,
    368                                &examples_);
    369   return Status::OK();
    370 }
    371 
    372 Status Examples::CreateSparseFeatureRepresentation(
    373     const DeviceBase::CpuWorkerThreads& worker_threads, const int num_examples,
    374     const int num_sparse_features, const ModelWeights& weights,
    375     const OpInputList& sparse_example_indices_inputs,
    376     const OpInputList& sparse_feature_indices_inputs,
    377     const OpInputList& sparse_feature_values_inputs,
    378     std::vector<Example>* const examples) {
    379   mutex mu;
    380   Status result GUARDED_BY(mu);
    381   auto parse_partition = [&](const int64 begin, const int64 end) {
    382     // The static_cast here is safe since begin and end can be at most
    383     // num_examples which is an int.
    384     for (int i = static_cast<int>(begin); i < end; ++i) {
    385       auto example_indices =
    386           sparse_example_indices_inputs[i].template flat<int64>();
    387       auto feature_indices =
    388           sparse_feature_indices_inputs[i].template flat<int64>();
    389 
    390       // Parse features for each example. Features for a particular example
    391       // are at the offsets (start_id, end_id]
    392       int start_id = -1;
    393       int end_id = 0;
    394       for (int example_id = 0; example_id < num_examples; ++example_id) {
    395         start_id = end_id;
    396         while (end_id < example_indices.size() &&
    397                example_indices(end_id) == example_id) {
    398           ++end_id;
    399         }
    400         Example::SparseFeatures* const sparse_features =
    401             &(*examples)[example_id].sparse_features_[i];
    402         if (start_id < example_indices.size() &&
    403             example_indices(start_id) == example_id) {
    404           sparse_features->indices.reset(new UnalignedInt64Vector(
    405               &(feature_indices(start_id)), end_id - start_id));
    406           if (sparse_feature_values_inputs.size() > i) {
    407             auto feature_weights =
    408                 sparse_feature_values_inputs[i].flat<float>();
    409             sparse_features->values.reset(new UnalignedFloatVector(
    410                 &(feature_weights(start_id)), end_id - start_id));
    411           }
    412           // If features are non empty.
    413           if (end_id - start_id > 0) {
    414             // TODO(sibyl-Aix6ihai): Write this efficiently using vectorized
    415             // operations from eigen.
    416             for (int64 k = 0; k < sparse_features->indices->size(); ++k) {
    417               const int64 feature_index = (*sparse_features->indices)(k);
    418               if (!weights.SparseIndexValid(i, feature_index)) {
    419                 mutex_lock l(mu);
    420                 result = errors::InvalidArgument(
    421                     "Found sparse feature indices out of valid range: ",
    422                     (*sparse_features->indices)(k));
    423                 return;
    424               }
    425             }
    426           }
    427         } else {
    428           // Add a Tensor that has size 0.
    429           sparse_features->indices.reset(
    430               new UnalignedInt64Vector(&(feature_indices(0)), 0));
    431           // If values exist for this feature group.
    432           if (sparse_feature_values_inputs.size() > i) {
    433             auto feature_weights =
    434                 sparse_feature_values_inputs[i].flat<float>();
    435             sparse_features->values.reset(
    436                 new UnalignedFloatVector(&(feature_weights(0)), 0));
    437           }
    438         }
    439       }
    440     }
    441   };
    442   // For each column, the cost of parsing it is O(num_examples). We use
    443   // num_examples here, as empirically Shard() creates the right amount of
    444   // threads based on the problem size.
    445   // TODO(sibyl-Aix6ihai): Tune this as a function of dataset size.
    446   const int64 kCostPerUnit = num_examples;
    447   Shard(worker_threads.num_threads, worker_threads.workers, num_sparse_features,
    448         kCostPerUnit, parse_partition);
    449   return result;
    450 }
    451 
    452 Status Examples::CreateDenseFeatureRepresentation(
    453     const DeviceBase::CpuWorkerThreads& worker_threads, const int num_examples,
    454     const int num_dense_features, const ModelWeights& weights,
    455     const OpInputList& dense_features_inputs,
    456     std::vector<Example>* const examples) {
    457   mutex mu;
    458   Status result GUARDED_BY(mu);
    459   auto parse_partition = [&](const int64 begin, const int64 end) {
    460     // The static_cast here is safe since begin and end can be at most
    461     // num_examples which is an int.
    462     for (int i = static_cast<int>(begin); i < end; ++i) {
    463       auto dense_features = dense_features_inputs[i].template matrix<float>();
    464       for (int example_id = 0; example_id < num_examples; ++example_id) {
    465         (*examples)[example_id].dense_vectors_[i].reset(
    466             new Example::DenseVector{dense_features, example_id});
    467       }
    468       if (!weights.DenseIndexValid(i, dense_features.dimension(1) - 1)) {
    469         mutex_lock l(mu);
    470         result = errors::InvalidArgument(
    471             "More dense features than we have parameters for: ",
    472             dense_features.dimension(1));
    473         return;
    474       }
    475     }
    476   };
    477   // TODO(sibyl-Aix6ihai): Tune this as a function of dataset size.
    478   const int64 kCostPerUnit = num_examples;
    479   Shard(worker_threads.num_threads, worker_threads.workers, num_dense_features,
    480         kCostPerUnit, parse_partition);
    481   return result;
    482 }
    483 
    484 void Examples::ComputeSquaredNormPerExample(
    485     const DeviceBase::CpuWorkerThreads& worker_threads, const int num_examples,
    486     const int num_sparse_features, const int num_dense_features,
    487     std::vector<Example>* const examples) {
    488   // Compute norm of examples.
    489   auto compute_example_norm = [&](const int64 begin, const int64 end) {
    490     // The static_cast here is safe since begin and end can be at most
    491     // num_examples which is an int.
    492     for (int example_id = static_cast<int>(begin); example_id < end;
    493          ++example_id) {
    494       double squared_norm = 0;
    495       Example* const example = &(*examples)[example_id];
    496       for (int j = 0; j < num_sparse_features; ++j) {
    497         const Example::SparseFeatures& sparse_features =
    498             example->sparse_features_[j];
    499         if (sparse_features.values) {
    500           const Eigen::Tensor<float, 0, Eigen::RowMajor> sn =
    501               sparse_features.values->square().sum();
    502           squared_norm += sn();
    503         } else {
    504           squared_norm += sparse_features.indices->size();
    505         }
    506       }
    507       for (int j = 0; j < num_dense_features; ++j) {
    508         const Eigen::Tensor<float, 0, Eigen::RowMajor> sn =
    509             example->dense_vectors_[j]->Row().square().sum();
    510         squared_norm += sn();
    511       }
    512       example->squared_norm_ = squared_norm;
    513     }
    514   };
    515   // TODO(sibyl-Aix6ihai): Compute the cost optimally.
    516   const int64 kCostPerUnit = num_dense_features + num_sparse_features;
    517   Shard(worker_threads.num_threads, worker_threads.workers, num_examples,
    518         kCostPerUnit, compute_example_norm);
    519 }
    520 
    521 }  // namespace sdca
    522 }  // namespace tensorflow
    523