Home | History | Annotate | Download | only in rpc
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_
     17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_
     18 
     19 #include <memory>
     20 
     21 #include "grpc++/grpc++.h"
     22 #include "grpc++/impl/codegen/proto_utils.h"
     23 #include "grpc++/support/byte_buffer.h"
     24 #include "tensorflow/core/distributed_runtime/tensor_coding.h"
     25 #include "tensorflow/core/lib/core/status.h"
     26 #include "tensorflow/core/lib/strings/stringprintf.h"
     27 #include "tensorflow/core/platform/mutex.h"
     28 #include "tensorflow/core/platform/protobuf.h"
     29 
     30 namespace tensorflow {
     31 
     32 constexpr char kStreamRemovedMessage[] = "Stream removed";
     33 
     34 // Identify if the given grpc::Status corresponds to an HTTP stream removed
     35 // error (see chttp2_transport.cc).
     36 //
     37 // When auto-reconnecting to a remote TensorFlow worker after it restarts, gRPC
     38 // can return an UNKNOWN error code with a "Stream removed" error message.
     39 // This should not be treated as an unrecoverable error.
     40 //
     41 // N.B. This is dependent on the error message from grpc remaining consistent.
     42 inline bool IsStreamRemovedError(const ::grpc::Status& s) {
     43   return !s.ok() && s.error_code() == ::grpc::StatusCode::UNKNOWN &&
     44          s.error_message() == kStreamRemovedMessage;
     45 }
     46 
     47 inline Status FromGrpcStatus(const ::grpc::Status& s) {
     48   if (s.ok()) {
     49     return Status::OK();
     50   } else {
     51     // Convert "UNKNOWN" stream removed errors into unavailable, to allow
     52     // for retry upstream.
     53     if (IsStreamRemovedError(s)) {
     54       return Status(tensorflow::error::UNAVAILABLE, s.error_message());
     55     }
     56     return Status(static_cast<tensorflow::error::Code>(s.error_code()),
     57                   s.error_message());
     58   }
     59 }
     60 
     61 inline ::grpc::Status ToGrpcStatus(const ::tensorflow::Status& s) {
     62   if (s.ok()) {
     63     return ::grpc::Status::OK;
     64   } else {
     65     if (s.error_message().size() > 3072 /* 3k bytes */) {
     66       // TODO(b/62947679): Remove truncation once the gRPC issue is resolved.
     67       string scratch =
     68           strings::Printf("%.3072s ... [truncated]", s.error_message().c_str());
     69       LOG(ERROR) << "Truncated error message: " << s;
     70       return ::grpc::Status(static_cast<::grpc::StatusCode>(s.code()), scratch);
     71     }
     72     return ::grpc::Status(static_cast<::grpc::StatusCode>(s.code()),
     73                           s.error_message());
     74   }
     75 }
     76 
     77 typedef std::shared_ptr<::grpc::Channel> SharedGrpcChannelPtr;
     78 
     79 inline string GrpcIdKey() { return "tf-rpc"; }
     80 
     81 // Serialize src and store in *dst.
     82 void GrpcMaybeUnparseProto(const protobuf::Message& src,
     83                            ::grpc::ByteBuffer* dst);
     84 
     85 // Parse contents of src and initialize *dst with them.
     86 bool GrpcMaybeParseProto(const ::grpc::ByteBuffer& src, protobuf::Message* dst);
     87 
     88 // Specialization for TensorResponse
     89 bool GrpcMaybeParseProto(const ::grpc::ByteBuffer& src, TensorResponse* dst);
     90 
     91 // Copy string src to grpc buffer *dst.
     92 void GrpcMaybeUnparseProto(const string& src, ::grpc::ByteBuffer* dst);
     93 
     94 // Copy grpc buffer src to string *dst.
     95 bool GrpcMaybeParseProto(const ::grpc::ByteBuffer& src, string* dst);
     96 
     97 // A ZeroCopyInputStream that reads from a grpc::ByteBuffer.
     98 class GrpcByteBufferSource : public ::grpc::protobuf::io::ZeroCopyInputStream {
     99  public:
    100   GrpcByteBufferSource();
    101   bool Init(const ::grpc::ByteBuffer& src);  // Can be called multiple times.
    102   bool Next(const void** data, int* size) override;
    103   void BackUp(int count) override;
    104   bool Skip(int count) override;
    105   ::grpc::protobuf::int64 ByteCount() const override;
    106 
    107  private:
    108   std::vector<::grpc::Slice> slices_;
    109   int cur_;          // Current slice index.
    110   int left_;         // Number of bytes in slices_[cur_] left to yield.
    111   const char* ptr_;  // Address of next byte in slices_[cur_] to yield.
    112   ::grpc::protobuf::int64 byte_count_;
    113 };
    114 
    115 }  // namespace tensorflow
    116 
    117 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_
    118