Home | History | Annotate | Download | only in xla
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/xla/literal_util.h"
     17 
     18 #include <algorithm>
     19 #include <cstring>
     20 #include <functional>
     21 #include <limits>
     22 #include <numeric>
     23 #include <vector>
     24 
     25 #include "absl/memory/memory.h"
     26 #include "absl/strings/str_cat.h"
     27 #include "absl/strings/str_join.h"
     28 #include "tensorflow/compiler/xla/index_util.h"
     29 #include "tensorflow/compiler/xla/shape_util.h"
     30 #include "tensorflow/compiler/xla/status_macros.h"
     31 #include "tensorflow/compiler/xla/types.h"
     32 #include "tensorflow/compiler/xla/util.h"
     33 #include "tensorflow/core/lib/core/errors.h"
     34 #include "tensorflow/core/lib/hash/hash.h"
     35 #include "tensorflow/core/platform/logging.h"
     36 #include "tensorflow/core/platform/mem.h"
     37 #include "tensorflow/core/platform/types.h"
     38 
     39 namespace xla {
     40 namespace {
     41 
     42 using absl::StrCat;
     43 
     44 // Return a literal with all arrays of type FromNativeT converted to type
     45 // ToNativeT in the given literal.
     46 template <typename FromNativeT, typename ToNativeT>
     47 Literal ConvertType(LiteralSlice literal) {
     48   // First construct shape of the result.
     49   Shape result_shape(literal.shape());
     50   ShapeUtil::ForEachMutableSubshape(
     51       &result_shape, [](Shape* subshape, const ShapeIndex&) {
     52         if (subshape->element_type() ==
     53             primitive_util::NativeToPrimitiveType<FromNativeT>()) {
     54           subshape->set_element_type(
     55               primitive_util::NativeToPrimitiveType<ToNativeT>());
     56         }
     57       });
     58   Literal result(result_shape);
     59 
     60   // Then copy over the data from 'literal' converting FromNativeT values to
     61   // ToNativeT values as necessary.
     62   ShapeUtil::ForEachSubshape(
     63       literal.shape(),
     64       [&](const Shape& subshape, const ShapeIndex& shape_index) {
     65         if (subshape.IsArray()) {
     66           if (subshape.element_type() ==
     67               primitive_util::NativeToPrimitiveType<FromNativeT>()) {
     68             auto src = literal.data<FromNativeT>(shape_index);
     69             auto dest = result.data<ToNativeT>(shape_index);
     70             for (int64 i = 0; i < src.size(); ++i) {
     71               dest[i] = static_cast<ToNativeT>(src[i]);
     72             }
     73           } else {
     74             TF_CHECK_OK(result.CopyFrom(literal,
     75                                         /*dest_shape_index=*/shape_index,
     76                                         /*src_shape_index=*/shape_index));
     77           }
     78         }
     79       });
     80   return result;
     81 }
     82 
     83 }  // namespace
     84 
     85 /* static */ Literal LiteralUtil::CreateFromDimensions(
     86     PrimitiveType primitive_type, absl::Span<const int64> dimensions) {
     87   return Literal::CreateFromShape(
     88       ShapeUtil::MakeShape(primitive_type, dimensions));
     89 }
     90 
     91 /* static */ Literal LiteralUtil::ConvertBF16ToF32(
     92     const LiteralSlice& bf16_literal) {
     93   return ConvertType<bfloat16, float>(bf16_literal);
     94 }
     95 
     96 /* static */ Literal LiteralUtil::ConvertF32ToBF16(
     97     const LiteralSlice& f32_literal) {
     98   return ConvertType<float, bfloat16>(f32_literal);
     99 }
    100 
    101 /* static */ Literal LiteralUtil::CreateToken() {
    102   return Literal(ShapeUtil::MakeTokenShape());
    103 }
    104 
    105 /* static */ Literal LiteralUtil::Zero(PrimitiveType primitive_type) {
    106   switch (primitive_type) {
    107     case U8:
    108       return LiteralUtil::CreateR0<uint8>(0);
    109     case U16:
    110       return LiteralUtil::CreateR0<uint16>(0);
    111     case U32:
    112       return LiteralUtil::CreateR0<uint32>(0);
    113     case U64:
    114       return LiteralUtil::CreateR0<uint64>(0);
    115     case S8:
    116       return LiteralUtil::CreateR0<int8>(0);
    117     case S16:
    118       return LiteralUtil::CreateR0<int16>(0);
    119     case S32:
    120       return LiteralUtil::CreateR0<int32>(0);
    121     case S64:
    122       return LiteralUtil::CreateR0<int64>(0);
    123     case F16:
    124       return LiteralUtil::CreateR0<half>(static_cast<half>(0.0f));
    125     case BF16:
    126       return LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(0.0f));
    127     case F32:
    128       return LiteralUtil::CreateR0<float>(0);
    129     case F64:
    130       return LiteralUtil::CreateR0<double>(0);
    131     case C64:
    132       return LiteralUtil::CreateR0<complex64>(0);
    133     case C128:
    134       return LiteralUtil::CreateR0<complex128>(0);
    135     case PRED:
    136       return LiteralUtil::CreateR0<bool>(false);
    137     case TUPLE:
    138       LOG(FATAL) << "tuple element type cannot take on value of 0";
    139     case OPAQUE:
    140       LOG(FATAL) << "opaque element type cannot take on value of 0";
    141     default:
    142       LOG(FATAL) << "Unhandled primitive type " << primitive_type;
    143   }
    144 }
    145 
    146 /* static */ Literal LiteralUtil::One(PrimitiveType primitive_type) {
    147   switch (primitive_type) {
    148     case U8:
    149       return LiteralUtil::CreateR0<uint8>(1);
    150     case U32:
    151       return LiteralUtil::CreateR0<uint32>(1);
    152     case U64:
    153       return LiteralUtil::CreateR0<uint64>(1);
    154     case S8:
    155       return LiteralUtil::CreateR0<int8>(1);
    156     case S32:
    157       return LiteralUtil::CreateR0<int32>(1);
    158     case S64:
    159       return LiteralUtil::CreateR0<int64>(1);
    160     case F16:
    161       return LiteralUtil::CreateR0<half>(static_cast<half>(1.0f));
    162     case BF16:
    163       return LiteralUtil::CreateR0<bfloat16>(static_cast<bfloat16>(1.0f));
    164     case F32:
    165       return LiteralUtil::CreateR0<float>(1);
    166     case F64:
    167       return LiteralUtil::CreateR0<double>(1);
    168     case C64:
    169       return LiteralUtil::CreateR0<complex64>(1);
    170     case C128:
    171       return LiteralUtil::CreateR0<complex128>(1);
    172     case PRED:
    173       return LiteralUtil::CreateR0<bool>(true);
    174     case S16:
    175     case U16:
    176       LOG(FATAL) << "u16/s16 literals not yet implemented";
    177     case TUPLE:
    178       LOG(FATAL) << "tuple element type cannot take on value of 1";
    179     case OPAQUE:
    180       LOG(FATAL) << "opaque element type cannot take on value of 1";
    181     default:
    182       LOG(FATAL) << "Unhandled primitive type " << primitive_type;
    183   }
    184 }
    185 
    186 /* static */ Literal LiteralUtil::MinValue(PrimitiveType primitive_type) {
    187   switch (primitive_type) {
    188     case U8:
    189       return LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::min());
    190     case U32:
    191       return LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::min());
    192     case U64:
    193       return LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::min());
    194     case S8:
    195       return LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::min());
    196     case S32:
    197       return LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::min());
    198     case S64:
    199       return LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::min());
    200     case F32:
    201       return LiteralUtil::CreateR0<float>(
    202           -std::numeric_limits<float>::infinity());
    203     case F64:
    204       return LiteralUtil::CreateR0<double>(
    205           -std::numeric_limits<double>::infinity());
    206     case C64:
    207       LOG(FATAL) << "C64 element type has no minimum value";
    208     case C128:
    209       LOG(FATAL) << "C128 element type has no minimum value";
    210     case PRED:
    211       return LiteralUtil::CreateR0<bool>(false);
    212     case S16:
    213     case U16:
    214       LOG(FATAL) << "u16/s16 literals not yet implemented";
    215     case F16:
    216       return LiteralUtil::CreateR0<half>(
    217           static_cast<half>(-std::numeric_limits<float>::infinity()));
    218     case BF16:
    219       return LiteralUtil::CreateR0<bfloat16>(
    220           static_cast<bfloat16>(-std::numeric_limits<float>::infinity()));
    221     case TUPLE:
    222       LOG(FATAL) << "tuple element type has no minimum value";
    223     case OPAQUE:
    224       LOG(FATAL) << "opaque element type has no minimum value";
    225     default:
    226       LOG(FATAL) << "Unhandled primitive type " << primitive_type;
    227   }
    228 }
    229 
    230 /* static */ Literal LiteralUtil::MaxValue(PrimitiveType primitive_type) {
    231   switch (primitive_type) {
    232     case U8:
    233       return LiteralUtil::CreateR0<uint8>(std::numeric_limits<uint8>::max());
    234     case U32:
    235       return LiteralUtil::CreateR0<uint32>(std::numeric_limits<uint32>::max());
    236     case U64:
    237       return LiteralUtil::CreateR0<uint64>(std::numeric_limits<uint64>::max());
    238     case S8:
    239       return LiteralUtil::CreateR0<int8>(std::numeric_limits<int8>::max());
    240     case S32:
    241       return LiteralUtil::CreateR0<int32>(std::numeric_limits<int32>::max());
    242     case S64:
    243       return LiteralUtil::CreateR0<int64>(std::numeric_limits<int64>::max());
    244     case F32:
    245       return LiteralUtil::CreateR0<float>(
    246           std::numeric_limits<float>::infinity());
    247     case F64:
    248       return LiteralUtil::CreateR0<double>(
    249           std::numeric_limits<double>::infinity());
    250     case PRED:
    251       return LiteralUtil::CreateR0<bool>(true);
    252     case S16:
    253     case U16:
    254       LOG(FATAL) << "u16/s16 literals not yet implemented";
    255     case F16:
    256       return LiteralUtil::CreateR0<half>(
    257           static_cast<half>(std::numeric_limits<float>::infinity()));
    258     case BF16:
    259       return LiteralUtil::CreateR0<bfloat16>(
    260           static_cast<bfloat16>(std::numeric_limits<float>::infinity()));
    261     case TUPLE:
    262       LOG(FATAL) << "tuple element type has no maximum value";
    263     case OPAQUE:
    264       LOG(FATAL) << "opaque element type has no maximum value";
    265     default:
    266       LOG(FATAL) << "Unhandled primitive type " << primitive_type;
    267   }
    268 }
    269 
    270 /* static */ Literal LiteralUtil::CreateR1(
    271     const tensorflow::core::Bitmap& values) {
    272   Literal literal(
    273       ShapeUtil::MakeShape(PRED, {static_cast<int64>(values.bits())}));
    274   literal.PopulateR1(values);
    275   return literal;
    276 }
    277 
    278 /* static */ Literal LiteralUtil::CreateR1U8(absl::string_view value) {
    279   Literal literal(ShapeUtil::MakeShape(U8, {static_cast<int64>(value.size())}));
    280   for (int i = 0; i < value.size(); ++i) {
    281     literal.Set<uint8>({i}, value[i]);
    282   }
    283   return literal;
    284 }
    285 
    286 /* static */ Literal LiteralUtil::CreateR2F32Linspace(float from, float to,
    287                                                       int64 rows, int64 cols) {
    288   auto value = MakeLinspaceArray2D(from, to, rows, cols);
    289   return CreateR2FromArray2D(*value);
    290 }
    291 
    292 /* static */ Literal LiteralUtil::ReshapeSlice(
    293     absl::Span<const int64> new_dimensions,
    294     absl::Span<const int64> minor_to_major, const LiteralSlice& literal) {
    295   int64 new_num_elements = 1;
    296   for (int64 i = 0; i < new_dimensions.size(); ++i) {
    297     new_num_elements *= new_dimensions[i];
    298   }
    299   CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements);
    300   CHECK_EQ(new_dimensions.size(), minor_to_major.size());
    301 
    302   Literal new_literal(
    303       ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions));
    304 
    305   // Create a new shape with the given minor-to-major layout. This shape is used
    306   // solely for converting linear address to multi-dimensional addresses when
    307   // writing elements to the new literal.
    308   Shape shape_with_layout = new_literal.shape();
    309   *shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
    310 
    311   // Copy data into new literal, element-by-element.
    312   for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) {
    313     std::vector<int64> from_multi_index =
    314         IndexUtil::LinearIndexToMultidimensionalIndex(literal.shape(), i);
    315     std::vector<int64> to_multi_index =
    316         IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i);
    317     switch (literal.shape().element_type()) {
    318       case PRED:
    319         new_literal.Set<bool>(to_multi_index,
    320                               literal.Get<bool>(from_multi_index));
    321         break;
    322       case U8:
    323         new_literal.Set<uint8>(to_multi_index,
    324                                literal.Get<uint8>(from_multi_index));
    325         break;
    326       case U32:
    327         new_literal.Set<uint32>(to_multi_index,
    328                                 literal.Get<uint32>(from_multi_index));
    329         break;
    330       case S32:
    331         new_literal.Set<int32>(to_multi_index,
    332                                literal.Get<int32>(from_multi_index));
    333         break;
    334       case U64:
    335         new_literal.Set<uint64>(to_multi_index,
    336                                 literal.Get<uint64>(from_multi_index));
    337         break;
    338       case S64:
    339         new_literal.Set<int64>(to_multi_index,
    340                                literal.Get<int64>(from_multi_index));
    341         break;
    342       case F32:
    343         new_literal.Set<float>(to_multi_index,
    344                                literal.Get<float>(from_multi_index));
    345         break;
    346       case F64:
    347         new_literal.Set<double>(to_multi_index,
    348                                 literal.Get<double>(from_multi_index));
    349         break;
    350       case C64:
    351         new_literal.Set<complex64>(to_multi_index,
    352                                    literal.Get<complex64>(from_multi_index));
    353         break;
    354       case C128:
    355         new_literal.Set<complex128>(to_multi_index,
    356                                     literal.Get<complex128>(from_multi_index));
    357         break;
    358       default:
    359         LOG(FATAL) << "Unhandled primitive element type: "
    360                    << PrimitiveType_Name(literal.shape().element_type());
    361     }
    362   }
    363 
    364   return new_literal;
    365 }
    366 
    367 /* static */ Literal LiteralUtil::GetFirstScalarLiteral(
    368     const LiteralSlice& literal) {
    369   CHECK(literal.shape().IsArray());
    370   CHECK_GT(ShapeUtil::ElementsIn(literal.shape()), 0);
    371   switch (literal.shape().element_type()) {
    372     case PRED:
    373       return LiteralUtil::CreateR0<bool>(literal.GetFirstElement<bool>());
    374     // 8 bit types.
    375     case S8:
    376       return LiteralUtil::CreateR0<int8>(literal.GetFirstElement<int8>());
    377     case U8:
    378       return LiteralUtil::CreateR0<uint8>(literal.GetFirstElement<uint8>());
    379     // 16 bit types.
    380     case BF16:
    381       return LiteralUtil::CreateR0<bfloat16>(
    382           literal.GetFirstElement<bfloat16>());
    383     case F16:
    384       return LiteralUtil::CreateR0<half>(literal.GetFirstElement<half>());
    385     case S16:
    386       return LiteralUtil::CreateR0<int16>(literal.GetFirstElement<int16>());
    387     case U16:
    388       return LiteralUtil::CreateR0<uint16>(literal.GetFirstElement<uint16>());
    389     // 32 bit types.
    390     case F32:
    391       return LiteralUtil::CreateR0<float>(literal.GetFirstElement<float>());
    392     case S32:
    393       return LiteralUtil::CreateR0<int32>(literal.GetFirstElement<int32>());
    394     case U32:
    395       return LiteralUtil::CreateR0<uint32>(literal.GetFirstElement<uint32>());
    396     // 64 bit types.
    397     case C64:
    398       return LiteralUtil::CreateR0<complex64>(
    399           literal.GetFirstElement<complex64>());
    400     case F64:
    401       return LiteralUtil::CreateR0<double>(literal.GetFirstElement<double>());
    402     case S64:
    403       return LiteralUtil::CreateR0<int64>(literal.GetFirstElement<int64>());
    404     case U64:
    405       return LiteralUtil::CreateR0<uint64>(literal.GetFirstElement<uint64>());
    406 
    407     case C128:
    408       return LiteralUtil::CreateR0<complex128>(
    409           literal.GetFirstElement<complex128>());
    410     default:
    411       LOG(FATAL) << "Unhandled primitive type "
    412                  << literal.shape().element_type();
    413   }
    414 }
    415 
    416 /* static */ Literal LiteralUtil::MakeTuple(
    417     absl::Span<const Literal* const> elements) {
    418   std::vector<Shape> element_shapes;
    419   for (const auto* element : elements) {
    420     element_shapes.push_back(element->shape());
    421   }
    422   Literal literal(ShapeUtil::MakeTupleShape(element_shapes));
    423   for (int i = 0; i < elements.size(); ++i) {
    424     TF_CHECK_OK(literal.CopyFrom(*elements[i], /*dest_shape_index=*/{i}));
    425   }
    426   return literal;
    427 }
    428 
    429 /* static */ Literal LiteralUtil::MakeTupleFromSlices(
    430     absl::Span<const LiteralSlice> elements) {
    431   std::vector<Shape> element_shapes;
    432   for (const auto& element : elements) {
    433     element_shapes.push_back(element.shape());
    434   }
    435   Literal literal(ShapeUtil::MakeTupleShape(element_shapes));
    436   for (int i = 0; i < elements.size(); ++i) {
    437     TF_CHECK_OK(literal.CopyFrom(elements[i], /*dest_shape_index=*/{i}));
    438   }
    439   return literal;
    440 }
    441 
    442 /* static */ Literal LiteralUtil::MakeTupleOwned(
    443     std::vector<Literal> elements) {
    444   std::vector<Shape> element_shapes;
    445   element_shapes.reserve(elements.size());
    446   for (const auto& element : elements) {
    447     element_shapes.push_back(element.shape());
    448   }
    449   Literal literal(ShapeUtil::MakeTupleShape(element_shapes));
    450   for (int64 i = 0; i < elements.size(); ++i) {
    451     TF_CHECK_OK(
    452         literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i}));
    453   }
    454   return literal;
    455 }
    456 
    457 /* static */ string LiteralUtil::MultiIndexAsString(
    458     absl::Span<const int64> multi_index) {
    459   return StrCat("{", absl::StrJoin(multi_index, ","), "}");
    460 }
    461 
    462 }  // namespace xla
    463