Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <cstddef>
     17 #include <functional>
     18 #include <map>
     19 #include <mutex>
     20 #include <numeric>
     21 #include <unordered_map>
     22 #include <vector>
     23 
     24 #include "tensorflow/core/framework/op_kernel.h"
     25 #include "tensorflow/core/framework/resource_mgr.h"
     26 #include "tensorflow/core/framework/tensor.h"
     27 #include "tensorflow/core/framework/tensor_shape.h"
     28 #include "tensorflow/core/lib/gtl/optional.h"
     29 #include "tensorflow/core/lib/strings/strcat.h"
     30 #include "tensorflow/core/platform/env.h"
     31 #include "tensorflow/core/platform/mutex.h"
     32 #include "tensorflow/core/platform/thread_annotations.h"
     33 
     34 namespace tensorflow {
     35 namespace {
     36 
     37 // Partial Ordering Comparator for Tensor keys containing scalar int64's
     38 struct KeyTensorLess {
     39   bool operator()(const Tensor& lhs, const Tensor& rhs) const {
     40     return std::less<int64>{}(lhs.scalar<int64>()(), rhs.scalar<int64>()());
     41   }
     42 };
     43 
     44 // Key Equality operator for Tensor keys containing scalar int64's
     45 struct KeyTensorEqual {
     46   bool operator()(const Tensor& lhs, const Tensor& rhs) const {
     47     return std::equal_to<int64>{}(lhs.scalar<int64>()(), rhs.scalar<int64>()());
     48   }
     49 };
     50 
     51 // Hash for Tensor keys containing scalar int64's
     52 struct KeyTensorHash {
     53   std::size_t operator()(const Tensor& key) const {
     54     return std::hash<int64>{}(key.scalar<int64>()());
     55   }
     56 };
     57 
     58 // Primary template.
     59 template <bool Ordered, typename Data>
     60 struct MapTraits;
     61 
     62 // Partial specialization for ordered.
     63 template <typename Data>
     64 struct MapTraits<true, Data> {
     65   using KeyType = Tensor;
     66   using DataType = Data;
     67   using MapType = std::map<KeyType, Data, KeyTensorLess>;
     68 };
     69 
     70 // Partial specialization for unordered.
     71 template <typename Data>
     72 struct MapTraits<false, Data> {
     73   using KeyType = Tensor;
     74   using DataType = Data;
     75   using MapType =
     76       std::unordered_map<KeyType, Data, KeyTensorHash, KeyTensorEqual>;
     77 };
     78 
     79 // Wrapper around map/unordered_map.
     80 template <bool Ordered>
     81 class StagingMap : public ResourceBase {
     82  public:
     83   // Public typedefs
     84   using Tuple = std::vector<Tensor>;
     85   using OptionalTensor = gtl::optional<Tensor>;
     86   using OptionalTuple = std::vector<OptionalTensor>;
     87 
     88   using MapType = typename MapTraits<Ordered, OptionalTuple>::MapType;
     89   using KeyType = typename MapTraits<Ordered, OptionalTuple>::KeyType;
     90 
     91   using IncompleteType = typename MapTraits<false, OptionalTuple>::MapType;
     92 
     93  private:
     94   // Private variables
     95   DataTypeVector dtypes_ GUARDED_BY(mu_);
     96   std::size_t capacity_ GUARDED_BY(mu_);
     97   std::size_t memory_limit_ GUARDED_BY(mu_);
     98   std::size_t current_bytes_ GUARDED_BY(mu_);
     99   tensorflow::mutex mu_;
    100   tensorflow::condition_variable not_empty_;
    101   tensorflow::condition_variable full_;
    102   IncompleteType incomplete_ GUARDED_BY(mu_);
    103   MapType map_ GUARDED_BY(mu_);
    104 
    105  private:
    106   // private methods
    107 
    108   // If map is configured for bounded capacity, notify
    109   // waiting inserters that space is now available
    110   void notify_inserters_if_bounded() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    111     if (has_capacity() || has_memory_limit()) {
    112       // Notify all inserters. The removal of an element
    113       // may make memory available for many inserters
    114       // to insert new elements
    115       full_.notify_all();
    116     }
    117   }
    118 
    119   // Notify all removers waiting to extract values
    120   // that data is now available
    121   void notify_removers() {
    122     // Notify all removers. This is because they are
    123     // waiting for specific keys to appear in the map
    124     // so we don't know which one to wake up.
    125     not_empty_.notify_all();
    126   }
    127 
    128   bool has_capacity() const EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    129     return capacity_ > 0;
    130   }
    131 
    132   bool has_memory_limit() const EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    133     return memory_limit_ > 0;
    134   }
    135 
    136   bool would_exceed_memory_limit(std::size_t bytes) const
    137       EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    138     return has_memory_limit() && bytes + current_bytes_ > memory_limit_;
    139   }
    140 
    141   bool is_capacity_full() const EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    142     return has_capacity() && map_.size() >= capacity_;
    143   }
    144 
    145   // Get number of bytes in the tuple
    146   std::size_t get_tuple_bytes(const Tuple& tuple) {
    147     return std::accumulate(tuple.begin(), tuple.end(),
    148                            static_cast<std::size_t>(0),
    149                            [](const std::size_t& lhs, const Tensor& rhs) {
    150                              return lhs + rhs.TotalBytes();
    151                            });
    152   }
    153 
    154   // Get number of bytes in the incomplete tuple
    155   std::size_t get_tuple_bytes(const OptionalTuple& tuple) {
    156     return std::accumulate(
    157         tuple.begin(), tuple.end(), static_cast<std::size_t>(0),
    158         [](const std::size_t& lhs, const OptionalTensor& rhs) {
    159           return (lhs + rhs.has_value()) ? rhs.value().TotalBytes() : 0;
    160         });
    161   }
    162 
    163   // Check that the index is within bounds
    164   Status check_index(const Tensor& key, std::size_t index)
    165       EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    166     if (index >= dtypes_.size()) {
    167       return Status(errors::InvalidArgument(
    168           "Index '", index, "' for key '", key.scalar<int64>()(),
    169           "' was out of bounds '", dtypes_.size(), "'."));
    170     }
    171 
    172     return Status::OK();
    173   }
    174 
    175   Status copy_or_move_tensors(OptionalTuple* map_tuple, const Tensor& key,
    176                               const Tensor& indices, Tuple* output,
    177                               bool copy = false) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    178     auto findices = indices.flat<int>();
    179 
    180     // Return values at specified indices
    181     for (std::size_t i = 0; i < findices.dimension(0); ++i) {
    182       std::size_t index = findices(i);
    183 
    184       TF_RETURN_IF_ERROR(check_index(key, index));
    185 
    186       // Insist on a value present at the specified index
    187       if (!(*map_tuple)[index].has_value()) {
    188         return Status(errors::InvalidArgument(
    189             "Tensor at index '", index, "' for key '", key.scalar<int64>()(),
    190             "' has already been removed."));
    191       }
    192 
    193       // Copy the contained tensor and
    194       // remove from the OptionalTuple
    195       output->push_back((*map_tuple)[index].value());
    196 
    197       // Clear out the entry if we're not copying (moving)
    198       if (!copy) {
    199         (*map_tuple)[index].reset();
    200       }
    201     }
    202 
    203     return Status::OK();
    204   }
    205 
    206   // Check that the optional value at the specified index
    207   // is uninitialized
    208   Status check_index_uninitialized(const Tensor& key, std::size_t index,
    209                                    const OptionalTuple& tuple)
    210       EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    211     if (tuple[index].has_value()) {
    212       return Status(errors::InvalidArgument(
    213           "The tensor for index '", index, "' for key '", key.scalar<int64>()(),
    214           "' was already initialized '", dtypes_.size(), "'."));
    215     }
    216 
    217     return Status::OK();
    218   }
    219 
    220   // Check that the indices are strictly ordered
    221   Status check_index_ordering(const Tensor& indices) {
    222     auto findices = indices.flat<int>();
    223 
    224     for (std::size_t i = 0; i < findices.dimension(0) - 1; ++i) {
    225       if (findices(i) < findices(i + 1)) {
    226         continue;
    227       }
    228 
    229       return Status(
    230           errors::InvalidArgument("Indices are not strictly ordered"));
    231     }
    232 
    233     return Status::OK();
    234   }
    235 
    236   // Check bytes are within memory limits memory limits
    237   Status check_memory_limit(std::size_t bytes) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    238     if (has_memory_limit() && bytes > memory_limit_) {
    239       return Status(errors::ResourceExhausted(
    240           "Attempted to insert tensors with combined size of '", bytes,
    241           "' bytes into Staging Area with a memory limit of '", memory_limit_,
    242           "'."));
    243     }
    244 
    245     return Status::OK();
    246   }
    247 
    248   // Insert incomplete data into the Barrier
    249   Status put_incomplete(const KeyType& key, const Tensor& indices,
    250                         OptionalTuple* tuple, tensorflow::mutex_lock* lock)
    251       EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    252     auto findices = indices.flat<int>();
    253 
    254     // Search for the key in our incomplete set
    255     auto it = incomplete_.find(key);
    256 
    257     // Check that the tuple fits within the memory limit
    258     std::size_t tuple_bytes = get_tuple_bytes(*tuple);
    259     TF_RETURN_IF_ERROR(check_memory_limit(tuple_bytes));
    260 
    261     // Wait until we don't exceed the memory limit
    262     while (would_exceed_memory_limit(tuple_bytes)) {
    263       full_.wait(*lock);
    264     }
    265 
    266     // This key isn't present in the incomplete set
    267     // Create OptionalTuple and insert
    268     if (it == incomplete_.end()) {
    269       OptionalTuple empty(dtypes_.size());
    270 
    271       // Initialize empty tuple with given dta
    272       for (std::size_t i = 0; i < findices.dimension(0); ++i) {
    273         std::size_t index = findices(i);
    274         TF_RETURN_IF_ERROR(check_index(key, index));
    275 
    276         // Assign tuple at this index
    277         empty[index] = std::move((*tuple)[i]);
    278       }
    279 
    280       // Insert into incomplete map
    281       incomplete_.insert({key, std::move(empty)});
    282 
    283       // Increment size
    284       current_bytes_ += tuple_bytes;
    285     }
    286     // Found an entry in the incomplete index
    287     // Update with given data and insert complete entries
    288     // into the main map
    289     else {
    290       // Reference existing incomplete tuple
    291       OptionalTuple& present = it->second;
    292 
    293       // Assign given data
    294       for (std::size_t i = 0; i < findices.dimension(0); ++i) {
    295         std::size_t index = findices(i);
    296         TF_RETURN_IF_ERROR(check_index(key, index));
    297         TF_RETURN_IF_ERROR(check_index_uninitialized(key, index, present));
    298 
    299         // Assign tuple at this index
    300         present[index] = std::move((*tuple)[i]);
    301       }
    302 
    303       // Increment size
    304       current_bytes_ += tuple_bytes;
    305 
    306       // Do we have values at all tuple elements?
    307       bool complete =
    308           std::all_of(present.begin(), present.end(),
    309                       [](const OptionalTensor& v) { return v.has_value(); });
    310 
    311       // If so, put the tuple in the actual map
    312       if (complete) {
    313         OptionalTuple insert_tuple = std::move(it->second);
    314 
    315         // Remove from incomplete
    316         incomplete_.erase(it);
    317 
    318         TF_RETURN_IF_ERROR(put_complete(key, &insert_tuple));
    319       }
    320     }
    321 
    322     return Status::OK();
    323   }
    324 
    325   // Does the insertion into the actual staging area
    326   Status put_complete(const KeyType& key, OptionalTuple* tuple)
    327       EXCLUSIVE_LOCKS_REQUIRED(mu_) {
    328     // Insert key and tuples into the map
    329     map_.insert({key, std::move(*tuple)});
    330 
    331     notify_removers();
    332 
    333     return Status::OK();
    334   }
    335 
    336  public:
    337   // public methods
    338   explicit StagingMap(const DataTypeVector& dtypes, std::size_t capacity,
    339                       std::size_t memory_limit)
    340       : dtypes_(dtypes),
    341         capacity_(capacity),
    342         memory_limit_(memory_limit),
    343         current_bytes_(0) {}
    344 
    345   Status put(KeyType* key, const Tensor* indices, OptionalTuple* tuple) {
    346     tensorflow::mutex_lock lock(mu_);
    347 
    348     // Sanity check the indices
    349     TF_RETURN_IF_ERROR(check_index_ordering(*indices));
    350 
    351     // Handle incomplete inserts
    352     if (indices->NumElements() != dtypes_.size()) {
    353       return put_incomplete(*key, *indices, tuple, &lock);
    354     }
    355 
    356     std::size_t tuple_bytes = get_tuple_bytes(*tuple);
    357     // Check that tuple_bytes fits within the memory limit
    358     TF_RETURN_IF_ERROR(check_memory_limit(tuple_bytes));
    359 
    360     // Wait until there's space for insertion.
    361     while (would_exceed_memory_limit(tuple_bytes) || is_capacity_full()) {
    362       full_.wait(lock);
    363     }
    364 
    365     // Do the put operation
    366     TF_RETURN_IF_ERROR(put_complete(*key, tuple));
    367 
    368     // Update the current size
    369     current_bytes_ += tuple_bytes;
    370 
    371     return Status::OK();
    372   }
    373 
    374   Status get(const KeyType* key, const Tensor* indices, Tuple* tuple) {
    375     tensorflow::mutex_lock lock(mu_);
    376 
    377     // Sanity check the indices
    378     TF_RETURN_IF_ERROR(check_index_ordering(*indices));
    379 
    380     typename MapType::iterator it;
    381 
    382     // Wait until the element with the requested key is present
    383     while ((it = map_.find(*key)) == map_.end()) {
    384       not_empty_.wait(lock);
    385     }
    386 
    387     TF_RETURN_IF_ERROR(
    388         copy_or_move_tensors(&it->second, *key, *indices, tuple, true));
    389 
    390     // Update bytes in the Staging Area
    391     current_bytes_ -= get_tuple_bytes(*tuple);
    392 
    393     return Status::OK();
    394   }
    395 
    396   Status pop(const KeyType* key, const Tensor* indices, Tuple* tuple) {
    397     tensorflow::mutex_lock lock(mu_);
    398 
    399     // Sanity check the indices
    400     TF_RETURN_IF_ERROR(check_index_ordering(*indices));
    401 
    402     typename MapType::iterator it;
    403 
    404     // Wait until the element with the requested key is present
    405     while ((it = map_.find(*key)) == map_.end()) {
    406       not_empty_.wait(lock);
    407     }
    408 
    409     TF_RETURN_IF_ERROR(
    410         copy_or_move_tensors(&it->second, *key, *indices, tuple));
    411 
    412     // Remove entry if all the values have been consumed
    413     if (!std::any_of(it->second.begin(), it->second.end(),
    414                      std::mem_fn(&OptionalTensor::has_value))) {
    415       map_.erase(it);
    416     }
    417 
    418     // Update bytes in the Staging Area
    419     current_bytes_ -= get_tuple_bytes(*tuple);
    420 
    421     notify_inserters_if_bounded();
    422 
    423     return Status::OK();
    424   }
    425 
    426   Status popitem(KeyType* key, const Tensor* indices, Tuple* tuple) {
    427     tensorflow::mutex_lock lock(mu_);
    428 
    429     // Sanity check the indices
    430     TF_RETURN_IF_ERROR(check_index_ordering(*indices));
    431 
    432     // Wait until map is not empty
    433     while (this->map_.empty()) {
    434       not_empty_.wait(lock);
    435     }
    436 
    437     // Move from the first element and erase it
    438 
    439     auto it = map_.begin();
    440 
    441     TF_RETURN_IF_ERROR(
    442         copy_or_move_tensors(&it->second, *key, *indices, tuple));
    443 
    444     *key = it->first;
    445 
    446     // Remove entry if all the values have been consumed
    447     if (!std::any_of(it->second.begin(), it->second.end(),
    448                      std::mem_fn(&OptionalTensor::has_value))) {
    449       map_.erase(it);
    450     }
    451 
    452     // Update bytes in the Staging Area
    453     current_bytes_ -= get_tuple_bytes(*tuple);
    454 
    455     notify_inserters_if_bounded();
    456 
    457     return Status::OK();
    458   }
    459 
    460   Status clear() {
    461     tensorflow::mutex_lock lock(mu_);
    462     map_.clear();
    463     incomplete_.clear();
    464     current_bytes_ = 0;
    465 
    466     notify_inserters_if_bounded();
    467 
    468     return Status::OK();
    469   }
    470 
    471   std::size_t incomplete_size() {
    472     tensorflow::mutex_lock lock(mu_);
    473     return incomplete_.size();
    474   }
    475 
    476   std::size_t size() {
    477     tensorflow::mutex_lock lock(mu_);
    478     return map_.size();
    479   }
    480 
    481   string DebugString() override { return "StagingMap"; }
    482 };
    483 
    484 template <bool Ordered>
    485 Status GetStagingMap(OpKernelContext* ctx, const NodeDef& ndef,
    486                      StagingMap<Ordered>** map) {
    487   auto rm = ctx->resource_manager();
    488   ContainerInfo cinfo;
    489 
    490   // Lambda for creating the Staging Area
    491   auto create_fn = [&ndef](StagingMap<Ordered>** ret) -> Status {
    492     DataTypeVector dtypes;
    493     int64 capacity;
    494     int64 memory_limit;
    495     TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "dtypes", &dtypes));
    496     TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "capacity", &capacity));
    497     TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "memory_limit", &memory_limit));
    498     *ret = new StagingMap<Ordered>(dtypes, capacity, memory_limit);
    499     return Status::OK();
    500   };
    501 
    502   TF_RETURN_IF_ERROR(cinfo.Init(rm, ndef, true /* use name() */));
    503   TF_RETURN_IF_ERROR(rm->LookupOrCreate<StagingMap<Ordered>>(
    504       cinfo.container(), cinfo.name(), map, create_fn));
    505   return Status::OK();
    506 }
    507 
    508 template <bool Ordered>
    509 class MapStageOp : public OpKernel {
    510  public:
    511   explicit MapStageOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
    512 
    513   void Compute(OpKernelContext* ctx) override {
    514     StagingMap<Ordered>* map = nullptr;
    515     OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
    516     core::ScopedUnref scope(map);
    517     typename StagingMap<Ordered>::OptionalTuple tuple;
    518 
    519     const Tensor* key_tensor;
    520     const Tensor* indices_tensor;
    521     OpInputList values_tensor;
    522 
    523     OP_REQUIRES_OK(ctx, ctx->input("key", &key_tensor));
    524     OP_REQUIRES_OK(ctx, ctx->input("indices", &indices_tensor));
    525     OP_REQUIRES_OK(ctx, ctx->input_list("values", &values_tensor));
    526 
    527     // Create copy for insertion into Staging Area
    528     Tensor key(*key_tensor);
    529 
    530     // Create the tuple to store
    531     for (std::size_t i = 0; i < values_tensor.size(); ++i) {
    532       tuple.push_back(values_tensor[i]);
    533     }
    534 
    535     // Store the tuple in the map
    536     OP_REQUIRES_OK(ctx, map->put(&key, indices_tensor, &tuple));
    537   }
    538 };
    539 
    540 REGISTER_KERNEL_BUILDER(Name("MapStage").Device(DEVICE_CPU), MapStageOp<false>);
    541 REGISTER_KERNEL_BUILDER(Name("OrderedMapStage").Device(DEVICE_CPU),
    542                         MapStageOp<true>);
    543 
    544 #if GOOGLE_CUDA
    545 REGISTER_KERNEL_BUILDER(
    546     Name("MapStage").HostMemory("key").HostMemory("indices").Device(DEVICE_GPU),
    547     MapStageOp<false>);
    548 REGISTER_KERNEL_BUILDER(Name("OrderedMapStage")
    549                             .HostMemory("key")
    550                             .HostMemory("indices")
    551                             .Device(DEVICE_GPU),
    552                         MapStageOp<true>);
    553 #endif  // GOOGLE_CUDA
    554 
    555 #ifdef TENSORFLOW_USE_SYCL
    556 REGISTER_KERNEL_BUILDER(Name("MapStage")
    557                             .HostMemory("key")
    558                             .HostMemory("indices")
    559                             .Device(DEVICE_SYCL),
    560                         MapStageOp<false>);
    561 REGISTER_KERNEL_BUILDER(Name("OrderedMapStage")
    562                             .HostMemory("key")
    563                             .HostMemory("indices")
    564                             .Device(DEVICE_SYCL),
    565                         MapStageOp<true>);
    566 #endif  // TENSORFLOW_USE_SYCL
    567 
    568 template <bool Ordered>
    569 class MapUnstageOp : public OpKernel {
    570  public:
    571   explicit MapUnstageOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
    572 
    573   // Using this op in such a way that it blocks forever
    574   // is an error.  As such cancellation is not handled.
    575   void Compute(OpKernelContext* ctx) override {
    576     StagingMap<Ordered>* map = nullptr;
    577     OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
    578     core::ScopedUnref scope(map);
    579     typename StagingMap<Ordered>::Tuple tuple;
    580 
    581     const Tensor* key_tensor;
    582     const Tensor* indices_tensor;
    583     OpInputList values_tensor;
    584 
    585     OP_REQUIRES_OK(ctx, ctx->input("key", &key_tensor));
    586     OP_REQUIRES_OK(ctx, ctx->input("indices", &indices_tensor));
    587     OP_REQUIRES_OK(ctx, map->pop(key_tensor, indices_tensor, &tuple));
    588 
    589     OP_REQUIRES(
    590         ctx, tuple.size() == indices_tensor->NumElements(),
    591         errors::InvalidArgument("output/indices size mismatch: ", tuple.size(),
    592                                 " vs. ", indices_tensor->NumElements()));
    593 
    594     for (std::size_t i = 0; i < tuple.size(); ++i) {
    595       ctx->set_output(i, tuple[i]);
    596     }
    597   }
    598 };
    599 
    600 REGISTER_KERNEL_BUILDER(Name("MapUnstage").Device(DEVICE_CPU),
    601                         MapUnstageOp<false>);
    602 REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstage").Device(DEVICE_CPU),
    603                         MapUnstageOp<true>);
    604 
    605 #if GOOGLE_CUDA
    606 REGISTER_KERNEL_BUILDER(Name("MapUnstage")
    607                             .HostMemory("key")
    608                             .HostMemory("indices")
    609                             .Device(DEVICE_GPU),
    610                         MapUnstageOp<false>);
    611 REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstage")
    612                             .HostMemory("key")
    613                             .HostMemory("indices")
    614                             .Device(DEVICE_GPU),
    615                         MapUnstageOp<true>);
    616 #endif
    617 #ifdef TENSORFLOW_USE_SYCL
    618 REGISTER_KERNEL_BUILDER(Name("MapUnstage")
    619                             .HostMemory("key")
    620                             .HostMemory("indices")
    621                             .Device(DEVICE_SYCL),
    622                         MapUnstageOp<false>);
    623 REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstage")
    624                             .HostMemory("key")
    625                             .HostMemory("indices")
    626                             .Device(DEVICE_SYCL),
    627                         MapUnstageOp<true>);
    628 #endif  // TENSORFLOW_USE_SYCL
    629 
    630 template <bool Ordered>
    631 class MapPeekOp : public OpKernel {
    632  public:
    633   explicit MapPeekOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
    634 
    635   // Using this op in such a way that it blocks forever
    636   // is an error.  As such cancellation is not handled.
    637   void Compute(OpKernelContext* ctx) override {
    638     StagingMap<Ordered>* map = nullptr;
    639     OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
    640     core::ScopedUnref scope(map);
    641     typename StagingMap<Ordered>::Tuple tuple;
    642 
    643     const Tensor* key_tensor;
    644     const Tensor* indices_tensor;
    645     OpInputList values_tensor;
    646 
    647     OP_REQUIRES_OK(ctx, ctx->input("key", &key_tensor));
    648     OP_REQUIRES_OK(ctx, ctx->input("indices", &indices_tensor));
    649     OP_REQUIRES_OK(ctx, map->get(key_tensor, indices_tensor, &tuple));
    650 
    651     OP_REQUIRES(
    652         ctx, tuple.size() == indices_tensor->NumElements(),
    653         errors::InvalidArgument("output/indices size mismatch: ", tuple.size(),
    654                                 " vs. ", indices_tensor->NumElements()));
    655 
    656     for (std::size_t i = 0; i < tuple.size(); ++i) {
    657       ctx->set_output(i, tuple[i]);
    658     }
    659   }
    660 };
    661 
    662 REGISTER_KERNEL_BUILDER(Name("MapPeek").Device(DEVICE_CPU), MapPeekOp<false>);
    663 REGISTER_KERNEL_BUILDER(Name("OrderedMapPeek").Device(DEVICE_CPU),
    664                         MapPeekOp<true>);
    665 
    666 #if GOOGLE_CUDA
    667 REGISTER_KERNEL_BUILDER(
    668     Name("MapPeek").HostMemory("key").HostMemory("indices").Device(DEVICE_GPU),
    669     MapPeekOp<false>);
    670 REGISTER_KERNEL_BUILDER(Name("OrderedMapPeek")
    671                             .HostMemory("key")
    672                             .HostMemory("indices")
    673                             .Device(DEVICE_GPU),
    674                         MapPeekOp<true>);
    675 #endif
    676 
    677 #ifdef TENSORFLOW_USE_SYCL
    678 REGISTER_KERNEL_BUILDER(
    679     Name("MapPeek").HostMemory("key").HostMemory("indices").Device(DEVICE_SYCL),
    680     MapPeekOp<false>);
    681 REGISTER_KERNEL_BUILDER(Name("OrderedMapPeek")
    682                             .HostMemory("key")
    683                             .HostMemory("indices")
    684                             .Device(DEVICE_SYCL),
    685                         MapPeekOp<true>);
    686 #endif  // TENSORFLOW_USE_SYCL
    687 
    688 template <bool Ordered>
    689 class MapUnstageNoKeyOp : public OpKernel {
    690  public:
    691   explicit MapUnstageNoKeyOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
    692 
    693   // Using this op in such a way that it blocks forever
    694   // is an error.  As such cancellation is not handled.
    695   void Compute(OpKernelContext* ctx) override {
    696     StagingMap<Ordered>* map = nullptr;
    697     OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
    698     core::ScopedUnref scope(map);
    699 
    700     // Pop a random (key, value) off the map
    701     typename StagingMap<Ordered>::KeyType key;
    702     typename StagingMap<Ordered>::Tuple tuple;
    703 
    704     const Tensor* indices_tensor;
    705 
    706     OP_REQUIRES_OK(ctx, ctx->input("indices", &indices_tensor));
    707     OP_REQUIRES_OK(ctx, map->popitem(&key, indices_tensor, &tuple));
    708 
    709     // Allocate a key tensor and assign the key as the first output
    710     ctx->set_output(0, key);
    711 
    712     // Set the rest of the outputs to the tuple Tensors
    713     OP_REQUIRES(
    714         ctx, tuple.size() == indices_tensor->NumElements(),
    715         errors::InvalidArgument("output/indices size mismatch: ", tuple.size(),
    716                                 " vs. ", indices_tensor->NumElements()));
    717 
    718     for (std::size_t i = 0; i < tuple.size(); ++i) {
    719       ctx->set_output(i + 1, tuple[i]);
    720     }
    721   }
    722 };
    723 
    724 REGISTER_KERNEL_BUILDER(Name("MapUnstageNoKey").Device(DEVICE_CPU),
    725                         MapUnstageNoKeyOp<false>);
    726 REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstageNoKey").Device(DEVICE_CPU),
    727                         MapUnstageNoKeyOp<true>);
    728 
    729 #if GOOGLE_CUDA
    730 REGISTER_KERNEL_BUILDER(Name("MapUnstageNoKey")
    731                             .HostMemory("key")
    732                             .HostMemory("indices")
    733                             .Device(DEVICE_GPU),
    734                         MapUnstageNoKeyOp<false>);
    735 REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstageNoKey")
    736                             .HostMemory("key")
    737                             .HostMemory("indices")
    738                             .Device(DEVICE_GPU),
    739                         MapUnstageNoKeyOp<true>);
    740 #endif
    741 
    742 #ifdef TENSORFLOW_USE_SYCL
    743 REGISTER_KERNEL_BUILDER(Name("MapUnstageNoKey")
    744                             .HostMemory("key")
    745                             .HostMemory("indices")
    746                             .Device(DEVICE_SYCL),
    747                         MapUnstageNoKeyOp<false>);
    748 REGISTER_KERNEL_BUILDER(Name("OrderedMapUnstageNoKey")
    749                             .HostMemory("key")
    750                             .HostMemory("indices")
    751                             .Device(DEVICE_SYCL),
    752                         MapUnstageNoKeyOp<true>);
    753 #endif  // TENSORFLOW_USE_SYCL
    754 
    755 template <bool Ordered>
    756 class MapSizeOp : public OpKernel {
    757  public:
    758   explicit MapSizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
    759 
    760   void Compute(OpKernelContext* ctx) override {
    761     StagingMap<Ordered>* map = nullptr;
    762     OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
    763     core::ScopedUnref scope(map);
    764 
    765     // Allocate size output tensor
    766     Tensor* size = nullptr;
    767     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &size));
    768 
    769     // Set it to the actual size
    770     size->scalar<int32>().setConstant(map->size());
    771   }
    772 };
    773 
    774 REGISTER_KERNEL_BUILDER(Name("MapSize").Device(DEVICE_CPU), MapSizeOp<false>);
    775 REGISTER_KERNEL_BUILDER(Name("OrderedMapSize").Device(DEVICE_CPU),
    776                         MapSizeOp<true>);
    777 
    778 #if GOOGLE_CUDA
    779 REGISTER_KERNEL_BUILDER(Name("MapSize").Device(DEVICE_GPU).HostMemory("size"),
    780                         MapSizeOp<false>);
    781 REGISTER_KERNEL_BUILDER(
    782     Name("OrderedMapSize").Device(DEVICE_GPU).HostMemory("size"),
    783     MapSizeOp<true>);
    784 #endif
    785 #ifdef TENSORFLOW_USE_SYCL
    786 REGISTER_KERNEL_BUILDER(Name("MapSize").Device(DEVICE_SYCL).HostMemory("size"),
    787                         MapSizeOp<false>);
    788 REGISTER_KERNEL_BUILDER(
    789     Name("OrderedMapSize").Device(DEVICE_SYCL).HostMemory("size"),
    790     MapSizeOp<true>);
    791 #endif  // TENSORFLOW_USE_SYCL
    792 
    793 template <bool Ordered>
    794 class MapIncompleteSizeOp : public OpKernel {
    795  public:
    796   explicit MapIncompleteSizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
    797 
    798   void Compute(OpKernelContext* ctx) override {
    799     StagingMap<Ordered>* map = nullptr;
    800     OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
    801     core::ScopedUnref scope(map);
    802 
    803     // Allocate size output tensor
    804     Tensor* size = nullptr;
    805     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &size));
    806 
    807     // Set it to the actual size
    808     size->scalar<int32>().setConstant(map->incomplete_size());
    809   }
    810 };
    811 
    812 REGISTER_KERNEL_BUILDER(Name("MapIncompleteSize").Device(DEVICE_CPU),
    813                         MapIncompleteSizeOp<false>);
    814 REGISTER_KERNEL_BUILDER(Name("OrderedMapIncompleteSize").Device(DEVICE_CPU),
    815                         MapIncompleteSizeOp<true>);
    816 
    817 #if GOOGLE_CUDA
    818 REGISTER_KERNEL_BUILDER(
    819     Name("MapIncompleteSize").Device(DEVICE_GPU).HostMemory("size"),
    820     MapIncompleteSizeOp<false>);
    821 REGISTER_KERNEL_BUILDER(
    822     Name("OrderedMapIncompleteSize").Device(DEVICE_GPU).HostMemory("size"),
    823     MapIncompleteSizeOp<true>);
    824 #endif
    825 #ifdef TENSORFLOW_USE_SYCL
    826 REGISTER_KERNEL_BUILDER(
    827     Name("MapIncompleteSize").Device(DEVICE_SYCL).HostMemory("size"),
    828     MapIncompleteSizeOp<false>);
    829 REGISTER_KERNEL_BUILDER(
    830     Name("OrderedMapIncompleteSize").Device(DEVICE_SYCL).HostMemory("size"),
    831     MapIncompleteSizeOp<true>);
    832 #endif  // TENSORFLOW_USE_SYCL
    833 
    834 template <bool Ordered>
    835 class MapClearOp : public OpKernel {
    836  public:
    837   explicit MapClearOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
    838 
    839   void Compute(OpKernelContext* ctx) override {
    840     StagingMap<Ordered>* map = nullptr;
    841     OP_REQUIRES_OK(ctx, GetStagingMap(ctx, def(), &map));
    842     core::ScopedUnref scope(map);
    843 
    844     OP_REQUIRES_OK(ctx, map->clear());
    845   }
    846 };
    847 
    848 REGISTER_KERNEL_BUILDER(Name("MapClear").Device(DEVICE_CPU), MapClearOp<false>);
    849 REGISTER_KERNEL_BUILDER(Name("OrderedMapClear").Device(DEVICE_CPU),
    850                         MapClearOp<true>);
    851 
    852 #if GOOGLE_CUDA
    853 REGISTER_KERNEL_BUILDER(Name("MapClear").Device(DEVICE_GPU), MapClearOp<false>);
    854 REGISTER_KERNEL_BUILDER(Name("OrderedMapClear").Device(DEVICE_GPU),
    855                         MapClearOp<true>);
    856 #endif
    857 #ifdef TENSORFLOW_USE_SYCL
    858 REGISTER_KERNEL_BUILDER(Name("MapClear").Device(DEVICE_SYCL),
    859                         MapClearOp<false>);
    860 REGISTER_KERNEL_BUILDER(Name("OrderedMapClear").Device(DEVICE_SYCL),
    861                         MapClearOp<true>);
    862 #endif  // TENSORFLOW_USE_SYCL
    863 
    864 }  // namespace
    865 }  // namespace tensorflow
    866