Home | History | Annotate | Download | only in cloud
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_PLATFORM_HTTP_REQUEST_FAKE_H_
     17 #define TENSORFLOW_CORE_PLATFORM_HTTP_REQUEST_FAKE_H_
     18 
     19 #include <algorithm>
     20 #include <fstream>
     21 #include <string>
     22 #include <vector>
     23 #include <curl/curl.h>
     24 #include "tensorflow/core/lib/core/errors.h"
     25 #include "tensorflow/core/lib/core/status.h"
     26 #include "tensorflow/core/lib/core/status_test_util.h"
     27 #include "tensorflow/core/lib/core/stringpiece.h"
     28 #include "tensorflow/core/platform/cloud/curl_http_request.h"
     29 #include "tensorflow/core/platform/macros.h"
     30 #include "tensorflow/core/platform/protobuf.h"
     31 #include "tensorflow/core/platform/test.h"
     32 #include "tensorflow/core/platform/types.h"
     33 
     34 namespace tensorflow {
     35 
     36 /// Fake HttpRequest for testing.
     37 class FakeHttpRequest : public CurlHttpRequest {
     38  public:
     39   /// Return the response for the given request.
     40   FakeHttpRequest(const string& request, const string& response)
     41       : FakeHttpRequest(request, response, Status::OK(), nullptr, {}, 200) {}
     42 
     43   /// Return the response with headers for the given request.
     44   FakeHttpRequest(const string& request, const string& response,
     45                   const std::map<string, string>& response_headers)
     46       : FakeHttpRequest(request, response, Status::OK(), nullptr,
     47                         response_headers, 200) {}
     48 
     49   /// \brief Return the response for the request and capture the POST body.
     50   ///
     51   /// Post body is not expected to be a part of the 'request' parameter.
     52   FakeHttpRequest(const string& request, const string& response,
     53                   string* captured_post_body)
     54       : FakeHttpRequest(request, response, Status::OK(), captured_post_body, {},
     55                         200) {}
     56 
     57   /// \brief Return the response and the status for the given request.
     58   FakeHttpRequest(const string& request, const string& response,
     59                   Status response_status, uint64 response_code)
     60       : FakeHttpRequest(request, response, response_status, nullptr, {},
     61                         response_code) {}
     62 
     63   /// \brief Return the response and the status for the given request
     64   ///  and capture the POST body.
     65   ///
     66   /// Post body is not expected to be a part of the 'request' parameter.
     67   FakeHttpRequest(const string& request, const string& response,
     68                   Status response_status, string* captured_post_body,
     69                   const std::map<string, string>& response_headers,
     70                   uint64 response_code)
     71       : expected_request_(request),
     72         response_(response),
     73         response_status_(response_status),
     74         captured_post_body_(captured_post_body),
     75         response_headers_(response_headers),
     76         response_code_(response_code) {}
     77 
     78   void SetUri(const string& uri) override {
     79     actual_uri_ += "Uri: " + uri + "\n";
     80   }
     81   void SetRange(uint64 start, uint64 end) override {
     82     actual_request_ += strings::StrCat("Range: ", start, "-", end, "\n");
     83   }
     84   void AddHeader(const string& name, const string& value) override {
     85     actual_request_ += "Header " + name + ": " + value + "\n";
     86   }
     87   void AddAuthBearerHeader(const string& auth_token) override {
     88     actual_request_ += "Auth Token: " + auth_token + "\n";
     89   }
     90   void SetDeleteRequest() override { actual_request_ += "Delete: yes\n"; }
     91   Status SetPutFromFile(const string& body_filepath, size_t offset) override {
     92     std::ifstream stream(body_filepath);
     93     const string& content = string(std::istreambuf_iterator<char>(stream),
     94                                    std::istreambuf_iterator<char>())
     95                                 .substr(offset);
     96     actual_request_ += "Put body: " + content + "\n";
     97     return Status::OK();
     98   }
     99   void SetPostFromBuffer(const char* buffer, size_t size) override {
    100     if (captured_post_body_) {
    101       *captured_post_body_ = string(buffer, size);
    102     } else {
    103       actual_request_ +=
    104           strings::StrCat("Post body: ", StringPiece(buffer, size), "\n");
    105     }
    106   }
    107   void SetPutEmptyBody() override { actual_request_ += "Put: yes\n"; }
    108   void SetPostEmptyBody() override {
    109     if (captured_post_body_) {
    110       *captured_post_body_ = "<empty>";
    111     } else {
    112       actual_request_ += "Post: yes\n";
    113     }
    114   }
    115   void SetResultBuffer(std::vector<char>* buffer) override {
    116     buffer->clear();
    117     buffer_ = buffer;
    118   }
    119   void SetResultBufferDirect(char* buffer, size_t size) override {
    120     direct_result_buffer_ = buffer;
    121     direct_result_buffer_size_ = size;
    122   }
    123   size_t GetResultBufferDirectBytesTransferred() override {
    124     return direct_result_bytes_transferred_;
    125   }
    126   Status Send() override {
    127     EXPECT_EQ(expected_request_, actual_request())
    128         << "Unexpected HTTP request.";
    129     if (buffer_) {
    130       buffer_->insert(buffer_->begin(), response_.data(),
    131                       response_.data() + response_.size());
    132     } else if (direct_result_buffer_ != nullptr) {
    133       size_t bytes_to_copy =
    134           std::min<size_t>(direct_result_buffer_size_, response_.size());
    135       memcpy(direct_result_buffer_, response_.data(), bytes_to_copy);
    136       direct_result_bytes_transferred_ += bytes_to_copy;
    137     }
    138     return response_status_;
    139   }
    140 
    141   // This function just does a simple replacing of "/" with "%2F" instead of
    142   // full url encoding.
    143   string EscapeString(const string& str) override {
    144     const string victim = "/";
    145     const string encoded = "%2F";
    146 
    147     string copy_str = str;
    148     std::string::size_type n = 0;
    149     while ((n = copy_str.find(victim, n)) != std::string::npos) {
    150       copy_str.replace(n, victim.size(), encoded);
    151       n += encoded.size();
    152     }
    153     return copy_str;
    154   }
    155 
    156   string GetResponseHeader(const string& name) const override {
    157     const auto header = response_headers_.find(name);
    158     return header != response_headers_.end() ? header->second : "";
    159   }
    160 
    161   virtual uint64 GetResponseCode() const override { return response_code_; }
    162 
    163   void SetTimeouts(uint32 connection, uint32 inactivity,
    164                    uint32 total) override {
    165     actual_request_ += strings::StrCat("Timeouts: ", connection, " ",
    166                                        inactivity, " ", total, "\n");
    167   }
    168 
    169  private:
    170   string actual_request() const {
    171     string s;
    172     s.append(actual_uri_);
    173     s.append(actual_request_);
    174     return s;
    175   }
    176 
    177   std::vector<char>* buffer_ = nullptr;
    178   char* direct_result_buffer_ = nullptr;
    179   size_t direct_result_buffer_size_ = 0;
    180   size_t direct_result_bytes_transferred_ = 0;
    181   string expected_request_;
    182   string actual_uri_;
    183   string actual_request_;
    184   string response_;
    185   Status response_status_;
    186   string* captured_post_body_ = nullptr;
    187   std::map<string, string> response_headers_;
    188   uint64 response_code_ = 0;
    189 };
    190 
    191 /// Fake HttpRequest factory for testing.
    192 class FakeHttpRequestFactory : public HttpRequest::Factory {
    193  public:
    194   FakeHttpRequestFactory(const std::vector<HttpRequest*>* requests)
    195       : requests_(requests) {}
    196 
    197   ~FakeHttpRequestFactory() {
    198     EXPECT_EQ(current_index_, requests_->size())
    199         << "Not all expected requests were made.";
    200   }
    201 
    202   HttpRequest* Create() override {
    203     EXPECT_LT(current_index_, requests_->size())
    204         << "Too many calls of HttpRequest factory.";
    205     return (*requests_)[current_index_++];
    206   }
    207 
    208  private:
    209   const std::vector<HttpRequest*>* requests_;
    210   int current_index_ = 0;
    211 };
    212 
    213 }  // namespace tensorflow
    214 
    215 #endif  // TENSORFLOW_CORE_PLATFORM_HTTP_REQUEST_FAKE_H_
    216