Home | History | Annotate | Download | only in windows
      1 /* Copyright 2015 Google Inc. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <Shlwapi.h>
     17 #include <Windows.h>
     18 #include <direct.h>
     19 #include <errno.h>
     20 #include <fcntl.h>
     21 #include <io.h>
     22 #undef StrCat
     23 #include <stdio.h>
     24 #include <sys/stat.h>
     25 #include <sys/types.h>
     26 #include <time.h>
     27 
     28 #include "tensorflow/core/lib/core/error_codes.pb.h"
     29 #include "tensorflow/core/lib/strings/strcat.h"
     30 #include "tensorflow/core/platform/env.h"
     31 #include "tensorflow/core/platform/logging.h"
     32 #include "tensorflow/core/platform/posix/error.h"
     33 #include "tensorflow/core/platform/windows/error.h"
     34 #include "tensorflow/core/platform/windows/windows_file_system.h"
     35 
     36 // TODO(mrry): Prevent this Windows.h #define from leaking out of our headers.
     37 #undef DeleteFile
     38 
     39 namespace tensorflow {
     40 
     41 namespace {
     42 
     43 // RAII helpers for HANDLEs
     44 const auto CloseHandleFunc = [](HANDLE h) { ::CloseHandle(h); };
     45 typedef std::unique_ptr<void, decltype(CloseHandleFunc)> UniqueCloseHandlePtr;
     46 
     47 inline Status IOErrorFromWindowsError(const string& context, DWORD err) {
     48   return IOError(
     49       context + string(" : ") + internal::GetWindowsErrorMessage(err), err);
     50 }
     51 
     52 // PLEASE NOTE: hfile is expected to be an async handle
     53 // (i.e. opened with FILE_FLAG_OVERLAPPED)
     54 SSIZE_T pread(HANDLE hfile, char* src, size_t num_bytes, uint64_t offset) {
     55   assert(num_bytes <= std::numeric_limits<DWORD>::max());
     56   OVERLAPPED overlapped = {0};
     57   ULARGE_INTEGER offset_union;
     58   offset_union.QuadPart = offset;
     59 
     60   overlapped.Offset = offset_union.LowPart;
     61   overlapped.OffsetHigh = offset_union.HighPart;
     62   overlapped.hEvent = ::CreateEvent(NULL, TRUE, FALSE, NULL);
     63 
     64   if (NULL == overlapped.hEvent) {
     65     return -1;
     66   }
     67 
     68   SSIZE_T result = 0;
     69 
     70   unsigned long bytes_read = 0;
     71   DWORD last_error = ERROR_SUCCESS;
     72 
     73   BOOL read_result = ::ReadFile(hfile, src, static_cast<DWORD>(num_bytes),
     74                                 &bytes_read, &overlapped);
     75   if (TRUE == read_result) {
     76     result = bytes_read;
     77   } else if ((FALSE == read_result) &&
     78              ((last_error = GetLastError()) != ERROR_IO_PENDING)) {
     79     result = (last_error == ERROR_HANDLE_EOF) ? 0 : -1;
     80   } else {
     81     if (ERROR_IO_PENDING ==
     82         last_error) {  // Otherwise bytes_read already has the result.
     83       BOOL overlapped_result =
     84           ::GetOverlappedResult(hfile, &overlapped, &bytes_read, TRUE);
     85       if (FALSE == overlapped_result) {
     86         result = (::GetLastError() == ERROR_HANDLE_EOF) ? 0 : -1;
     87       } else {
     88         result = bytes_read;
     89       }
     90     }
     91   }
     92 
     93   ::CloseHandle(overlapped.hEvent);
     94 
     95   return result;
     96 }
     97 
     98 // read() based random-access
     99 class WindowsRandomAccessFile : public RandomAccessFile {
    100  private:
    101   string filename_;
    102   HANDLE hfile_;
    103 
    104  public:
    105   WindowsRandomAccessFile(const string& fname, HANDLE hfile)
    106       : filename_(fname), hfile_(hfile) {}
    107   ~WindowsRandomAccessFile() override {
    108     if (hfile_ != NULL && hfile_ != INVALID_HANDLE_VALUE) {
    109       ::CloseHandle(hfile_);
    110     }
    111   }
    112 
    113   Status Read(uint64 offset, size_t n, StringPiece* result,
    114               char* scratch) const override {
    115     Status s;
    116     char* dst = scratch;
    117     while (n > 0 && s.ok()) {
    118       SSIZE_T r = pread(hfile_, dst, n, offset);
    119       if (r > 0) {
    120         offset += r;
    121         dst += r;
    122         n -= r;
    123       } else if (r == 0) {
    124         s = Status(error::OUT_OF_RANGE, "Read fewer bytes than requested");
    125       } else if (errno == EINTR || errno == EAGAIN) {
    126         // Retry
    127       } else {
    128         s = IOError(filename_, errno);
    129       }
    130     }
    131     *result = StringPiece(scratch, dst - scratch);
    132     return s;
    133   }
    134 };
    135 
    136 class WindowsWritableFile : public WritableFile {
    137  private:
    138   string filename_;
    139   HANDLE hfile_;
    140 
    141  public:
    142   WindowsWritableFile(const string& fname, HANDLE hFile)
    143       : filename_(fname), hfile_(hFile) {}
    144 
    145   ~WindowsWritableFile() override {
    146     if (hfile_ != NULL && hfile_ != INVALID_HANDLE_VALUE) {
    147       WindowsWritableFile::Close();
    148     }
    149   }
    150 
    151   Status Append(const StringPiece& data) override {
    152     DWORD bytes_written = 0;
    153     DWORD data_size = static_cast<DWORD>(data.size());
    154     BOOL write_result =
    155         ::WriteFile(hfile_, data.data(), data_size, &bytes_written, NULL);
    156     if (FALSE == write_result) {
    157       return IOErrorFromWindowsError("Failed to WriteFile: " + filename_,
    158                                      ::GetLastError());
    159     }
    160 
    161     assert(size_t(bytes_written) == data.size());
    162     return Status::OK();
    163   }
    164 
    165   Status Close() override {
    166     assert(INVALID_HANDLE_VALUE != hfile_);
    167 
    168     Status result = Flush();
    169     if (!result.ok()) {
    170       return result;
    171     }
    172 
    173     if (FALSE == ::CloseHandle(hfile_)) {
    174       return IOErrorFromWindowsError("CloseHandle failed for: " + filename_,
    175                                      ::GetLastError());
    176     }
    177 
    178     hfile_ = INVALID_HANDLE_VALUE;
    179     return Status::OK();
    180   }
    181 
    182   Status Flush() override {
    183     if (FALSE == ::FlushFileBuffers(hfile_)) {
    184       return IOErrorFromWindowsError(
    185           "FlushFileBuffers failed for: " + filename_, ::GetLastError());
    186     }
    187     return Status::OK();
    188   }
    189 
    190   Status Sync() override { return Flush(); }
    191 };
    192 
    193 class WinReadOnlyMemoryRegion : public ReadOnlyMemoryRegion {
    194  private:
    195   const std::string filename_;
    196   HANDLE hfile_;
    197   HANDLE hmap_;
    198 
    199   const void* const address_;
    200   const uint64 length_;
    201 
    202  public:
    203   WinReadOnlyMemoryRegion(const std::string& filename, HANDLE hfile,
    204                           HANDLE hmap, const void* address, uint64 length)
    205       : filename_(filename),
    206         hfile_(hfile),
    207         hmap_(hmap),
    208         address_(address),
    209         length_(length) {}
    210 
    211   ~WinReadOnlyMemoryRegion() {
    212     BOOL ret = ::UnmapViewOfFile(address_);
    213     assert(ret);
    214 
    215     ret = ::CloseHandle(hmap_);
    216     assert(ret);
    217 
    218     ret = ::CloseHandle(hfile_);
    219     assert(ret);
    220   }
    221 
    222   const void* data() override { return address_; }
    223   uint64 length() override { return length_; }
    224 };
    225 
    226 }  // namespace
    227 
    228 Status WindowsFileSystem::NewRandomAccessFile(
    229     const string& fname, std::unique_ptr<RandomAccessFile>* result) {
    230   string translated_fname = TranslateName(fname);
    231   std::wstring ws_translated_fname = Utf8ToWideChar(translated_fname);
    232   result->reset();
    233 
    234   // Open the file for read-only random access
    235   // Open in async mode which makes Windows allow more parallelism even
    236   // if we need to do sync I/O on top of it.
    237   DWORD file_flags = FILE_ATTRIBUTE_READONLY | FILE_FLAG_OVERLAPPED;
    238   // Shared access is necessary for tests to pass
    239   // almost all tests would work with a possible exception of fault_injection.
    240   DWORD share_mode = FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE;
    241 
    242   HANDLE hfile =
    243       ::CreateFileW(ws_translated_fname.c_str(), GENERIC_READ, share_mode, NULL,
    244                     OPEN_EXISTING, file_flags, NULL);
    245 
    246   if (INVALID_HANDLE_VALUE == hfile) {
    247     string context = "NewRandomAccessFile failed to Create/Open: " + fname;
    248     return IOErrorFromWindowsError(context, ::GetLastError());
    249   }
    250 
    251   result->reset(new WindowsRandomAccessFile(translated_fname, hfile));
    252   return Status::OK();
    253 }
    254 
    255 Status WindowsFileSystem::NewWritableFile(
    256     const string& fname, std::unique_ptr<WritableFile>* result) {
    257   string translated_fname = TranslateName(fname);
    258   std::wstring ws_translated_fname = Utf8ToWideChar(translated_fname);
    259   result->reset();
    260 
    261   DWORD share_mode = FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE;
    262   HANDLE hfile =
    263       ::CreateFileW(ws_translated_fname.c_str(), GENERIC_WRITE, share_mode,
    264                     NULL, CREATE_ALWAYS, FILE_ATTRIBUTE_NORMAL, NULL);
    265 
    266   if (INVALID_HANDLE_VALUE == hfile) {
    267     string context = "Failed to create a NewWriteableFile: " + fname;
    268     return IOErrorFromWindowsError(context, ::GetLastError());
    269   }
    270 
    271   result->reset(new WindowsWritableFile(translated_fname, hfile));
    272   return Status::OK();
    273 }
    274 
    275 Status WindowsFileSystem::NewAppendableFile(
    276     const string& fname, std::unique_ptr<WritableFile>* result) {
    277   string translated_fname = TranslateName(fname);
    278   std::wstring ws_translated_fname = Utf8ToWideChar(translated_fname);
    279   result->reset();
    280 
    281   DWORD share_mode = FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE;
    282   HANDLE hfile =
    283       ::CreateFileW(ws_translated_fname.c_str(), GENERIC_WRITE, share_mode,
    284                     NULL, OPEN_ALWAYS, FILE_ATTRIBUTE_NORMAL, NULL);
    285 
    286   if (INVALID_HANDLE_VALUE == hfile) {
    287     string context = "Failed to create a NewAppendableFile: " + fname;
    288     return IOErrorFromWindowsError(context, ::GetLastError());
    289   }
    290 
    291   UniqueCloseHandlePtr file_guard(hfile, CloseHandleFunc);
    292 
    293   DWORD file_ptr = ::SetFilePointer(hfile, NULL, NULL, FILE_END);
    294   if (INVALID_SET_FILE_POINTER == file_ptr) {
    295     string context = "Failed to create a NewAppendableFile: " + fname;
    296     return IOErrorFromWindowsError(context, ::GetLastError());
    297   }
    298 
    299   result->reset(new WindowsWritableFile(translated_fname, hfile));
    300   file_guard.release();
    301 
    302   return Status::OK();
    303 }
    304 
    305 Status WindowsFileSystem::NewReadOnlyMemoryRegionFromFile(
    306     const string& fname, std::unique_ptr<ReadOnlyMemoryRegion>* result) {
    307   string translated_fname = TranslateName(fname);
    308   std::wstring ws_translated_fname = Utf8ToWideChar(translated_fname);
    309   result->reset();
    310   Status s = Status::OK();
    311 
    312   // Open the file for read-only
    313   DWORD file_flags = FILE_ATTRIBUTE_READONLY;
    314 
    315   // Open in async mode which makes Windows allow more parallelism even
    316   // if we need to do sync I/O on top of it.
    317   file_flags |= FILE_FLAG_OVERLAPPED;
    318 
    319   DWORD share_mode = FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE;
    320   HANDLE hfile =
    321       ::CreateFileW(ws_translated_fname.c_str(), GENERIC_READ, share_mode, NULL,
    322                     OPEN_EXISTING, file_flags, NULL);
    323 
    324   if (INVALID_HANDLE_VALUE == hfile) {
    325     return IOErrorFromWindowsError(
    326         "NewReadOnlyMemoryRegionFromFile failed to Create/Open: " + fname,
    327         ::GetLastError());
    328   }
    329 
    330   UniqueCloseHandlePtr file_guard(hfile, CloseHandleFunc);
    331 
    332   // Use mmap when virtual address-space is plentiful.
    333   uint64_t file_size;
    334   s = GetFileSize(translated_fname, &file_size);
    335   if (s.ok()) {
    336     // Will not map empty files
    337     if (file_size == 0) {
    338       return IOError(
    339           "NewReadOnlyMemoryRegionFromFile failed to map empty file: " + fname,
    340           EINVAL);
    341     }
    342 
    343     HANDLE hmap = ::CreateFileMappingA(hfile, NULL, PAGE_READONLY,
    344                                        0,  // Whole file at its present length
    345                                        0,
    346                                        NULL);  // Mapping name
    347 
    348     if (!hmap) {
    349       string context =
    350           "Failed to create file mapping for "
    351           "NewReadOnlyMemoryRegionFromFile: " +
    352           fname;
    353       return IOErrorFromWindowsError(context, ::GetLastError());
    354     }
    355 
    356     UniqueCloseHandlePtr map_guard(hmap, CloseHandleFunc);
    357 
    358     const void* mapped_region =
    359         ::MapViewOfFileEx(hmap, FILE_MAP_READ,
    360                           0,  // High DWORD of access start
    361                           0,  // Low DWORD
    362                           file_size,
    363                           NULL);  // Let the OS choose the mapping
    364 
    365     if (!mapped_region) {
    366       string context =
    367           "Failed to MapViewOfFile for "
    368           "NewReadOnlyMemoryRegionFromFile: " +
    369           fname;
    370       return IOErrorFromWindowsError(context, ::GetLastError());
    371     }
    372 
    373     result->reset(new WinReadOnlyMemoryRegion(fname, hfile, hmap, mapped_region,
    374                                               file_size));
    375 
    376     map_guard.release();
    377     file_guard.release();
    378   }
    379 
    380   return s;
    381 }
    382 
    383 Status WindowsFileSystem::FileExists(const string& fname) {
    384   constexpr int kOk = 0;
    385   if (_access(TranslateName(fname).c_str(), kOk) == 0) {
    386     return Status::OK();
    387   }
    388   return errors::NotFound(fname, " not found");
    389 }
    390 
    391 Status WindowsFileSystem::GetChildren(const string& dir,
    392                                       std::vector<string>* result) {
    393   string translated_dir = TranslateName(dir);
    394   std::wstring ws_translated_dir = Utf8ToWideChar(translated_dir);
    395   result->clear();
    396 
    397   std::wstring pattern = ws_translated_dir;
    398   if (!pattern.empty() && pattern.back() != '\\' && pattern.back() != '/') {
    399     pattern += L"\\*";
    400   } else {
    401     pattern += L'*';
    402   }
    403 
    404   WIN32_FIND_DATAW find_data;
    405   HANDLE find_handle = ::FindFirstFileW(pattern.c_str(), &find_data);
    406   if (find_handle == INVALID_HANDLE_VALUE) {
    407     string context = "FindFirstFile failed for: " + translated_dir;
    408     return IOErrorFromWindowsError(context, ::GetLastError());
    409   }
    410 
    411   do {
    412     string file_name = WideCharToUtf8(find_data.cFileName);
    413     const StringPiece basename = file_name;
    414     if (basename != "." && basename != "..") {
    415       result->push_back(file_name);
    416     }
    417   } while (::FindNextFileW(find_handle, &find_data));
    418 
    419   if (!::FindClose(find_handle)) {
    420     string context = "FindClose failed for: " + translated_dir;
    421     return IOErrorFromWindowsError(context, ::GetLastError());
    422   }
    423 
    424   return Status::OK();
    425 }
    426 
    427 Status WindowsFileSystem::DeleteFile(const string& fname) {
    428   Status result;
    429   std::wstring file_name = Utf8ToWideChar(fname);
    430   if (_wunlink(file_name.c_str()) != 0) {
    431     result = IOError("Failed to delete a file: " + fname, errno);
    432   }
    433   return result;
    434 }
    435 
    436 Status WindowsFileSystem::CreateDir(const string& name) {
    437   Status result;
    438   std::wstring ws_name = Utf8ToWideChar(name);
    439   if (_wmkdir(ws_name.c_str()) != 0) {
    440     result = IOError("Failed to create a directory: " + name, errno);
    441   }
    442   return result;
    443 }
    444 
    445 Status WindowsFileSystem::DeleteDir(const string& name) {
    446   Status result;
    447   std::wstring ws_name = Utf8ToWideChar(name);
    448   if (_wrmdir(ws_name.c_str()) != 0) {
    449     result = IOError("Failed to remove a directory: " + name, errno);
    450   }
    451   return result;
    452 }
    453 
    454 Status WindowsFileSystem::GetFileSize(const string& fname, uint64* size) {
    455   string translated_fname = TranslateName(fname);
    456   std::wstring ws_translated_dir = Utf8ToWideChar(translated_fname);
    457   Status result;
    458   WIN32_FILE_ATTRIBUTE_DATA attrs;
    459   if (TRUE == ::GetFileAttributesExW(ws_translated_dir.c_str(),
    460                                      GetFileExInfoStandard, &attrs)) {
    461     ULARGE_INTEGER file_size;
    462     file_size.HighPart = attrs.nFileSizeHigh;
    463     file_size.LowPart = attrs.nFileSizeLow;
    464     *size = file_size.QuadPart;
    465   } else {
    466     string context = "Can not get size for: " + fname;
    467     result = IOErrorFromWindowsError(context, ::GetLastError());
    468   }
    469   return result;
    470 }
    471 
    472 Status WindowsFileSystem::RenameFile(const string& src, const string& target) {
    473   Status result;
    474   // rename() is not capable of replacing the existing file as on Linux
    475   // so use OS API directly
    476   std::wstring ws_translated_src = Utf8ToWideChar(TranslateName(src));
    477   std::wstring ws_translated_target = Utf8ToWideChar(TranslateName(target));
    478   if (!::MoveFileExW(ws_translated_src.c_str(), ws_translated_target.c_str(),
    479                      MOVEFILE_REPLACE_EXISTING)) {
    480     string context(strings::StrCat("Failed to rename: ", src, " to: ", target));
    481     result = IOErrorFromWindowsError(context, ::GetLastError());
    482   }
    483   return result;
    484 }
    485 
    486 Status WindowsFileSystem::GetMatchingPaths(const string& pattern,
    487                                            std::vector<string>* results) {
    488   // NOTE(mrry): The existing implementation of FileSystem::GetMatchingPaths()
    489   // does not handle Windows paths containing backslashes correctly. Since
    490   // Windows APIs will accept forward and backslashes equivalently, we
    491   // convert the pattern to use forward slashes exclusively. Note that this
    492   // is not ideal, since the API expects backslash as an escape character,
    493   // but no code appears to rely on this behavior.
    494   string converted_pattern(pattern);
    495   std::replace(converted_pattern.begin(), converted_pattern.end(), '\\', '/');
    496   TF_RETURN_IF_ERROR(FileSystem::GetMatchingPaths(converted_pattern, results));
    497   for (string& result : *results) {
    498     std::replace(result.begin(), result.end(), '/', '\\');
    499   }
    500   return Status::OK();
    501 }
    502 
    503 Status WindowsFileSystem::Stat(const string& fname, FileStatistics* stat) {
    504   Status result;
    505   struct _stat sbuf;
    506   std::wstring ws_translated_fname = Utf8ToWideChar(TranslateName(fname));
    507   if (_wstat(ws_translated_fname.c_str(), &sbuf) != 0) {
    508     result = IOError(fname, errno);
    509   } else {
    510     stat->mtime_nsec = sbuf.st_mtime * 1e9;
    511     stat->length = sbuf.st_size;
    512     stat->is_directory = PathIsDirectoryW(ws_translated_fname.c_str());
    513   }
    514   return result;
    515 }
    516 
    517 }  // namespace tensorflow
    518