1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_KERNELS_DATA_OPTIONAL_OPS_H_ 16 #define TENSORFLOW_CORE_KERNELS_DATA_OPTIONAL_OPS_H_ 17 18 #include <vector> 19 20 #include "tensorflow/core/framework/tensor.h" 21 #include "tensorflow/core/framework/variant_tensor_data.h" 22 #include "tensorflow/core/util/tensor_ops_util.h" 23 24 namespace tensorflow { 25 namespace data { 26 27 const char kOptionalVariantTypeName[] = "tensorflow::data::Optional"; 28 29 // Stores a DT_VARIANT value representing an Optional with the given value 30 // in the `output_index`^th output of the given kernel execution context. 31 Status WriteOptionalWithValueToOutput(OpKernelContext* ctx, int output_index, 32 std::vector<Tensor> value); 33 34 // Stores a DT_VARIANT value representing an Optional with no value 35 // in the `output_index`^th output of the given kernel execution context. 36 Status WriteOptionalNoneToOutput(OpKernelContext* ctx, int output_index); 37 38 // An `OptionalVariant` can represent either an "actual value" (a tuple of 39 // tensors) or "none", and may be stored in a DT_VARIANT tensor. 40 class OptionalVariant { 41 public: 42 // Create an `OptionalVariant` with no actual value. 43 OptionalVariant() : values_(nullptr) {} 44 45 // Create an `OptionalVariant` with the actual value given by the tuple of 46 // tensors in `values`. 47 explicit OptionalVariant(std::vector<Tensor> values) { 48 values_ = std::make_shared<std::vector<Tensor>>(std::move(values)); 49 } 50 51 OptionalVariant(const OptionalVariant& other) : values_(other.values_) {} 52 53 // Returns true if `this` represents an actual value. 54 bool has_value() const { return values_ != nullptr; } 55 56 // REQUIRES: `this->has_value()` must be true. 57 const std::vector<Tensor>& get_values() const { 58 DCHECK(values_) << "Tried to get values from an empty OptionalVariant"; 59 return *values_; 60 } 61 62 // Implementations of the necessary methods for using `OptionalVariant` 63 // objects in DT_VARIANT tensors. 64 string TypeName() const { return kOptionalVariantTypeName; } 65 void Encode(VariantTensorData* data) const { 66 data->set_metadata(values_ != nullptr); 67 if (values_ != nullptr) { 68 for (const auto& t : *values_) { 69 *(data->add_tensors()) = t; 70 } 71 } 72 } 73 74 bool Decode(const VariantTensorData& data) { 75 if (data.type_name() != TypeName()) { 76 return false; 77 } 78 bool has_value = false; 79 if (!data.get_metadata(&has_value)) { 80 return false; 81 } 82 if (has_value) { 83 values_ = std::make_shared<std::vector<Tensor>>(data.tensors()); 84 } else { 85 values_.reset(); 86 } 87 return true; 88 } 89 90 string DebugString() const { 91 if (values_) { 92 return strings::StrCat("OptionalVariant<", "values: (", 93 str_util::Join(*values_, ", ", 94 [](string* s, const Tensor& elem) { 95 *s = elem.DebugString(); 96 }), 97 ")>"); 98 } else { 99 return strings::StrCat("OptionalVariant<None>"); 100 } 101 } 102 103 private: 104 std::shared_ptr<const std::vector<Tensor>> values_; 105 }; 106 107 template <typename Device> 108 Status OptionalZerosLike(OpKernelContext* ctx, const OptionalVariant& x, 109 OptionalVariant* y) { 110 if (!x.has_value()) { 111 *y = x; 112 return Status::OK(); 113 } 114 std::vector<Tensor> zero_tensors; 115 for (const Tensor& tensor : x.get_values()) { 116 Tensor zero_t; 117 TF_RETURN_IF_ERROR(ZerosLikeTensor<Device>(ctx, tensor, &zero_t)); 118 zero_tensors.push_back(std::move(zero_t)); 119 } 120 *y = OptionalVariant(zero_tensors); 121 return Status::OK(); 122 } 123 124 template <typename Device> 125 Status OptionalBinaryAdd(OpKernelContext* ctx, const OptionalVariant& a, 126 const OptionalVariant& b, OptionalVariant* out) { 127 // TODO(skyewm): should adding a value to a non-value be a no-op instead? 128 if (a.has_value() != b.has_value()) { 129 return errors::InvalidArgument( 130 "Cannot add optionals because one has a value and the other doesn't."); 131 } 132 if (!a.has_value()) { 133 *out = a; 134 return Status::OK(); 135 } 136 if (a.get_values().size() != b.get_values().size()) { 137 return errors::InvalidArgument( 138 "Cannot add optionals because they have different numbers of " 139 "components (", 140 a.get_values().size(), " vs. ", b.get_values().size(), ")."); 141 } 142 std::vector<Tensor> out_tensors; 143 for (int i = 0; i < a.get_values().size(); ++i) { 144 const Tensor& a_tensor = a.get_values()[i]; 145 const Tensor& b_tensor = b.get_values()[i]; 146 Tensor out_tensor; 147 TF_RETURN_IF_ERROR( 148 BinaryAddTensors<Device>(ctx, a_tensor, b_tensor, &out_tensor)); 149 out_tensors.push_back(std::move(out_tensor)); 150 } 151 *out = OptionalVariant(out_tensors); 152 return Status::OK(); 153 } 154 155 class OptionalNoneOp : public OpKernel { 156 public: 157 explicit OptionalNoneOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 158 159 void Compute(OpKernelContext* ctx) override; 160 }; 161 162 class OptionalFromValueOp : public OpKernel { 163 public: 164 explicit OptionalFromValueOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 165 166 void Compute(OpKernelContext* ctx) override; 167 }; 168 169 class OptionalHasValueOp : public OpKernel { 170 public: 171 explicit OptionalHasValueOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 172 173 void Compute(OpKernelContext* ctx) override; 174 }; 175 176 class OptionalGetValueOp : public OpKernel { 177 public: 178 explicit OptionalGetValueOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 179 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); 180 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); 181 OP_REQUIRES( 182 ctx, output_shapes_.size() == output_types_.size(), 183 errors::InvalidArgument( 184 "output_types and output_shapes must be same length, got:\n", 185 "output_types: ", output_types_.size(), "\n", 186 "output_shapes: ", output_shapes_.size())); 187 } 188 189 void Compute(OpKernelContext* ctx) override; 190 191 private: 192 DataTypeVector output_types_; 193 std::vector<PartialTensorShape> output_shapes_; 194 }; 195 196 } // namespace data 197 } // namespace tensorflow 198 199 #endif // TENSORFLOW_CORE_KERNELS_DATA_OPTIONAL_OPS_H_ 200