1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/platform/env.h" 17 18 #include <Shlwapi.h> 19 #include <Windows.h> 20 #include <errno.h> 21 #include <fcntl.h> 22 #include <stdio.h> 23 #include <time.h> 24 #undef LoadLibrary 25 #undef ERROR 26 27 #include <string> 28 #include <thread> 29 #include <vector> 30 31 #include "tensorflow/core/lib/core/error_codes.pb.h" 32 #include "tensorflow/core/platform/load_library.h" 33 #include "tensorflow/core/platform/logging.h" 34 #include "tensorflow/core/platform/windows/windows_file_system.h" 35 36 #pragma comment(lib, "Shlwapi.lib") 37 38 namespace tensorflow { 39 40 namespace { 41 42 class StdThread : public Thread { 43 public: 44 // name and thread_options are both ignored. 45 StdThread(const ThreadOptions& thread_options, const string& name, 46 std::function<void()> fn) 47 : thread_(fn) {} 48 ~StdThread() { thread_.join(); } 49 50 private: 51 std::thread thread_; 52 }; 53 54 class WindowsEnv : public Env { 55 public: 56 WindowsEnv() : GetSystemTimePreciseAsFileTime_(NULL) { 57 // GetSystemTimePreciseAsFileTime function is only available in the latest 58 // versions of Windows. For that reason, we try to look it up in 59 // kernel32.dll at runtime and use an alternative option if the function 60 // is not available. 61 HMODULE module = GetModuleHandleW(L"kernel32.dll"); 62 if (module != NULL) { 63 auto func = (FnGetSystemTimePreciseAsFileTime)GetProcAddress( 64 module, "GetSystemTimePreciseAsFileTime"); 65 GetSystemTimePreciseAsFileTime_ = func; 66 } 67 } 68 69 ~WindowsEnv() override { 70 LOG(FATAL) << "Env::Default() must not be destroyed"; 71 } 72 73 bool MatchPath(const string& path, const string& pattern) override { 74 std::wstring ws_path(WindowsFileSystem::Utf8ToWideChar(path)); 75 std::wstring ws_pattern(WindowsFileSystem::Utf8ToWideChar(pattern)); 76 return PathMatchSpecW(ws_path.c_str(), ws_pattern.c_str()) == TRUE; 77 } 78 79 void SleepForMicroseconds(int64 micros) override { Sleep(micros / 1000); } 80 81 Thread* StartThread(const ThreadOptions& thread_options, const string& name, 82 std::function<void()> fn) override { 83 return new StdThread(thread_options, name, fn); 84 } 85 86 static VOID CALLBACK SchedClosureCallback(PTP_CALLBACK_INSTANCE Instance, 87 PVOID Context, PTP_WORK Work) { 88 CloseThreadpoolWork(Work); 89 std::function<void()>* f = (std::function<void()>*)Context; 90 (*f)(); 91 delete f; 92 } 93 void SchedClosure(std::function<void()> closure) override { 94 PTP_WORK work = CreateThreadpoolWork( 95 SchedClosureCallback, new std::function<void()>(std::move(closure)), 96 nullptr); 97 SubmitThreadpoolWork(work); 98 } 99 100 static VOID CALLBACK SchedClosureAfterCallback(PTP_CALLBACK_INSTANCE Instance, 101 PVOID Context, 102 PTP_TIMER Timer) { 103 CloseThreadpoolTimer(Timer); 104 std::function<void()>* f = (std::function<void()>*)Context; 105 (*f)(); 106 delete f; 107 } 108 109 void SchedClosureAfter(int64 micros, std::function<void()> closure) override { 110 PTP_TIMER timer = CreateThreadpoolTimer( 111 SchedClosureAfterCallback, 112 new std::function<void()>(std::move(closure)), nullptr); 113 // in 100 nanosecond units 114 FILETIME FileDueTime; 115 ULARGE_INTEGER ulDueTime; 116 // Negative indicates the amount of time to wait is relative to the current 117 // time. 118 ulDueTime.QuadPart = (ULONGLONG) - (10 * micros); 119 FileDueTime.dwHighDateTime = ulDueTime.HighPart; 120 FileDueTime.dwLowDateTime = ulDueTime.LowPart; 121 SetThreadpoolTimer(timer, &FileDueTime, 0, 0); 122 } 123 124 Status LoadLibrary(const char* library_filename, void** handle) override { 125 std::string file_name = library_filename; 126 std::replace(file_name.begin(), file_name.end(), '/', '\\'); 127 128 std::wstring ws_file_name(WindowsFileSystem::Utf8ToWideChar(file_name)); 129 130 HMODULE hModule = LoadLibraryExW(ws_file_name.c_str(), NULL, 131 LOAD_WITH_ALTERED_SEARCH_PATH); 132 if (!hModule) { 133 return errors::NotFound(file_name + " not found"); 134 } 135 *handle = hModule; 136 return Status::OK(); 137 } 138 139 Status GetSymbolFromLibrary(void* handle, const char* symbol_name, 140 void** symbol) override { 141 FARPROC found_symbol; 142 143 found_symbol = GetProcAddress((HMODULE)handle, symbol_name); 144 if (found_symbol == NULL) { 145 return errors::NotFound(std::string(symbol_name) + " not found"); 146 } 147 *symbol = (void**)found_symbol; 148 return Status::OK(); 149 } 150 151 string FormatLibraryFileName(const string& name, 152 const string& version) override { 153 string filename; 154 if (version.size() == 0) { 155 filename = name + ".dll"; 156 } else { 157 filename = name + version + ".dll"; 158 } 159 return filename; 160 } 161 162 private: 163 typedef VOID(WINAPI* FnGetSystemTimePreciseAsFileTime)(LPFILETIME); 164 FnGetSystemTimePreciseAsFileTime GetSystemTimePreciseAsFileTime_; 165 }; 166 167 } // namespace 168 169 REGISTER_FILE_SYSTEM("", WindowsFileSystem); 170 REGISTER_FILE_SYSTEM("file", LocalWinFileSystem); 171 172 Env* Env::Default() { 173 static Env* default_env = new WindowsEnv; 174 return default_env; 175 } 176 177 void Env::GetLocalTempDirectories(std::vector<string>* list) { 178 list->clear(); 179 // On windows we'll try to find a directory in this order: 180 // C:/Documents & Settings/whomever/TEMP (or whatever GetTempPath() is) 181 // C:/TMP/ 182 // C:/TEMP/ 183 // C:/WINDOWS/ or C:/WINNT/ 184 // . 185 char tmp[MAX_PATH]; 186 // GetTempPath can fail with either 0 or with a space requirement > bufsize. 187 // See http://msdn.microsoft.com/en-us/library/aa364992(v=vs.85).aspx 188 DWORD n = GetTempPathA(MAX_PATH, tmp); 189 if (n > 0 && n <= MAX_PATH) list->push_back(tmp); 190 list->push_back("C:\\tmp\\"); 191 list->push_back("C:\\temp\\"); 192 } 193 194 } // namespace tensorflow 195