Home | History | Annotate | Download | only in data
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/core/kernels/data/optional_ops.h"
     16 
     17 #include "tensorflow/core/common_runtime/dma_helper.h"
     18 #include "tensorflow/core/framework/op_kernel.h"
     19 #include "tensorflow/core/framework/variant_encode_decode.h"
     20 #include "tensorflow/core/framework/variant_op_registry.h"
     21 
     22 namespace tensorflow {
     23 namespace data {
     24 namespace {
     25 
     26 static Status OptionalDeviceCopy(
     27     const OptionalVariant& from, OptionalVariant* to,
     28     const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) {
     29   if (from.has_value()) {
     30     const std::vector<Tensor>& from_values = from.get_values();
     31     std::vector<Tensor> to_values;
     32     to_values.reserve(from_values.size());
     33     for (const Tensor& t : from_values) {
     34       if (DMAHelper::CanUseDMA(&t) || t.dtype() == DT_VARIANT) {
     35         // NOTE(skyewm): we're careful to make sure the lifetime of the 'to'
     36         // Tensor passed to `copy` (i.e. to_values.back()) is the same as the
     37         // returned 'to' OptionalVariant. This is because `copy` may spawn async
     38         // callbacks that don't run until after this function returns and access
     39         // the 'to' Tensor (e.g. BaseGPUDevice::MaybeCopyTensorToGPU).
     40         to_values.emplace_back(t.dtype());
     41         TF_RETURN_IF_ERROR(copy(t, &to_values.back()));
     42       } else {
     43         to_values.push_back(t);
     44       }
     45     }
     46     *to = OptionalVariant(std::move(to_values));
     47   } else {
     48     *to = from;
     49   }
     50   return Status::OK();
     51 }
     52 
     53 #define REGISTER_OPTIONAL_COPY(DIRECTION)               \
     54   INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
     55       OptionalVariant, DIRECTION, OptionalDeviceCopy)
     56 
     57 REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
     58 REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
     59 REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::DEVICE_TO_DEVICE);
     60 
     61 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(OptionalVariant,
     62                                        kOptionalVariantTypeName);
     63 
     64 }  // namespace
     65 
     66 void OptionalNoneOp::Compute(OpKernelContext* ctx) {
     67   OP_REQUIRES_OK(ctx, WriteOptionalNoneToOutput(ctx, 0));
     68 }
     69 
     70 void OptionalFromValueOp::Compute(OpKernelContext* ctx) {
     71   OpInputList components_input;
     72   OP_REQUIRES_OK(ctx, ctx->input_list("components", &components_input));
     73   std::vector<Tensor> components(components_input.begin(),
     74                                  components_input.end());
     75   OP_REQUIRES_OK(ctx,
     76                  WriteOptionalWithValueToOutput(ctx, 0, std::move(components)));
     77 }
     78 
     79 void OptionalHasValueOp::Compute(OpKernelContext* ctx) {
     80   const Tensor* optional_input;
     81   OP_REQUIRES_OK(ctx, ctx->input("optional", &optional_input));
     82   OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(optional_input->shape()),
     83               errors::InvalidArgument(
     84                   "Input to OptionalHasValue must be a scalar tensor "
     85                   "containing an OptionalVariant object."));
     86   const OptionalVariant* optional =
     87       optional_input->scalar<Variant>()().get<OptionalVariant>();
     88   OP_REQUIRES(
     89       ctx, optional != nullptr,
     90       errors::InvalidArgument(
     91           "Input to OptionalHasValue must be an OptionalVariant object."));
     92   Tensor* result;
     93   OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {}, &result));
     94   result->scalar<bool>()() = optional->has_value();
     95 }
     96 
     97 void OptionalGetValueOp::Compute(OpKernelContext* ctx) {
     98   const Tensor* optional_input;
     99   OP_REQUIRES_OK(ctx, ctx->input("optional", &optional_input));
    100   OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(optional_input->shape()),
    101               errors::InvalidArgument(
    102                   "Input to OptionalHasValue must be a scalar tensor "
    103                   "containing an OptionalVariant object."));
    104   const OptionalVariant* optional =
    105       optional_input->scalar<Variant>()().get<OptionalVariant>();
    106   OP_REQUIRES(
    107       ctx, optional != nullptr,
    108       errors::InvalidArgument(
    109           "Input to OptionalHasValue must be an OptionalVariant object."));
    110   OP_REQUIRES(
    111       ctx, optional->has_value(),
    112       errors::InvalidArgument("The given optional does not have a value."));
    113   const auto& components = optional->get_values();
    114   OP_REQUIRES(
    115       ctx, components.size() == output_types_.size(),
    116       errors::InvalidArgument("The given optional has ", components.size(),
    117                               " components, expected ", output_types_.size()));
    118   for (int i = 0; i < components.size(); ++i) {
    119     OP_REQUIRES(ctx, components[i].dtype() == output_types_[i],
    120                 errors::InvalidArgument(
    121                     "The given optional does not match the expected type for "
    122                     "component ",
    123                     i, ". Expected: ", DataTypeString(output_types_[i]),
    124                     ". Actual: ", DataTypeString(components[i].dtype()), "."));
    125     OP_REQUIRES(ctx, output_shapes_[i].IsCompatibleWith(components[i].shape()),
    126                 errors::InvalidArgument(
    127                     "The given optional does not match the expected shape "
    128                     "for component ",
    129                     i, ". Expected: ", output_shapes_[i].DebugString(),
    130                     ". Actual: ", components[i].shape().DebugString(), "."));
    131     ctx->set_output(i, components[i]);
    132   }
    133 }
    134 
    135 Status WriteOptionalWithValueToOutput(OpKernelContext* ctx, int output_index,
    136                                       std::vector<Tensor> value) {
    137   OptionalVariant v(std::move(value));
    138   Tensor* variant_t;
    139   AllocatorAttributes cpu_alloc;
    140   cpu_alloc.set_on_host(true);
    141   TF_RETURN_IF_ERROR(ctx->allocate_output(output_index, TensorShape({}),
    142                                           &variant_t, cpu_alloc));
    143   variant_t->scalar<Variant>()() = v;
    144   return Status::OK();
    145 }
    146 
    147 Status WriteOptionalNoneToOutput(OpKernelContext* ctx, int output_index) {
    148   OptionalVariant v;
    149   Tensor* variant_t;
    150   AllocatorAttributes cpu_alloc;
    151   cpu_alloc.set_on_host(true);
    152   TF_RETURN_IF_ERROR(ctx->allocate_output(output_index, TensorShape({}),
    153                                           &variant_t, cpu_alloc));
    154   variant_t->scalar<Variant>()() = v;
    155   return Status::OK();
    156 }
    157 
    158 namespace {
    159 
    160 REGISTER_KERNEL_BUILDER(Name("OptionalNone").Device(DEVICE_CPU).Priority(2),
    161                         OptionalNoneOp);
    162 REGISTER_KERNEL_BUILDER(Name("OptionalNone").Device(DEVICE_GPU).Priority(1),
    163                         OptionalNoneOp);
    164 REGISTER_KERNEL_BUILDER(
    165     Name("OptionalFromValue").Device(DEVICE_CPU).Priority(2),
    166     OptionalFromValueOp);
    167 REGISTER_KERNEL_BUILDER(
    168     Name("OptionalFromValue").Device(DEVICE_GPU).Priority(1),
    169     OptionalFromValueOp);
    170 
    171 REGISTER_KERNEL_BUILDER(Name("OptionalHasValue").Device(DEVICE_CPU).Priority(2),
    172                         OptionalHasValueOp);
    173 REGISTER_KERNEL_BUILDER(Name("OptionalHasValue")
    174                             .Device(DEVICE_GPU)
    175                             .HostMemory("has_value")
    176                             .Priority(1),
    177                         OptionalHasValueOp);
    178 REGISTER_KERNEL_BUILDER(Name("OptionalGetValue").Device(DEVICE_CPU).Priority(2),
    179                         OptionalGetValueOp);
    180 REGISTER_KERNEL_BUILDER(Name("OptionalGetValue").Device(DEVICE_GPU).Priority(1),
    181                         OptionalGetValueOp);
    182 
    183 }  // namespace
    184 
    185 REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
    186                                          DEVICE_CPU, OptionalVariant,
    187                                          OptionalZerosLike<CPUDevice>);
    188 
    189 REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
    190                                           OptionalVariant,
    191                                           OptionalBinaryAdd<CPUDevice>);
    192 
    193 }  // namespace data
    194 }  // namespace tensorflow
    195