Home | History | Annotate | Download | only in data
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #ifndef TENSORFLOW_CORE_KERNELS_DATA_OPTIONAL_OPS_H_
     16 #define TENSORFLOW_CORE_KERNELS_DATA_OPTIONAL_OPS_H_
     17 
     18 #include <vector>
     19 
     20 #include "tensorflow/core/framework/tensor.h"
     21 #include "tensorflow/core/framework/variant_tensor_data.h"
     22 #include "tensorflow/core/util/tensor_ops_util.h"
     23 
     24 namespace tensorflow {
     25 namespace data {
     26 
     27 const char kOptionalVariantTypeName[] = "tensorflow::data::Optional";
     28 
     29 // Stores a DT_VARIANT value representing an Optional with the given value
     30 // in the `output_index`^th output of the given kernel execution context.
     31 Status WriteOptionalWithValueToOutput(OpKernelContext* ctx, int output_index,
     32                                       std::vector<Tensor> value);
     33 
     34 // Stores a DT_VARIANT value representing an Optional with no value
     35 // in the `output_index`^th output of the given kernel execution context.
     36 Status WriteOptionalNoneToOutput(OpKernelContext* ctx, int output_index);
     37 
     38 // An `OptionalVariant` can represent either an "actual value" (a tuple of
     39 // tensors) or "none", and may be stored in a DT_VARIANT tensor.
     40 class OptionalVariant {
     41  public:
     42   // Create an `OptionalVariant` with no actual value.
     43   OptionalVariant() : values_(nullptr) {}
     44 
     45   // Create an `OptionalVariant` with the actual value given by the tuple of
     46   // tensors in `values`.
     47   explicit OptionalVariant(std::vector<Tensor> values) {
     48     values_ = std::make_shared<std::vector<Tensor>>(std::move(values));
     49   }
     50 
     51   OptionalVariant(const OptionalVariant& other) : values_(other.values_) {}
     52 
     53   // Returns true if `this` represents an actual value.
     54   bool has_value() const { return values_ != nullptr; }
     55 
     56   // REQUIRES: `this->has_value()` must be true.
     57   const std::vector<Tensor>& get_values() const {
     58     DCHECK(values_) << "Tried to get values from an empty OptionalVariant";
     59     return *values_;
     60   }
     61 
     62   // Implementations of the necessary methods for using `OptionalVariant`
     63   // objects in DT_VARIANT tensors.
     64   string TypeName() const { return kOptionalVariantTypeName; }
     65   void Encode(VariantTensorData* data) const {
     66     data->set_metadata(values_ != nullptr);
     67     if (values_ != nullptr) {
     68       for (const auto& t : *values_) {
     69         *(data->add_tensors()) = t;
     70       }
     71     }
     72   }
     73 
     74   bool Decode(const VariantTensorData& data) {
     75     if (data.type_name() != TypeName()) {
     76       return false;
     77     }
     78     bool has_value = false;
     79     if (!data.get_metadata(&has_value)) {
     80       return false;
     81     }
     82     if (has_value) {
     83       values_ = std::make_shared<std::vector<Tensor>>(data.tensors());
     84     } else {
     85       values_.reset();
     86     }
     87     return true;
     88   }
     89 
     90   string DebugString() const {
     91     if (values_) {
     92       return strings::StrCat("OptionalVariant<", "values: (",
     93                              str_util::Join(*values_, ", ",
     94                                             [](string* s, const Tensor& elem) {
     95                                               *s = elem.DebugString();
     96                                             }),
     97                              ")>");
     98     } else {
     99       return strings::StrCat("OptionalVariant<None>");
    100     }
    101   }
    102 
    103  private:
    104   std::shared_ptr<const std::vector<Tensor>> values_;
    105 };
    106 
    107 template <typename Device>
    108 Status OptionalZerosLike(OpKernelContext* ctx, const OptionalVariant& x,
    109                          OptionalVariant* y) {
    110   if (!x.has_value()) {
    111     *y = x;
    112     return Status::OK();
    113   }
    114   std::vector<Tensor> zero_tensors;
    115   for (const Tensor& tensor : x.get_values()) {
    116     Tensor zero_t;
    117     TF_RETURN_IF_ERROR(ZerosLikeTensor<Device>(ctx, tensor, &zero_t));
    118     zero_tensors.push_back(std::move(zero_t));
    119   }
    120   *y = OptionalVariant(zero_tensors);
    121   return Status::OK();
    122 }
    123 
    124 template <typename Device>
    125 Status OptionalBinaryAdd(OpKernelContext* ctx, const OptionalVariant& a,
    126                          const OptionalVariant& b, OptionalVariant* out) {
    127   // TODO(skyewm): should adding a value to a non-value be a no-op instead?
    128   if (a.has_value() != b.has_value()) {
    129     return errors::InvalidArgument(
    130         "Cannot add optionals because one has a value and the other doesn't.");
    131   }
    132   if (!a.has_value()) {
    133     *out = a;
    134     return Status::OK();
    135   }
    136   if (a.get_values().size() != b.get_values().size()) {
    137     return errors::InvalidArgument(
    138         "Cannot add optionals because they have different numbers of "
    139         "components (",
    140         a.get_values().size(), " vs. ", b.get_values().size(), ").");
    141   }
    142   std::vector<Tensor> out_tensors;
    143   for (int i = 0; i < a.get_values().size(); ++i) {
    144     const Tensor& a_tensor = a.get_values()[i];
    145     const Tensor& b_tensor = b.get_values()[i];
    146     Tensor out_tensor;
    147     TF_RETURN_IF_ERROR(
    148         BinaryAddTensors<Device>(ctx, a_tensor, b_tensor, &out_tensor));
    149     out_tensors.push_back(std::move(out_tensor));
    150   }
    151   *out = OptionalVariant(out_tensors);
    152   return Status::OK();
    153 }
    154 
    155 class OptionalNoneOp : public OpKernel {
    156  public:
    157   explicit OptionalNoneOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
    158 
    159   void Compute(OpKernelContext* ctx) override;
    160 };
    161 
    162 class OptionalFromValueOp : public OpKernel {
    163  public:
    164   explicit OptionalFromValueOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
    165 
    166   void Compute(OpKernelContext* ctx) override;
    167 };
    168 
    169 class OptionalHasValueOp : public OpKernel {
    170  public:
    171   explicit OptionalHasValueOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
    172 
    173   void Compute(OpKernelContext* ctx) override;
    174 };
    175 
    176 class OptionalGetValueOp : public OpKernel {
    177  public:
    178   explicit OptionalGetValueOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
    179     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
    180     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
    181     OP_REQUIRES(
    182         ctx, output_shapes_.size() == output_types_.size(),
    183         errors::InvalidArgument(
    184             "output_types and output_shapes must be same length, got:\n",
    185             "output_types: ", output_types_.size(), "\n",
    186             "output_shapes: ", output_shapes_.size()));
    187   }
    188 
    189   void Compute(OpKernelContext* ctx) override;
    190 
    191  private:
    192   DataTypeVector output_types_;
    193   std::vector<PartialTensorShape> output_shapes_;
    194 };
    195 
    196 }  // namespace data
    197 }  // namespace tensorflow
    198 
    199 #endif  // TENSORFLOW_CORE_KERNELS_DATA_OPTIONAL_OPS_H_
    200