Home | History | Annotate | Download | only in windows
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/platform/env.h"
     17 
     18 #include <Shlwapi.h>
     19 #include <Windows.h>
     20 #include <errno.h>
     21 #include <fcntl.h>
     22 #include <stdio.h>
     23 #include <time.h>
     24 #undef LoadLibrary
     25 #undef ERROR
     26 
     27 #include <string>
     28 #include <thread>
     29 #include <vector>
     30 
     31 #include "tensorflow/core/lib/core/error_codes.pb.h"
     32 #include "tensorflow/core/platform/load_library.h"
     33 #include "tensorflow/core/platform/logging.h"
     34 #include "tensorflow/core/platform/windows/windows_file_system.h"
     35 
     36 #pragma comment(lib, "Shlwapi.lib")
     37 
     38 namespace tensorflow {
     39 
     40 namespace {
     41 
     42 class StdThread : public Thread {
     43  public:
     44   // name and thread_options are both ignored.
     45   StdThread(const ThreadOptions& thread_options, const string& name,
     46             std::function<void()> fn)
     47       : thread_(fn) {}
     48   ~StdThread() { thread_.join(); }
     49 
     50  private:
     51   std::thread thread_;
     52 };
     53 
     54 class WindowsEnv : public Env {
     55  public:
     56   WindowsEnv() : GetSystemTimePreciseAsFileTime_(NULL) {
     57     // GetSystemTimePreciseAsFileTime function is only available in the latest
     58     // versions of Windows. For that reason, we try to look it up in
     59     // kernel32.dll at runtime and use an alternative option if the function
     60     // is not available.
     61     HMODULE module = GetModuleHandleW(L"kernel32.dll");
     62     if (module != NULL) {
     63       auto func = (FnGetSystemTimePreciseAsFileTime)GetProcAddress(
     64           module, "GetSystemTimePreciseAsFileTime");
     65       GetSystemTimePreciseAsFileTime_ = func;
     66     }
     67   }
     68 
     69   ~WindowsEnv() override {
     70     LOG(FATAL) << "Env::Default() must not be destroyed";
     71   }
     72 
     73   bool MatchPath(const string& path, const string& pattern) override {
     74     std::wstring ws_path(WindowsFileSystem::Utf8ToWideChar(path));
     75     std::wstring ws_pattern(WindowsFileSystem::Utf8ToWideChar(pattern));
     76     return PathMatchSpecW(ws_path.c_str(), ws_pattern.c_str()) == TRUE;
     77   }
     78 
     79   void SleepForMicroseconds(int64 micros) override { Sleep(micros / 1000); }
     80 
     81   Thread* StartThread(const ThreadOptions& thread_options, const string& name,
     82                       std::function<void()> fn) override {
     83     return new StdThread(thread_options, name, fn);
     84   }
     85 
     86   static VOID CALLBACK SchedClosureCallback(PTP_CALLBACK_INSTANCE Instance,
     87                                             PVOID Context, PTP_WORK Work) {
     88     CloseThreadpoolWork(Work);
     89     std::function<void()>* f = (std::function<void()>*)Context;
     90     (*f)();
     91     delete f;
     92   }
     93   void SchedClosure(std::function<void()> closure) override {
     94     PTP_WORK work = CreateThreadpoolWork(
     95         SchedClosureCallback, new std::function<void()>(std::move(closure)),
     96         nullptr);
     97     SubmitThreadpoolWork(work);
     98   }
     99 
    100   static VOID CALLBACK SchedClosureAfterCallback(PTP_CALLBACK_INSTANCE Instance,
    101                                                  PVOID Context,
    102                                                  PTP_TIMER Timer) {
    103     CloseThreadpoolTimer(Timer);
    104     std::function<void()>* f = (std::function<void()>*)Context;
    105     (*f)();
    106     delete f;
    107   }
    108 
    109   void SchedClosureAfter(int64 micros, std::function<void()> closure) override {
    110     PTP_TIMER timer = CreateThreadpoolTimer(
    111         SchedClosureAfterCallback,
    112         new std::function<void()>(std::move(closure)), nullptr);
    113     // in 100 nanosecond units
    114     FILETIME FileDueTime;
    115     ULARGE_INTEGER ulDueTime;
    116     // Negative indicates the amount of time to wait is relative to the current
    117     // time.
    118     ulDueTime.QuadPart = (ULONGLONG) - (10 * micros);
    119     FileDueTime.dwHighDateTime = ulDueTime.HighPart;
    120     FileDueTime.dwLowDateTime = ulDueTime.LowPart;
    121     SetThreadpoolTimer(timer, &FileDueTime, 0, 0);
    122   }
    123 
    124   Status LoadLibrary(const char* library_filename, void** handle) override {
    125     std::string file_name = library_filename;
    126     std::replace(file_name.begin(), file_name.end(), '/', '\\');
    127 
    128     std::wstring ws_file_name(WindowsFileSystem::Utf8ToWideChar(file_name));
    129 
    130     HMODULE hModule = LoadLibraryExW(ws_file_name.c_str(), NULL,
    131                                      LOAD_WITH_ALTERED_SEARCH_PATH);
    132     if (!hModule) {
    133       return errors::NotFound(file_name + " not found");
    134     }
    135     *handle = hModule;
    136     return Status::OK();
    137   }
    138 
    139   Status GetSymbolFromLibrary(void* handle, const char* symbol_name,
    140                               void** symbol) override {
    141     FARPROC found_symbol;
    142 
    143     found_symbol = GetProcAddress((HMODULE)handle, symbol_name);
    144     if (found_symbol == NULL) {
    145       return errors::NotFound(std::string(symbol_name) + " not found");
    146     }
    147     *symbol = (void**)found_symbol;
    148     return Status::OK();
    149   }
    150 
    151   string FormatLibraryFileName(const string& name,
    152                                const string& version) override {
    153     string filename;
    154     if (version.size() == 0) {
    155       filename = name + ".dll";
    156     } else {
    157       filename = name + version + ".dll";
    158     }
    159     return filename;
    160   }
    161 
    162  private:
    163   typedef VOID(WINAPI* FnGetSystemTimePreciseAsFileTime)(LPFILETIME);
    164   FnGetSystemTimePreciseAsFileTime GetSystemTimePreciseAsFileTime_;
    165 };
    166 
    167 }  // namespace
    168 
    169 REGISTER_FILE_SYSTEM("", WindowsFileSystem);
    170 REGISTER_FILE_SYSTEM("file", LocalWinFileSystem);
    171 
    172 Env* Env::Default() {
    173   static Env* default_env = new WindowsEnv;
    174   return default_env;
    175 }
    176 
    177 void Env::GetLocalTempDirectories(std::vector<string>* list) {
    178   list->clear();
    179   // On windows we'll try to find a directory in this order:
    180   //   C:/Documents & Settings/whomever/TEMP (or whatever GetTempPath() is)
    181   //   C:/TMP/
    182   //   C:/TEMP/
    183   //   C:/WINDOWS/ or C:/WINNT/
    184   //   .
    185   char tmp[MAX_PATH];
    186   // GetTempPath can fail with either 0 or with a space requirement > bufsize.
    187   // See http://msdn.microsoft.com/en-us/library/aa364992(v=vs.85).aspx
    188   DWORD n = GetTempPathA(MAX_PATH, tmp);
    189   if (n > 0 && n <= MAX_PATH) list->push_back(tmp);
    190   list->push_back("C:\\tmp\\");
    191   list->push_back("C:\\temp\\");
    192 }
    193 
    194 }  // namespace tensorflow
    195