Home | History | Annotate | Download | only in distributed_runtime
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_TENSOR_CODING_H_
     17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_TENSOR_CODING_H_
     18 
     19 #include "tensorflow/core/framework/allocator.h"
     20 #include "tensorflow/core/framework/tensor.h"
     21 #include "tensorflow/core/lib/core/status.h"
     22 #include "tensorflow/core/platform/protobuf.h"
     23 #include "tensorflow/core/platform/types.h"
     24 #include "tensorflow/core/protobuf/worker.pb.h"
     25 
     26 namespace tensorflow {
     27 
     28 class Allocator;
     29 class DeviceBase;
     30 class TensorProto;
     31 
     32 // TensorResponse can be used as the destination of an RPC that returns
     33 // a RecvTensorResponse.  It efficiently decodes the incoming data
     34 // into Tensor contents as well as associated metadata.
     35 class TensorResponse {
     36  public:
     37   TensorResponse() {}
     38 
     39   // Reset to initial state.
     40   void Clear();
     41 
     42   // Clear just tensor_ and meta_ members without setting allocation
     43   // related members.
     44   void ClearTensor();
     45 
     46   // Initialize memory allocation related members.
     47   void InitAlloc(DeviceBase* d, const AllocatorAttributes& aa);
     48 
     49   // Source provides a way for a particular RPC implementation to provide
     50   // received data to ParseFrom.
     51   class Source {
     52    public:
     53     virtual ~Source();
     54 
     55     // Return the stream that contains the data to be parsed.
     56     // Note that this method might be invoked more than once if
     57     // ParseFrom needs to fall back to a more expensive parsing method.
     58     // Every call must return a stream pointing at the beginning of
     59     // the serialized RecvTensorResponse.
     60     //
     61     // Note that a subsequent call to contents() invalidates previous
     62     // results of contents().
     63     //
     64     // Ownership of the returned stream is retained by the Source and
     65     // should not be deleted by the caller.
     66     virtual ::tensorflow::protobuf::io::ZeroCopyInputStream* contents() = 0;
     67   };
     68 
     69   // Parse the RecvTensorResponse encoded in the data yielded by
     70   // source->contents() into *this.
     71   Status ParseFrom(Source* source);
     72 
     73   // Initialize tensor from *response.
     74   // Leaves *response with unspecified contents.
     75   Status InitFrom(RecvTensorResponse* response);
     76 
     77   // Initialize tensor metadata from response and allocate
     78   // uninitialized backing storage for actual contents.
     79   void InitPartial(const RecvTensorResponse& response);
     80 
     81   // Return a reference to the parsed tensor.  The tensor will remain
     82   // live only until *this is destroyed or modified.
     83   const Tensor& tensor() const { return tensor_; }
     84 
     85   // Return a reference to the parsed tensor metadata (no contents).
     86   // The result will remain live only until *this is destroyed or
     87   // modified.
     88   const RecvTensorResponse& metadata() const { return meta_; }
     89 
     90  private:
     91   bool ParseTensorSubmessage(protobuf::io::CodedInputStream* input,
     92                              TensorProto* tensor_meta);
     93   bool ParseFast(Source* source);
     94   bool ParseSlow(Source* source);
     95 
     96   bool on_host_ = false;
     97   DeviceBase* device_ = nullptr;
     98   AllocatorAttributes alloc_attrs_;
     99   Allocator* allocator_ = nullptr;
    100   bool already_used_ = false;
    101   Tensor tensor_;
    102   RecvTensorResponse meta_;
    103 };
    104 
    105 }  // namespace tensorflow
    106 
    107 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_TENSOR_CODING_H_
    108