1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_ 17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_ 18 19 #include <memory> 20 21 #include "grpc++/grpc++.h" 22 #include "grpc++/impl/codegen/proto_utils.h" 23 #include "grpc++/support/byte_buffer.h" 24 #include "tensorflow/core/distributed_runtime/tensor_coding.h" 25 #include "tensorflow/core/lib/core/status.h" 26 #include "tensorflow/core/lib/strings/stringprintf.h" 27 #include "tensorflow/core/platform/mutex.h" 28 #include "tensorflow/core/platform/protobuf.h" 29 30 namespace tensorflow { 31 32 constexpr char kStreamRemovedMessage[] = "Stream removed"; 33 34 // Identify if the given grpc::Status corresponds to an HTTP stream removed 35 // error (see chttp2_transport.cc). 36 // 37 // When auto-reconnecting to a remote TensorFlow worker after it restarts, gRPC 38 // can return an UNKNOWN error code with a "Stream removed" error message. 39 // This should not be treated as an unrecoverable error. 40 // 41 // N.B. This is dependent on the error message from grpc remaining consistent. 42 inline bool IsStreamRemovedError(const ::grpc::Status& s) { 43 return !s.ok() && s.error_code() == ::grpc::StatusCode::UNKNOWN && 44 s.error_message() == kStreamRemovedMessage; 45 } 46 47 inline Status FromGrpcStatus(const ::grpc::Status& s) { 48 if (s.ok()) { 49 return Status::OK(); 50 } else { 51 // Convert "UNKNOWN" stream removed errors into unavailable, to allow 52 // for retry upstream. 53 if (IsStreamRemovedError(s)) { 54 return Status(tensorflow::error::UNAVAILABLE, s.error_message()); 55 } 56 return Status(static_cast<tensorflow::error::Code>(s.error_code()), 57 s.error_message()); 58 } 59 } 60 61 inline ::grpc::Status ToGrpcStatus(const ::tensorflow::Status& s) { 62 if (s.ok()) { 63 return ::grpc::Status::OK; 64 } else { 65 if (s.error_message().size() > 3072 /* 3k bytes */) { 66 // TODO(b/62947679): Remove truncation once the gRPC issue is resolved. 67 string scratch = 68 strings::Printf("%.3072s ... [truncated]", s.error_message().c_str()); 69 LOG(ERROR) << "Truncated error message: " << s; 70 return ::grpc::Status(static_cast<::grpc::StatusCode>(s.code()), scratch); 71 } 72 return ::grpc::Status(static_cast<::grpc::StatusCode>(s.code()), 73 s.error_message()); 74 } 75 } 76 77 typedef std::shared_ptr<::grpc::Channel> SharedGrpcChannelPtr; 78 79 inline string GrpcIdKey() { return "tf-rpc"; } 80 81 // Serialize src and store in *dst. 82 void GrpcMaybeUnparseProto(const protobuf::Message& src, 83 ::grpc::ByteBuffer* dst); 84 85 // Parse contents of src and initialize *dst with them. 86 bool GrpcMaybeParseProto(const ::grpc::ByteBuffer& src, protobuf::Message* dst); 87 88 // Specialization for TensorResponse 89 bool GrpcMaybeParseProto(const ::grpc::ByteBuffer& src, TensorResponse* dst); 90 91 // Copy string src to grpc buffer *dst. 92 void GrpcMaybeUnparseProto(const string& src, ::grpc::ByteBuffer* dst); 93 94 // Copy grpc buffer src to string *dst. 95 bool GrpcMaybeParseProto(const ::grpc::ByteBuffer& src, string* dst); 96 97 // A ZeroCopyInputStream that reads from a grpc::ByteBuffer. 98 class GrpcByteBufferSource : public ::grpc::protobuf::io::ZeroCopyInputStream { 99 public: 100 GrpcByteBufferSource(); 101 bool Init(const ::grpc::ByteBuffer& src); // Can be called multiple times. 102 bool Next(const void** data, int* size) override; 103 void BackUp(int count) override; 104 bool Skip(int count) override; 105 ::grpc::protobuf::int64 ByteCount() const override; 106 107 private: 108 std::vector<::grpc::Slice> slices_; 109 int cur_; // Current slice index. 110 int left_; // Number of bytes in slices_[cur_] left to yield. 111 const char* ptr_; // Address of next byte in slices_[cur_] to yield. 112 ::grpc::protobuf::int64 byte_count_; 113 }; 114 115 } // namespace tensorflow 116 117 #endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_ 118