Home | History | Annotate | Download | only in xla
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/util.h"
     17 
     18 #include <stdarg.h>
     19 #include <numeric>
     20 
     21 #include "absl/container/inlined_vector.h"
     22 #include "absl/strings/match.h"
     23 #include "absl/strings/str_cat.h"
     24 #include "absl/strings/str_join.h"
     25 #include "absl/strings/str_split.h"
     26 #include "tensorflow/compiler/xla/types.h"
     27 #include "tensorflow/core/lib/core/errors.h"
     28 #include "tensorflow/core/lib/strings/numbers.h"
     29 #include "tensorflow/core/platform/env.h"
     30 #include "tensorflow/core/platform/mutex.h"
     31 #include "tensorflow/core/platform/stacktrace.h"
     32 
     33 namespace xla {
     34 
     35 Status WithLogBacktrace(const Status& status) {
     36   CHECK(!status.ok());
     37   VLOG(1) << status.ToString();
     38   VLOG(1) << tensorflow::CurrentStackTrace();
     39   return status;
     40 }
     41 
     42 ScopedLoggingTimer::ScopedLoggingTimer(const string& label, bool enabled)
     43     : enabled(enabled), label(label) {
     44   if (enabled) {
     45     start_micros = tensorflow::Env::Default()->NowMicros();
     46   }
     47 }
     48 
     49 ScopedLoggingTimer::~ScopedLoggingTimer() {
     50   if (enabled) {
     51     uint64 end_micros = tensorflow::Env::Default()->NowMicros();
     52     double secs = (end_micros - start_micros) / 1000000.0;
     53 
     54     LOG(INFO) << label << " time: "
     55               << tensorflow::strings::HumanReadableElapsedTime(secs);
     56   }
     57 }
     58 
     59 Status AddStatus(Status prior, absl::string_view context) {
     60   CHECK(!prior.ok());
     61   return Status{prior.code(),
     62                 absl::StrCat(context, ": ", prior.error_message())};
     63 }
     64 
     65 Status AppendStatus(Status prior, absl::string_view context) {
     66   CHECK(!prior.ok());
     67   return Status{prior.code(),
     68                 absl::StrCat(prior.error_message(), ": ", context)};
     69 }
     70 
     71 string Reindent(absl::string_view original,
     72                 const absl::string_view indentation) {
     73   std::vector<string> pieces =
     74       absl::StrSplit(absl::string_view(original.data(), original.size()), '\n');
     75   return absl::StrJoin(pieces, "\n", [indentation](string* out, string s) {
     76     absl::StrAppend(out, indentation, absl::StripAsciiWhitespace(s));
     77   });
     78 }
     79 
     80 bool IsPermutation(absl::Span<const int64> permutation, int64 rank) {
     81   if (rank != permutation.size()) {
     82     return false;
     83   }
     84   absl::InlinedVector<int64, 8> trivial_permutation(rank);
     85   absl::c_iota(trivial_permutation, 0);
     86   return absl::c_is_permutation(permutation, trivial_permutation);
     87 }
     88 
     89 std::vector<int64> InversePermutation(
     90     absl::Span<const int64> input_permutation) {
     91   DCHECK(IsPermutation(input_permutation, input_permutation.size()));
     92   std::vector<int64> output_permutation(input_permutation.size(), -1);
     93   for (size_t i = 0; i < input_permutation.size(); ++i) {
     94     output_permutation[input_permutation[i]] = i;
     95   }
     96   return output_permutation;
     97 }
     98 
     99 std::vector<int64> ComposePermutations(absl::Span<const int64> p1,
    100                                        absl::Span<const int64> p2) {
    101   CHECK_EQ(p1.size(), p2.size());
    102   std::vector<int64> output;
    103   for (size_t i = 0; i < p1.size(); ++i) {
    104     output.push_back(p1[p2[i]]);
    105   }
    106   return output;
    107 }
    108 
    109 bool IsIdentityPermutation(absl::Span<const int64> permutation) {
    110   for (int64 i = 0; i < permutation.size(); ++i) {
    111     if (permutation[i] != i) {
    112       return false;
    113     }
    114   }
    115   return true;
    116 }
    117 
    118 PaddingConfig MakeNoPaddingConfig(int64 rank) {
    119   PaddingConfig padding_config;
    120   for (int64 dnum = 0; dnum < rank; ++dnum) {
    121     auto dimension = padding_config.add_dimensions();
    122     dimension->set_edge_padding_low(0);
    123     dimension->set_edge_padding_high(0);
    124     dimension->set_interior_padding(0);
    125   }
    126   return padding_config;
    127 }
    128 
    129 PaddingConfig MakeEdgePaddingConfig(
    130     absl::Span<const std::pair<int64, int64>> padding) {
    131   PaddingConfig padding_config;
    132   for (const std::pair<int64, int64>& dim : padding) {
    133     auto dimension = padding_config.add_dimensions();
    134     dimension->set_edge_padding_low(dim.first);
    135     dimension->set_edge_padding_high(dim.second);
    136     dimension->set_interior_padding(0);
    137   }
    138   return padding_config;
    139 }
    140 
    141 bool HasInteriorPadding(const PaddingConfig& config) {
    142   for (const auto& dim : config.dimensions()) {
    143     if (dim.interior_padding() != 0) {
    144       return true;
    145     }
    146   }
    147   return false;
    148 }
    149 
    150 namespace {
    151 string HumanReadableNumOps(double flops, double nanoseconds,
    152                            absl::string_view op_prefix) {
    153   if (nanoseconds == 0) {
    154     return absl::StrCat("NaN ", op_prefix, "OP/s");
    155   }
    156   double nano_flops = flops / nanoseconds;
    157   string throughput = tensorflow::strings::HumanReadableNum(
    158       static_cast<int64>(nano_flops * 1e9));
    159   absl::string_view sp(throughput);
    160   // Use the more common "G(FLOPS)", rather than "B(FLOPS)"
    161   if (absl::EndsWith(sp, "B") ||  // Ends in 'B', ignoring case
    162       absl::EndsWith(sp, "b")) {
    163     *throughput.rbegin() = 'G';
    164   }
    165   throughput += absl::StrCat(op_prefix, "OP/s");
    166   return throughput;
    167 }
    168 }  // namespace
    169 
    170 string HumanReadableNumFlops(double flops, double nanoseconds) {
    171   return HumanReadableNumOps(flops, nanoseconds, "FL");
    172 }
    173 
    174 string HumanReadableNumTranscendentalOps(double trops, double nanoseconds) {
    175   return HumanReadableNumOps(trops, nanoseconds, "TR");
    176 }
    177 
    178 void LogLines(int sev, absl::string_view text, const char* fname, int lineno) {
    179   const int orig_sev = sev;
    180   if (sev == tensorflow::FATAL) {
    181     sev = tensorflow::ERROR;
    182   }
    183 
    184   // Protect calls with a mutex so we don't interleave calls to LogLines from
    185   // multiple threads.
    186   static tensorflow::mutex log_lines_mu(tensorflow::LINKER_INITIALIZED);
    187   tensorflow::mutex_lock lock(log_lines_mu);
    188 
    189   size_t cur = 0;
    190   while (cur < text.size()) {
    191     size_t eol = text.find('\n', cur);
    192     if (eol == absl::string_view::npos) {
    193       eol = text.size();
    194     }
    195     auto msg = text.substr(cur, eol - cur);
    196     tensorflow::internal::LogString(fname, lineno, sev,
    197                                     string(msg.data(), msg.size()));
    198     cur = eol + 1;
    199   }
    200 
    201   if (orig_sev == tensorflow::FATAL) {
    202     tensorflow::internal::LogString(fname, lineno, orig_sev,
    203                                     "Aborting due to errors.");
    204   }
    205 }
    206 
    207 int64 Product(absl::Span<const int64> xs) {
    208   return std::accumulate(xs.begin(), xs.end(), static_cast<int64>(1),
    209                          std::multiplies<int64>());
    210 }
    211 
    212 std::vector<std::pair<int64, int64>> CommonFactors(absl::Span<const int64> a,
    213                                                    absl::Span<const int64> b) {
    214   CHECK_EQ(Product(a), Product(b));
    215   if (0 == Product(a)) {
    216     return {std::make_pair(0, 0), std::make_pair(a.size(), b.size())};
    217   }
    218 
    219   std::vector<std::pair<int64, int64>> bounds;
    220   for (int64 i = 0, j = 0, prior_i = -1, prior_j = -1, partial_size_a = 1,
    221              partial_size_b = 1;
    222        ;) {
    223     if (partial_size_a == partial_size_b && (i > prior_i || j > prior_j)) {
    224       std::tie(prior_i, prior_j) = std::make_pair(i, j);
    225       bounds.emplace_back(i, j);
    226       continue;
    227     }
    228     bool in_bounds_i = i < a.size();
    229     bool in_bounds_j = j < b.size();
    230     if (!(in_bounds_i || in_bounds_j)) {
    231       break;
    232     }
    233     bool next_a =
    234         partial_size_a < partial_size_b ||
    235         (in_bounds_i &&
    236          (!in_bounds_j || (partial_size_a == partial_size_b && a[i] <= b[j])));
    237     bool next_b =
    238         partial_size_b < partial_size_a ||
    239         (in_bounds_j &&
    240          (!in_bounds_i || (partial_size_b == partial_size_a && b[j] <= a[i])));
    241     if (next_a) {
    242       partial_size_a *= a[i];
    243       ++i;
    244     }
    245     if (next_b) {
    246       partial_size_b *= b[j];
    247       ++j;
    248     }
    249   }
    250   return bounds;
    251 }
    252 
    253 string SanitizeFileName(string file_name) {
    254   for (char& c : file_name) {
    255     if (c == '/' || c == '\\' || c == '[' || c == ']' || c == ' ') {
    256       c = '_';
    257     }
    258   }
    259   return file_name;
    260 }
    261 
    262 }  // namespace xla
    263