Home | History | Annotate | Download | only in xla
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_COMPILER_XLA_ARRAY_H_
     17 #define TENSORFLOW_COMPILER_XLA_ARRAY_H_
     18 
     19 #include <algorithm>
     20 #include <array>
     21 #include <functional>
     22 #include <initializer_list>
     23 #include <iterator>
     24 #include <memory>
     25 #include <numeric>
     26 #include <random>
     27 #include <type_traits>
     28 #include <vector>
     29 
     30 #include "absl/strings/str_cat.h"
     31 #include "absl/strings/str_join.h"
     32 #include "absl/types/span.h"
     33 #include "tensorflow/compiler/xla/status.h"
     34 #include "tensorflow/compiler/xla/types.h"
     35 #include "tensorflow/core/lib/core/bits.h"
     36 #include "tensorflow/core/platform/logging.h"
     37 #include "tensorflow/core/platform/macros.h"
     38 #include "tensorflow/core/platform/types.h"
     39 
     40 namespace xla {
     41 
     42 namespace array_impl {
     43 
     44 // conjunction
     45 //
     46 // Performs a compile-time logical AND operation on the passed types (which
     47 // must have  `::value` members convertible to `bool`. Short-circuits if it
     48 // encounters any `false` members (and does not compare the `::value` members
     49 // of any remaining arguments).
     50 //
     51 // This metafunction is designed to be a drop-in replacement for the C++17
     52 // `std::conjunction` metafunction.
     53 template <typename... Ts>
     54 struct conjunction;
     55 
     56 template <typename T, typename... Ts>
     57 struct conjunction<T, Ts...>
     58     : std::conditional<T::value, conjunction<Ts...>, T>::type {};
     59 
     60 template <>
     61 struct conjunction<> : std::true_type {};
     62 
     63 // A type trait that is valid when all elements in a parameter pack are of
     64 // integral type.
     65 template <typename... T>
     66 using pack_is_integral = conjunction<std::is_integral<T>...>;
     67 
     68 // Compares three same-sized vectors elementwise. For each item in `values`,
     69 // returns false if any of values[i] is outside the half-open range [starts[i],
     70 // ends[i]).
     71 template <typename C1, typename C2, typename C3>
     72 bool all_inside_range(const C1& values, const C2& range_starts,
     73                       const C3& range_ends) {
     74   for (size_t i = 0, e = values.size(); i < e; ++i) {
     75     if (values[i] < range_starts[i] || values[i] >= range_ends[i]) {
     76       return false;
     77     }
     78   }
     79   return true;
     80 }
     81 
     82 }  // namespace array_impl
     83 
     84 // General N dimensional array class with arbitrary value type.
     85 template <typename T>
     86 class Array {
     87  public:
     88   // Type inference can have a hard time parsing very deep initializer list
     89   // nests, especially if one or more dimensions is one as the compiler just
     90   // sees a single-element integer initializer. These typedefs allow casting
     91   // explicitly with less typing.
     92   using InitializerList1D = std::initializer_list<T>;
     93   using InitializerList2D = std::initializer_list<InitializerList1D>;
     94   using InitializerList3D = std::initializer_list<InitializerList2D>;
     95   using InitializerList4D = std::initializer_list<InitializerList3D>;
     96 
     97   using value_type = T;
     98 
     99   // Creates a new array with the specified dimensions.
    100   explicit Array(absl::Span<const int64> sizes) : Array(sizes, T()) {}
    101 
    102   // Creates a new array with the specified dimensions and specified value for
    103   // every cell.
    104   Array(absl::Span<const int64> sizes, T value)
    105       : sizes_(sizes.begin(), sizes.end()), values_(new T[num_elements()]) {
    106     Fill(value);
    107   }
    108 
    109   // Creates a 2D array from the given nested initializer list. The outer
    110   // initializer list is the first dimension, the inner is the second dimension.
    111   // For example, {{1, 2, 3}, {4, 5, 6}} results in an array with n1=2 and n2=3.
    112   Array(InitializerList2D values)
    113       : Array(ToInt64Vector({values.size(), values.begin()->size()})) {
    114     int64 idx = 0;
    115     for (const auto& it1 : values) {
    116       for (const auto& it2 : it1) {
    117         values_[idx] = it2;
    118         ++idx;
    119       }
    120     }
    121     CHECK(idx == num_elements());
    122   }
    123 
    124   // Creates a 1D array of a floating-point type (half, bfloat16, float,
    125   // or double) from an initializer list of float values.
    126   template <typename T2, typename = typename std::enable_if<
    127                              (std::is_same<T, Eigen::half>::value ||
    128                               std::is_same<T, bfloat16>::value ||
    129                               std::is_same<T, float>::value ||
    130                               std::is_same<T, double>::value) &&
    131                              std::is_same<T2, float>::value>::type>
    132   Array(std::initializer_list<T2> values)
    133       : Array(ToInt64Vector({values.size()})) {
    134     int64 idx = 0;
    135     for (const auto& it1 : values) {
    136       values_[idx] = static_cast<T>(it1);
    137       ++idx;
    138     }
    139     CHECK(idx == num_elements());
    140   }
    141 
    142   // Creates a 2D array of a floating-point type (half, bfloat16, float,
    143   // or double) from an initializer list of float values.
    144   template <typename T2, typename = typename std::enable_if<
    145                              (std::is_same<T, Eigen::half>::value ||
    146                               std::is_same<T, bfloat16>::value ||
    147                               std::is_same<T, float>::value ||
    148                               std::is_same<T, double>::value) &&
    149                              std::is_same<T2, float>::value>::type>
    150   Array(std::initializer_list<std::initializer_list<T2>> values)
    151       : Array(ToInt64Vector({values.size(), values.begin()->size()})) {
    152     int64 idx = 0;
    153     for (const auto& it1 : values) {
    154       for (const auto& it2 : it1) {
    155         values_[idx] = static_cast<T>(it2);
    156         ++idx;
    157       }
    158     }
    159     CHECK(idx == num_elements());
    160   }
    161 
    162   // Creates a 3D array from the given nested initializer list. The outer
    163   // initializer list is the first dimension, and so on.
    164   Array(InitializerList3D values)
    165       : Array(ToInt64Vector({values.size(), values.begin()->size(),
    166                              values.begin()->begin()->size()})) {
    167     int64 idx = 0;
    168     for (const auto& it1 : values) {
    169       for (const auto& it2 : it1) {
    170         for (const auto& it3 : it2) {
    171           values_[idx] = it3;
    172           ++idx;
    173         }
    174       }
    175     }
    176     CHECK(idx == num_elements());
    177   }
    178 
    179   // Creates a 3D array of a floating-point type (half, bfloat16, float,
    180   // or double) from an initializer list of float values.
    181   template <typename T2, typename = typename std::enable_if<
    182                              (std::is_same<T, Eigen::half>::value ||
    183                               std::is_same<T, bfloat16>::value ||
    184                               std::is_same<T, float>::value ||
    185                               std::is_same<T, double>::value) &&
    186                              std::is_same<T2, float>::value>::type>
    187   Array(std::initializer_list<std::initializer_list<std::initializer_list<T2>>>
    188             values)
    189       : Array(ToInt64Vector({values.size(), values.begin()->size(),
    190                              values.begin()->begin()->size()})) {
    191     int64 idx = 0;
    192     for (const auto& it1 : values) {
    193       for (const auto& it2 : it1) {
    194         for (const auto& it3 : it2) {
    195           values_[idx] = static_cast<T>(it3);
    196           ++idx;
    197         }
    198       }
    199     }
    200     CHECK(idx == num_elements());
    201   }
    202 
    203   // Creates a 4D array from the given nested initializer list. The outer
    204   // initializer list is the first dimension, and so on.
    205   Array(InitializerList4D values)
    206       : Array(ToInt64Vector({values.size(), values.begin()->size(),
    207                              values.begin()->begin()->size(),
    208                              values.begin()->begin()->begin()->size()})) {
    209     int64 idx = 0;
    210     for (const auto& it1 : values) {
    211       for (const auto& it2 : it1) {
    212         for (const auto& it3 : it2) {
    213           for (const auto& it4 : it3) {
    214             values_[idx] = it4;
    215             ++idx;
    216           }
    217         }
    218       }
    219     }
    220     CHECK(idx == num_elements());
    221   }
    222 
    223   // Creates a 4D array of a floating-point type (half, bfloat16, float,
    224   // or double) from an initializer list of float values.
    225   template <typename T2, typename = typename std::enable_if<
    226                              (std::is_same<T, Eigen::half>::value ||
    227                               std::is_same<T, bfloat16>::value ||
    228                               std::is_same<T, float>::value ||
    229                               std::is_same<T, double>::value) &&
    230                              std::is_same<T2, float>::value>::type>
    231   Array(std::initializer_list<
    232         std::initializer_list<std::initializer_list<std::initializer_list<T2>>>>
    233             values)
    234       : Array(ToInt64Vector({values.size(), values.begin()->size(),
    235                              values.begin()->begin()->size(),
    236                              values.begin()->begin()->begin()->size()})) {
    237     int64 idx = 0;
    238     for (const auto& it1 : values) {
    239       for (const auto& it2 : it1) {
    240         for (const auto& it3 : it2) {
    241           for (const auto& it4 : it3) {
    242             values_[idx] = static_cast<T>(it4);
    243             ++idx;
    244           }
    245         }
    246       }
    247     }
    248     CHECK(idx == num_elements());
    249   }
    250 
    251   Array(const Array<T>& other)
    252       : sizes_(other.sizes_), values_(new T[num_elements()]) {
    253     std::copy(&other.values_[0], &other.values_[0] + num_elements(),
    254               &values_[0]);
    255   }
    256 
    257   Array<T>& operator=(const Array<T>& other) {
    258     sizes_ = other.sizes_;
    259     values_.reset(new T[num_elements()]);
    260     std::copy(&other.values_[0], &other.values_[0] + num_elements(),
    261               &values_[0]);
    262     return *this;
    263   }
    264 
    265   // Fills the array with the specified value.
    266   void Fill(const T& value) {
    267     std::fill(&values_[0], &values_[0] + num_elements(), value);
    268   }
    269 
    270   // Fills the array with sequentially increasing values.
    271   void FillIota(const T& value) {
    272     std::iota(&values_[0], &values_[0] + num_elements(), value);
    273   }
    274 
    275   // Fills the array with a repeating sequence:
    276   //   [value, value + 1, ..., value + length - 1, value, ... ]
    277   void FillRepeatedIota(const T& value, int64 length) {
    278     for (int64 i = 0; i < num_elements(); i += length) {
    279       std::iota(&values_[i], &values_[std::min(i + length, num_elements())],
    280                 value);
    281     }
    282   }
    283 
    284   // Fills the array with the sequence i*multiplier for i=0,1,...
    285   void FillWithMultiples(const T& multiplier) {
    286     for (int64 i = 0; i < num_elements(); ++i) {
    287       values_[i] = static_cast<T>(i) * multiplier;
    288     }
    289   }
    290 
    291   // Fills the array with random normal variables with the specified mean.
    292   void FillRandom(const T& stddev, const double mean = 0.0,
    293                   const int seed = 12345) {
    294     std::mt19937 g(seed);
    295     std::normal_distribution<double> distribution(mean,
    296                                                   static_cast<double>(stddev));
    297     for (int64 i = 0; i < num_elements(); ++i) {
    298       values_[i] = static_cast<T>(distribution(g));
    299     }
    300   }
    301 
    302   // Sets all the values in the array to values specified in the container.
    303   template <typename Container = std::initializer_list<T>>
    304   void SetValues(const Container& container) {
    305     CHECK_EQ(std::distance(std::begin(container), std::end(container)),
    306              num_elements());
    307     std::copy(std::begin(container), std::end(container), &values_[0]);
    308   }
    309 
    310   // Invokes a callback with the (indices, value_ptr) for each cell in the
    311   // array.
    312   void Each(std::function<void(absl::Span<const int64>, T*)> f) {
    313     std::vector<int64> index(sizes_.size());
    314     for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) {
    315       f(index, &values_[i]);
    316     }
    317   }
    318 
    319   // Invokes a callback with the (indices, value) for each cell in the array.
    320   void Each(std::function<void(absl::Span<const int64>, T)> f) const {
    321     std::vector<int64> index(sizes_.size());
    322     for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) {
    323       f(index, values_[i]);
    324     }
    325   }
    326 
    327   // Invokes a callback with the (indices, value_ptr) for each cell in the
    328   // array. If a callback returns a non-OK status, returns that else returns
    329   // Status::OK().
    330   Status EachStatus(std::function<Status(absl::Span<const int64>, T*)> f) {
    331     std::vector<int64> index(sizes_.size());
    332     for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) {
    333       Status s = f(index, &values_[i]);
    334       if (!s.ok()) {
    335         return s;
    336       }
    337     }
    338     return Status::OK();
    339   }
    340 
    341   // Invokes a callback with the (indices, value) for each cell in the array.
    342   // If a callback returns a non-OK status, returns that else returns
    343   // Status::OK().
    344   Status EachStatus(std::function<Status(absl::Span<const int64>, T)> f) const {
    345     std::vector<int64> index(sizes_.size());
    346     for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) {
    347       Status s = f(index, values_[i]);
    348       if (!s.ok()) {
    349         return s;
    350       }
    351     }
    352     return Status::OK();
    353   }
    354 
    355   // Returns the value at the cell specified by the indexes. The number of
    356   // arguments have to match with the number of dimensions for the array.
    357   //
    358   // The type trait is required to avoid this overload participating too
    359   // eagerly; a parameter pack can take zero or more elements, so we must
    360   // restrict this to only parameter packs that are all of integral type.
    361   template <typename... Dims>
    362   typename std::enable_if<array_impl::pack_is_integral<Dims...>::value,
    363                           const T&>::type
    364   operator()(Dims... dims) const {
    365     // We are using a std::array to avoid having to allocate memory in this
    366     // function for performance reasons.
    367     std::array<int64, sizeof...(dims)> indexes{{static_cast<int64>(dims)...}};
    368     return values_[calculate_index(indexes)];
    369   }
    370 
    371   // Returns the value at the cell specified by the indexes. The number of
    372   // arguments have to match with the number of dimensions for the array.
    373   template <typename... Dims>
    374   typename std::enable_if<array_impl::pack_is_integral<Dims...>::value,
    375                           T&>::type
    376   operator()(Dims... dims) {
    377     // We are using a std::array to avoid having to allocate memory in this
    378     // function for performance reasons.
    379     std::array<int64, sizeof...(dims)> indexes{{static_cast<int64>(dims)...}};
    380     return values_[calculate_index(indexes)];
    381   }
    382 
    383   // Returns the value at the cell specified by the indexes. The number of
    384   // arguments have to match with the number of dimensions for the array.
    385   const T& operator()(absl::Span<const int64> indexes) const {
    386     return values_[calculate_index(indexes)];
    387   }
    388 
    389   // Returns the value at the cell specified by the indexes. The number of
    390   // arguments have to match with the number of dimensions for the array.
    391   T& operator()(absl::Span<const int64> indexes) {
    392     return values_[calculate_index(indexes)];
    393   }
    394 
    395   // Low-level accessor for stuff like memcmp, handle with care. Returns pointer
    396   // to the underlying storage of the array (similarly to std::vector::data()).
    397   T* data() const {
    398     // TODO(tberghammer): Get rid of the const_cast. Currently it is needed
    399     // because the Eigen backend needs a non-const pointers even for reading
    400     // from the array.
    401     return const_cast<Array*>(this)->values_.get();
    402   }
    403 
    404   // Returns the size of the dimension at the given index.
    405   int64 dim(int64 n) const {
    406     CHECK(n < sizes_.size());
    407     return sizes_[n];
    408   }
    409 
    410   // Returns a vector containing the dimensions of the array.
    411   const std::vector<int64>& dimensions() const { return sizes_; }
    412 
    413   int64 num_dimensions() const { return sizes_.size(); }
    414 
    415   // Returns the total number of elements in the array.
    416   int64 num_elements() const {
    417     return std::accumulate(sizes_.begin(), sizes_.end(), 1LL,
    418                            std::multiplies<int64>());
    419   }
    420 
    421   const T* begin() const { return &values_[0]; }
    422   T* begin() { return &values_[0]; }
    423   const T* end() const { return &values_[num_elements()]; }
    424   T* end() { return &values_[num_elements()]; }
    425 
    426   bool operator==(const Array<T>& other) const {
    427     if (sizes_.size() != other.sizes_.size()) {
    428       return false;
    429     }
    430     for (int64 i = 0; i < sizes_.size(); ++i) {
    431       if (sizes_[i] != other.sizes_[i]) {
    432         return false;
    433       }
    434     }
    435     for (int64 i = 0; i < num_elements(); ++i) {
    436       if (values_[i] != other.values_[i]) {
    437         return false;
    438       }
    439     }
    440     return true;
    441   }
    442 
    443   bool operator!=(const Array<T>& other) const { return !(*this == other); }
    444 
    445   // Performs the equivalent of a slice operation on this array.
    446   Array<T> Slice(absl::Span<const int64> starts,
    447                  absl::Span<const int64> limits) const {
    448     CHECK_EQ(starts.size(), num_dimensions());
    449     CHECK_EQ(limits.size(), num_dimensions());
    450 
    451     std::vector<int64> sizes;
    452     std::transform(starts.begin(), starts.end(), limits.begin(),
    453                    std::back_inserter(sizes),
    454                    [](int64 start, int64 limit) { return limit - start; });
    455     Array<T> result(sizes);
    456 
    457     std::vector<int64> index(sizes_.size());
    458     int64 slice_i = 0;
    459     for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) {
    460       if (array_impl::all_inside_range(index, starts, limits)) {
    461         // Even though the bounds of result are different to our bounds, we're
    462         // iterating in the same order. So we can simply write successive linear
    463         // indices instead of recalculating a multi-dimensional index.
    464         result.values_[slice_i++] = values_[i];
    465       }
    466     }
    467     return result;
    468   }
    469 
    470   // Performs the equivalent of a DynamicUpdateSlice in-place on this array.
    471   void UpdateSlice(const Array<T>& from,
    472                    absl::Span<const int64> start_indices) {
    473     CHECK_EQ(from.num_dimensions(), num_dimensions());
    474     std::vector<int64> limit_indices;
    475     std::transform(start_indices.begin(), start_indices.end(),
    476                    from.dimensions().begin(), std::back_inserter(limit_indices),
    477                    std::plus<int64>{});
    478     std::vector<int64> index(sizes_.size());
    479     int64 from_i = 0;
    480     for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) {
    481       if (array_impl::all_inside_range(index, start_indices, limit_indices)) {
    482         // Even though the bounds of from are different to our bounds, we're
    483         // iterating in the same order. So we can simply write successive linear
    484         // indices instead of recalculating a multi-dimensional index.
    485         values_[i] = from.values_[from_i++];
    486       }
    487     }
    488   }
    489 
    490   // Performs an in-place reshape, modifying the dimensions but not the
    491   // underlying data.
    492   void Reshape(absl::Span<const int64> new_dimensions) {
    493     int64 old_num_elements = num_elements();
    494     sizes_ = std::vector<int64>(new_dimensions.begin(), new_dimensions.end());
    495     CHECK_EQ(num_elements(), old_num_elements);
    496   }
    497 
    498   // Returns a string representation of the array suitable for debugging.
    499   string ToString() const {
    500     std::vector<string> pieces;
    501     std::vector<int64> index(sizes_.size());
    502     do {
    503       // Emit leading spaces and opening square brackets
    504       if (index.back() == 0) {
    505         for (int64 i = sizes_.size() - 1; i >= 0; --i) {
    506           if (i == 0 || index[i - 1] != 0) {
    507             for (int64 j = 0; j < sizes_.size(); ++j) {
    508               pieces.push_back(j < i ? " " : "[");
    509             }
    510             break;
    511           }
    512         }
    513       }
    514 
    515       pieces.push_back(absl::StrCat(values_[calculate_index(index)]));
    516 
    517       // Emit comma if it isn't the last element
    518       if (index.back() != sizes_.back() - 1) {
    519         pieces.push_back(", ");
    520       }
    521 
    522       // Emit closing square brackets
    523       for (int64 i = sizes_.size() - 1; i >= 0; --i) {
    524         if (index[i] != sizes_[i] - 1) {
    525           break;
    526         }
    527         pieces.push_back("]");
    528         if (i != 0 && index[i - 1] != sizes_[i - 1] - 1) {
    529           pieces.push_back(",\n");
    530         }
    531       }
    532     } while (next_index(&index));
    533     return absl::StrJoin(pieces, "");
    534   }
    535 
    536  private:
    537   // Converts an initializer_list of type U to a vector of type int64. Used by
    538   // the initializer list based constructors to convert the size type into int64
    539   // to be passed to the size based constructor.
    540   template <typename U>
    541   static std::vector<int64> ToInt64Vector(
    542       const std::initializer_list<U>& data) {
    543     return std::vector<int64>(data.begin(), data.end());
    544   }
    545 
    546   // Returns the linear index from the list of per-dimension indexes. Function
    547   // is templated so can be used with an std::array from operator() to avoid
    548   // memory allocation.
    549   template <typename U>
    550   int64 calculate_index(const U& indexes) const {
    551     CHECK_EQ(sizes_.size(), indexes.size());
    552     int64 index = 0;
    553     for (int64 i = 0; i < sizes_.size(); ++i) {
    554       index *= sizes_[i];
    555       index += indexes[i];
    556     }
    557     return index;
    558   }
    559 
    560   // Advances the specified set of indexes and returns true if we haven't
    561   // wrapped around (i.e. result isn't {0, 0, ...}).
    562   bool next_index(std::vector<int64>* index) const {
    563     CHECK_EQ(index->size(), sizes_.size());
    564     for (int64 i = sizes_.size() - 1; i >= 0; --i) {
    565       (*index)[i]++;
    566       if ((*index)[i] < sizes_[i]) {
    567         return true;
    568       }
    569       (*index)[i] = 0;
    570     }
    571     return false;
    572   }
    573 
    574   std::vector<int64> sizes_;
    575   std::unique_ptr<T[]> values_;
    576 };
    577 
    578 }  // namespace xla
    579 
    580 #endif  // TENSORFLOW_COMPILER_XLA_ARRAY_H_
    581