Home | History | Annotate | Download | only in xla
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/util.h"
     17 
     18 #include <numeric>
     19 #include <stdarg.h>
     20 #include <numeric>
     21 
     22 #include "tensorflow/compiler/xla/types.h"
     23 #include "tensorflow/core/lib/core/errors.h"
     24 #include "tensorflow/core/lib/strings/numbers.h"
     25 #include "tensorflow/core/lib/strings/str_util.h"
     26 #include "tensorflow/core/lib/strings/strcat.h"
     27 #include "tensorflow/core/lib/strings/stringprintf.h"
     28 #include "tensorflow/core/platform/env.h"
     29 #include "tensorflow/core/platform/mutex.h"
     30 #include "tensorflow/core/platform/stacktrace.h"
     31 
     32 namespace xla {
     33 
     34 Status WithLogBacktrace(const Status& status) {
     35   CHECK(!status.ok());
     36   VLOG(1) << status.ToString();
     37   VLOG(1) << tensorflow::CurrentStackTrace();
     38   return status;
     39 }
     40 
     41 ScopedLoggingTimer::ScopedLoggingTimer(const string& label, bool enabled)
     42     : enabled(enabled), label(label) {
     43   if (enabled) {
     44     start_micros = tensorflow::Env::Default()->NowMicros();
     45   }
     46 }
     47 
     48 ScopedLoggingTimer::~ScopedLoggingTimer() {
     49   if (enabled) {
     50     uint64 end_micros = tensorflow::Env::Default()->NowMicros();
     51     double secs = (end_micros - start_micros) / 1000000.0;
     52 
     53     LOG(INFO) << label << " time: "
     54               << tensorflow::strings::HumanReadableElapsedTime(secs);
     55   }
     56 }
     57 
     58 Status AddStatus(Status prior, tensorflow::StringPiece context) {
     59   CHECK(!prior.ok());
     60   return Status{prior.code(), tensorflow::strings::StrCat(
     61                                   context, ": ", prior.error_message())};
     62 }
     63 
     64 Status AppendStatus(Status prior, tensorflow::StringPiece context) {
     65   CHECK(!prior.ok());
     66   return Status{prior.code(), tensorflow::strings::StrCat(prior.error_message(),
     67                                                           ": ", context)};
     68 }
     69 
     70 // Implementation note: we can't common these out (without using macros) because
     71 // they all need to va_start/va_end their varargs in their frame.
     72 
     73 Status InvalidArgumentV(const char* format, va_list args) {
     74   string message;
     75   tensorflow::strings::Appendv(&message, format, args);
     76   return WithLogBacktrace(tensorflow::errors::InvalidArgument(message));
     77 }
     78 
     79 Status InvalidArgument(const char* format, ...) {
     80   va_list args;
     81   va_start(args, format);
     82   Status result = InvalidArgumentV(format, args);
     83   va_end(args);
     84   return result;
     85 }
     86 
     87 Status Unimplemented(const char* format, ...) {
     88   string message;
     89   va_list args;
     90   va_start(args, format);
     91   tensorflow::strings::Appendv(&message, format, args);
     92   va_end(args);
     93   return WithLogBacktrace(tensorflow::errors::Unimplemented(message));
     94 }
     95 
     96 Status InternalError(const char* format, ...) {
     97   string message;
     98   va_list args;
     99   va_start(args, format);
    100   tensorflow::strings::Appendv(&message, format, args);
    101   va_end(args);
    102   return WithLogBacktrace(tensorflow::errors::Internal(message));
    103 }
    104 
    105 Status FailedPrecondition(const char* format, ...) {
    106   string message;
    107   va_list args;
    108   va_start(args, format);
    109   tensorflow::strings::Appendv(&message, format, args);
    110   va_end(args);
    111   return WithLogBacktrace(tensorflow::errors::FailedPrecondition(message));
    112 }
    113 
    114 Status Cancelled(const char* format, ...) {
    115   string message;
    116   va_list args;
    117   va_start(args, format);
    118   tensorflow::strings::Appendv(&message, format, args);
    119   va_end(args);
    120   return WithLogBacktrace(tensorflow::errors::Cancelled(message));
    121 }
    122 
    123 Status ResourceExhausted(const char* format, ...) {
    124   string message;
    125   va_list args;
    126   va_start(args, format);
    127   tensorflow::strings::Appendv(&message, format, args);
    128   va_end(args);
    129   return WithLogBacktrace(tensorflow::errors::ResourceExhausted(message));
    130 }
    131 
    132 Status NotFound(const char* format, ...) {
    133   string message;
    134   va_list args;
    135   va_start(args, format);
    136   tensorflow::strings::Appendv(&message, format, args);
    137   va_end(args);
    138   return WithLogBacktrace(tensorflow::errors::NotFound(message));
    139 }
    140 
    141 Status Unavailable(const char* format, ...) {
    142   string message;
    143   va_list args;
    144   va_start(args, format);
    145   tensorflow::strings::Appendv(&message, format, args);
    146   va_end(args);
    147   return WithLogBacktrace(tensorflow::errors::Unavailable(message));
    148 }
    149 
    150 string Reindent(tensorflow::StringPiece original,
    151                 const tensorflow::StringPiece indentation) {
    152   std::vector<string> pieces = tensorflow::str_util::Split(
    153       tensorflow::StringPiece(original.data(), original.size()), '\n');
    154   return tensorflow::str_util::Join(
    155       pieces, "\n", [indentation](string* out, string s) {
    156         tensorflow::StringPiece piece(s);
    157         tensorflow::str_util::RemoveWhitespaceContext(&piece);
    158         tensorflow::strings::StrAppend(out, indentation, piece);
    159       });
    160 }
    161 
    162 bool IsPermutation(tensorflow::gtl::ArraySlice<int64> permutation, int64 rank) {
    163   if (rank != permutation.size()) {
    164     return false;
    165   }
    166   std::vector<int64> output(permutation.size(), -1);
    167   for (auto index : permutation) {
    168     CHECK_GE(index, 0);
    169     CHECK_LT(index, rank);
    170     output[index] = 0;
    171   }
    172   return std::find(output.begin(), output.end(), -1) == output.end();
    173 }
    174 
    175 std::vector<int64> InversePermutation(
    176     tensorflow::gtl::ArraySlice<int64> input_permutation) {
    177   DCHECK(IsPermutation(input_permutation, input_permutation.size()));
    178   std::vector<int64> output_permutation(input_permutation.size(), -1);
    179   for (size_t i = 0; i < input_permutation.size(); ++i) {
    180     output_permutation[input_permutation[i]] = i;
    181   }
    182   return output_permutation;
    183 }
    184 
    185 std::vector<int64> ComposePermutations(tensorflow::gtl::ArraySlice<int64> p1,
    186                                        tensorflow::gtl::ArraySlice<int64> p2) {
    187   CHECK_EQ(p1.size(), p2.size());
    188   std::vector<int64> output;
    189   for (size_t i = 0; i < p1.size(); ++i) {
    190     output.push_back(p1[p2[i]]);
    191   }
    192   return output;
    193 }
    194 
    195 bool IsIdentityPermutation(tensorflow::gtl::ArraySlice<int64> permutation) {
    196   for (int64 i = 0; i < permutation.size(); ++i) {
    197     if (permutation[i] != i) {
    198       return false;
    199     }
    200   }
    201   return true;
    202 }
    203 
    204 PaddingConfig MakeNoPaddingConfig(int64 rank) {
    205   PaddingConfig padding_config;
    206   for (int64 dnum = 0; dnum < rank; ++dnum) {
    207     auto dimension = padding_config.add_dimensions();
    208     dimension->set_edge_padding_low(0);
    209     dimension->set_edge_padding_high(0);
    210     dimension->set_interior_padding(0);
    211   }
    212   return padding_config;
    213 }
    214 
    215 PaddingConfig MakeEdgePaddingConfig(
    216     tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) {
    217   PaddingConfig padding_config;
    218   for (const std::pair<int64, int64>& dim : padding) {
    219     auto dimension = padding_config.add_dimensions();
    220     dimension->set_edge_padding_low(dim.first);
    221     dimension->set_edge_padding_high(dim.second);
    222     dimension->set_interior_padding(0);
    223   }
    224   return padding_config;
    225 }
    226 
    227 bool HasInteriorPadding(const PaddingConfig& config) {
    228   for (const auto& dim : config.dimensions()) {
    229     if (dim.interior_padding() != 0) {
    230       return true;
    231     }
    232   }
    233   return false;
    234 }
    235 
    236 namespace {
    237 string HumanReadableNumOps(double flops, double nanoseconds,
    238                            tensorflow::StringPiece op_prefix) {
    239   if (nanoseconds == 0) {
    240     return tensorflow::strings::StrCat("NaN ", op_prefix, "OP/s");
    241   }
    242   double nano_flops = flops / nanoseconds;
    243   string throughput = tensorflow::strings::HumanReadableNum(
    244       static_cast<int64>(nano_flops * 1e9));
    245   tensorflow::StringPiece sp(throughput);
    246   // Use the more common "G(FLOPS)", rather than "B(FLOPS)"
    247   if (sp.ends_with("B") ||  // Ends in 'B', ignoring case
    248       sp.ends_with("b")) {
    249     *throughput.rbegin() = 'G';
    250   }
    251   throughput += tensorflow::strings::StrCat(op_prefix, "OP/s");
    252   return throughput;
    253 }
    254 }  // namespace
    255 
    256 string HumanReadableNumFlops(double flops, double nanoseconds) {
    257   return HumanReadableNumOps(flops, nanoseconds, "FL");
    258 }
    259 
    260 string HumanReadableNumTranscendentalOps(double trops, double nanoseconds) {
    261   return HumanReadableNumOps(trops, nanoseconds, "TR");
    262 }
    263 
    264 void LogLines(int sev, tensorflow::StringPiece text, const char* fname,
    265               int lineno) {
    266   const int orig_sev = sev;
    267   if (sev == tensorflow::FATAL) {
    268     sev = tensorflow::ERROR;
    269   }
    270 
    271   // Protect calls with a mutex so we don't interleave calls to LogLines from
    272   // multiple threads.
    273   static tensorflow::mutex log_lines_mu(tensorflow::LINKER_INITIALIZED);
    274   tensorflow::mutex_lock lock(log_lines_mu);
    275 
    276   size_t cur = 0;
    277   while (cur < text.size()) {
    278     size_t eol = text.find('\n', cur);
    279     if (eol == tensorflow::StringPiece::npos) {
    280       eol = text.size();
    281     }
    282     auto msg = text.substr(cur, eol - cur);
    283     tensorflow::internal::LogString(fname, lineno, sev,
    284                                     string(msg.data(), msg.size()));
    285     cur = eol + 1;
    286   }
    287 
    288   if (orig_sev == tensorflow::FATAL) {
    289     tensorflow::internal::LogString(fname, lineno, orig_sev,
    290                                     "Aborting due to errors.");
    291   }
    292 }
    293 
    294 int64 Product(tensorflow::gtl::ArraySlice<int64> xs) {
    295   return std::accumulate(xs.begin(), xs.end(), 1, std::multiplies<int64>());
    296 }
    297 
    298 std::vector<std::pair<int64, int64>> CommonFactors(
    299     tensorflow::gtl::ArraySlice<int64> a,
    300     tensorflow::gtl::ArraySlice<int64> b) {
    301   CHECK_EQ(Product(a), Product(b));
    302   if (0 == Product(a)) {
    303     return {std::make_pair(0, 0), std::make_pair(a.size(), b.size())};
    304   }
    305 
    306   std::vector<std::pair<int64, int64>> bounds;
    307   for (int64 i = 0, j = 0, prior_i = -1, prior_j = -1, partial_size_a = 1,
    308              partial_size_b = 1;
    309        ;) {
    310     if (partial_size_a == partial_size_b && (i > prior_i || j > prior_j)) {
    311       std::tie(prior_i, prior_j) = std::make_pair(i, j);
    312       bounds.emplace_back(i, j);
    313       continue;
    314     }
    315     bool in_bounds_i = i < a.size();
    316     bool in_bounds_j = j < b.size();
    317     if (!(in_bounds_i || in_bounds_j)) {
    318       break;
    319     }
    320     bool next_a =
    321         partial_size_a < partial_size_b ||
    322         (in_bounds_i &&
    323          (!in_bounds_j || (partial_size_a == partial_size_b && a[i] <= b[j])));
    324     bool next_b =
    325         partial_size_b < partial_size_a ||
    326         (in_bounds_j &&
    327          (!in_bounds_i || (partial_size_b == partial_size_a && b[j] <= a[i])));
    328     if (next_a) {
    329       partial_size_a *= a[i];
    330       ++i;
    331     }
    332     if (next_b) {
    333       partial_size_b *= b[j];
    334       ++j;
    335     }
    336   }
    337   return bounds;
    338 }
    339 
    340 string SanitizeFileName(string file_name) {
    341   for (char& c : file_name) {
    342     if (c == '/' || c == '\\' || c == '[' || c == ']' || c == ' ') {
    343       c = '_';
    344     }
    345   }
    346   return file_name;
    347 }
    348 
    349 }  // namespace xla
    350