Home | History | Annotate | Download | only in lib
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/compiler/tf2xla/lib/util.h"
     17 
     18 #include <memory>
     19 #include <vector>
     20 
     21 #include "tensorflow/compiler/xla/literal_util.h"
     22 #include "tensorflow/compiler/xla/shape_util.h"
     23 #include "tensorflow/compiler/xla/status_macros.h"
     24 #include "tensorflow/compiler/xla/statusor.h"
     25 #include "tensorflow/compiler/xla/util.h"
     26 #include "tensorflow/core/lib/core/errors.h"
     27 
     28 namespace tensorflow {
     29 
     30 xla::ComputationDataHandle Zeros(xla::ComputationBuilder* builder,
     31                                  const xla::Shape& shape) {
     32   return builder->Broadcast(
     33       builder->ConstantLiteral(xla::Literal::Zero(shape.element_type())),
     34       xla::AsInt64Slice(shape.dimensions()));
     35 }
     36 
     37 xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder,
     38                                         xla::PrimitiveType type, double value) {
     39   switch (type) {
     40     case xla::F16:
     41       return builder->ConstantR0<xla::half>(static_cast<xla::half>(value));
     42       break;
     43     case xla::BF16:
     44       return builder->ConstantR0<bfloat16>(static_cast<bfloat16>(value));
     45       break;
     46     case xla::F32:
     47       return builder->ConstantR0<float>(static_cast<float>(value));
     48       break;
     49     case xla::F64:
     50       return builder->ConstantR0<double>(value);
     51       break;
     52     case xla::C64:
     53       return builder->ConstantR0<xla::complex64>(value);
     54       break;
     55     default:
     56       LOG(FATAL) << "unhandled element type " << type;
     57   }
     58 }
     59 
     60 xla::ComputationDataHandle IntegerLiteral(xla::ComputationBuilder* builder,
     61                                           xla::PrimitiveType type,
     62                                           int64 value) {
     63   xla::Literal literal;
     64   switch (type) {
     65     case xla::U8:
     66       literal = std::move(*xla::Literal::CreateR0<uint8>(value));
     67       break;
     68     case xla::U32:
     69       literal = std::move(*xla::Literal::CreateR0<uint32>(value));
     70       break;
     71     case xla::U64:
     72       literal = std::move(*xla::Literal::CreateR0<uint64>(value));
     73       break;
     74     case xla::S8:
     75       literal = std::move(*xla::Literal::CreateR0<int8>(value));
     76       break;
     77     case xla::S32:
     78       literal = std::move(*xla::Literal::CreateR0<int32>(value));
     79       break;
     80     case xla::S64:
     81       literal = std::move(*xla::Literal::CreateR0<int64>(value));
     82       break;
     83     case xla::F32:
     84       literal = std::move(*xla::Literal::CreateR0<float>(value));
     85       break;
     86     case xla::F64:
     87       literal = std::move(*xla::Literal::CreateR0<double>(value));
     88       break;
     89     case xla::C64:
     90       literal = std::move(*xla::Literal::CreateR0<complex64>(value));
     91       break;
     92     case xla::PRED:
     93       LOG(FATAL) << "pred element type is not integral";
     94     case xla::S16:
     95     case xla::U16:
     96       LOG(FATAL) << "u16/s16 literals not yet implemented";
     97     case xla::BF16:
     98       literal = std::move(
     99           *xla::Literal::CreateR0<bfloat16>(static_cast<bfloat16>(value)));
    100       break;
    101     case xla::F16:
    102       literal = std::move(
    103           *xla::Literal::CreateR0<xla::half>(static_cast<xla::half>(value)));
    104       break;
    105     case xla::TUPLE:
    106       LOG(FATAL) << "tuple element type is not integral";
    107     case xla::OPAQUE:
    108       LOG(FATAL) << "opaque element type is not integral";
    109     default:
    110       LOG(FATAL) << "unhandled element type " << type;
    111   }
    112   return builder->ConstantLiteral(literal);
    113 }
    114 
    115 xla::StatusOr<xla::ComputationDataHandle> SliceInMinorDims(
    116     xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
    117     gtl::ArraySlice<int64> start, gtl::ArraySlice<int64> end) {
    118   TF_RET_CHECK(start.size() == end.size());
    119   int64 n_minor_dims = start.size();
    120 
    121   TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> shape, builder->GetShape(x));
    122 
    123   const int64 n_dims = xla::ShapeUtil::Rank(*shape);
    124   TF_RET_CHECK(n_minor_dims <= n_dims);
    125   gtl::ArraySlice<int64> major_dims(xla::AsInt64Slice(shape->dimensions()),
    126                                     /*pos=*/0,
    127                                     /*len=*/n_dims - n_minor_dims);
    128 
    129   // Prepends 0s in the major dim
    130   std::vector<int64> padded_start(n_dims, 0);
    131   std::copy(start.begin(), start.end(),
    132             padded_start.begin() + major_dims.size());
    133 
    134   // Prepends the shape of the major dims.
    135   std::vector<int64> padded_end(n_dims);
    136   std::copy(major_dims.begin(), major_dims.end(), padded_end.begin());
    137   std::copy(end.begin(), end.end(), padded_end.begin() + major_dims.size());
    138 
    139   std::vector<int64> strides(n_dims, 1);
    140   return builder->Slice(x, padded_start, padded_end, strides);
    141 }
    142 
    143 xla::StatusOr<xla::ComputationDataHandle> UpdateSlice(
    144     xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
    145     const xla::ComputationDataHandle& update, gtl::ArraySlice<int64> start) {
    146   // TODO(phawkins): make int64 work on all backends, remove the int32 cast.
    147   std::vector<int32> start_as_int32(start.begin(), start.end());
    148   return builder->DynamicUpdateSlice(
    149       x, update, builder->ConstantR1<int32>(start_as_int32));
    150 }
    151 
    152 xla::StatusOr<xla::ComputationDataHandle> UpdateSliceInMinorDims(
    153     xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
    154     const xla::ComputationDataHandle& update, gtl::ArraySlice<int64> start) {
    155   TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> shape, builder->GetShape(x));
    156   const int64 n_dims = xla::ShapeUtil::Rank(*shape);
    157   const int64 n_minor_dims = start.size();
    158   TF_RET_CHECK(n_minor_dims <= n_dims);
    159   std::vector<int64> padded_start(n_dims, 0);
    160   std::copy(start.begin(), start.end(),
    161             padded_start.begin() + (n_dims - n_minor_dims));
    162   return UpdateSlice(builder, x, update, padded_start);
    163 }
    164 
    165 xla::StatusOr<xla::ComputationDataHandle> TransposeInMinorDims(
    166     xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x) {
    167   TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> shape, builder->GetShape(x));
    168   const int64 n_dims = xla::ShapeUtil::Rank(*shape);
    169   TF_RET_CHECK(n_dims >= 2);
    170   std::vector<int64> permutation(n_dims);
    171   std::iota(permutation.begin(), permutation.end(), 0);
    172   std::swap(permutation[n_dims - 1], permutation[n_dims - 2]);
    173   return builder->Transpose(x, permutation);
    174 }
    175 
    176 }  // namespace tensorflow
    177