Home | History | Annotate | Download | only in tf2xla
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_RESOURCE_H_
     17 #define TENSORFLOW_COMPILER_TF2XLA_XLA_RESOURCE_H_
     18 
     19 #include <memory>
     20 
     21 #include "tensorflow/compiler/xla/client/computation_builder.h"
     22 #include "tensorflow/compiler/xla/xla_data.pb.h"
     23 #include "tensorflow/core/framework/tensor_shape.h"
     24 #include "tensorflow/core/framework/types.pb.h"
     25 #include "tensorflow/core/lib/core/status.h"
     26 
     27 namespace tensorflow {
     28 
     29 // Represents a resource, such as a Variable or TensorArray.
     30 class XlaResource {
     31  public:
     32   enum Kind {
     33     kInvalid,
     34     kVariable,
     35     kTensorArray,
     36     kStack,
     37   };
     38 
     39   XlaResource(Kind kind, int arg_num, string name, DataType type,
     40               TensorShape shape,
     41               const xla::ComputationDataHandle& initial_value,
     42               int64 tensor_array_size,
     43               const std::set<string>& tensor_array_gradients);
     44 
     45   XlaResource(const XlaResource&) = delete;
     46   XlaResource(XlaResource&&) = delete;
     47   XlaResource& operator=(const XlaResource&) = delete;
     48   XlaResource& operator=(XlaResource&&) = delete;
     49 
     50   Kind kind() const { return kind_; }
     51 
     52   // If this resource is visible externally to the computation, what was its
     53   // argument number?
     54   // < 0 means "not visible externally".
     55   int arg_num() const { return arg_num_; }
     56 
     57   // A descriptive name for the resource, used in error messages.
     58   const string& name() const { return name_; }
     59 
     60   // Current type and value of the resource. Uninitialized resources are
     61   // represented by a default (zero) handle and type DT_INVALID.
     62   // While the type of a resource is notionally fixed during execution, when
     63   // a resource is first initialized we do not yet know its type, so we keep
     64   // track of its type dynamically.
     65   DataType type() const { return type_; }
     66 
     67   // Shape of the resource. For an uninitialized resource, this is ignored.
     68   // For a Variable, this is the shape of the value. For a TensorArray or Stack
     69   // this is the shape of each entry in the TensorArray/Stack.
     70   const TensorShape& shape() const { return shape_; }
     71 
     72   const xla::ComputationDataHandle& value() const { return value_; }
     73 
     74   // Value of the resource at computation entry. Used to detect which
     75   // variables have new values that need to be written back.
     76   const xla::ComputationDataHandle& initial_value() const {
     77     return initial_value_;
     78   }
     79 
     80   // A variable is initialized if it has a value.
     81   bool initialized() const { return value_.handle() > 0; }
     82 
     83   // Sets the type and shape of the resource. The type and shape of a resource
     84   // must not change once the variable has been initialized.
     85   Status SetTypeAndShape(DataType type, const TensorShape& shape);
     86 
     87   // Sets the current value of the resource. Returns an error if the type is not
     88   // set to a valid value.
     89   Status SetValue(const xla::ComputationDataHandle& value);
     90 
     91   // Sets the current value of the resource to an all-zero value.
     92   Status SetZeroValue(xla::ComputationBuilder* builder);
     93 
     94   // Looks up the gradient for `source`, or creates it if it does not already
     95   // exist. The call target must be an initialized TensorArray resource. A
     96   // TensorArray can have multiple named gradients; see the operator
     97   // documentation for TensorArrayGradV3 for details.
     98   Status GetOrCreateTensorArrayGradient(const string& source,
     99                                         xla::ComputationBuilder* builder,
    100                                         XlaResource** gradient_out);
    101 
    102   // Packs a resource into a single XLA value `pack`, suitable for use as
    103   // an XlaCompiler::Argument. For non-TensorArrays or TensorArrays without
    104   // gradients, sets `*pack` to `value`.
    105   // For TensorArrays with gradients, packs the value and its gradient values in
    106   // a tuple; the gradients values are packed in order by source name.
    107   Status Pack(xla::ComputationDataHandle* pack,
    108               xla::ComputationBuilder* builder) const;
    109 
    110   // Updates the resource with values from `pack`. If `gradient_sources` is
    111   // non-empty, treats `pack` as a tuple that represents a TensorArray and
    112   // its gradients, and unpacks and updates the gradient resources.
    113   // If `reset_initial_values` is true, sets the initial_values as well as the
    114   // values.
    115   // Opposite of Pack().
    116   Status SetFromPack(const std::set<string>& gradient_sources,
    117                      const xla::ComputationDataHandle& pack,
    118                      xla::ComputationBuilder* builder);
    119 
    120   // TensorArray and Stack specific fields
    121 
    122   // 'tensor_array_size' stores the expected size of the TensorArray or Stack.
    123   // We need to store this since sometimes TensorArrays must be initialized
    124   // lazily since we do not know the element shape at construction time.
    125   // Used by both TensorArrays and Stacks.
    126   int64 tensor_array_size() const { return tensor_array_size_; }
    127   void set_tensor_array_size(int64 size) { tensor_array_size_ = size; }
    128 
    129   // 'tensor_array_gradient' is a map from TensorArrayGradV3 'source' attributes
    130   // to an XlaResource containing the gradient TensorArrays. We store a pointer
    131   // here since there should only be one gradient TensorArray per 'source'
    132   // string, irrespective of the number of calls to TensorArrayGrad. The map
    133   // is ordered since values are packed into tuples by Pack() sorted by name
    134   // order.
    135   const std::map<string, std::unique_ptr<XlaResource>>& tensor_array_gradients()
    136       const {
    137     return tensor_array_gradients_;
    138   }
    139 
    140  private:
    141   const Kind kind_;
    142   const int arg_num_;
    143   const string name_;
    144 
    145   DataType type_;
    146   TensorShape shape_;
    147   xla::ComputationDataHandle value_;
    148   xla::ComputationDataHandle initial_value_;
    149 
    150   int64 tensor_array_size_ = -1;
    151 
    152   std::map<string, std::unique_ptr<XlaResource>> tensor_array_gradients_;
    153 };
    154 
    155 }  // namespace tensorflow
    156 
    157 #endif  // TENSORFLOW_COMPILER_TF2XLA_XLA_RESOURCE_H_
    158