Home | History | Annotate | Download | only in tflite
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/contrib/lite/toco/tflite/types.h"
     16 
     17 namespace toco {
     18 
     19 namespace tflite {
     20 
     21 namespace {
     22 template <ArrayDataType T>
     23 DataBuffer::FlatBufferOffset CopyBuffer(
     24     const Array& array, flatbuffers::FlatBufferBuilder* builder) {
     25   using NativeT = ::toco::DataType<T>;
     26   const auto& src_data = array.GetBuffer<T>().data;
     27   const uint8_t* dst_data = reinterpret_cast<const uint8_t*>(src_data.data());
     28   auto size = src_data.size() * sizeof(NativeT);
     29   return builder->CreateVector(dst_data, size);
     30 }
     31 
     32 template <ArrayDataType T>
     33 void CopyBuffer(const ::tflite::Buffer& buffer, Array* array) {
     34   using NativeT = ::toco::DataType<T>;
     35   auto* src_buffer = buffer.data();
     36   const NativeT* src_data =
     37       reinterpret_cast<const NativeT*>(src_buffer->data());
     38   int num_items = src_buffer->size() / sizeof(NativeT);
     39 
     40   std::vector<NativeT>* dst_data = &array->GetMutableBuffer<T>().data;
     41   for (int i = 0; i < num_items; ++i) {
     42     dst_data->push_back(*src_data);
     43     ++src_data;
     44   }
     45 }
     46 }  // namespace
     47 
     48 ::tflite::TensorType DataType::Serialize(ArrayDataType array_data_type) {
     49   switch (array_data_type) {
     50     case ArrayDataType::kFloat:
     51       return ::tflite::TensorType_FLOAT32;
     52     case ArrayDataType::kInt32:
     53       return ::tflite::TensorType_INT32;
     54     case ArrayDataType::kInt64:
     55       return ::tflite::TensorType_INT64;
     56     case ArrayDataType::kUint8:
     57       return ::tflite::TensorType_UINT8;
     58     case ArrayDataType::kString:
     59       return ::tflite::TensorType_STRING;
     60     default:
     61       // FLOAT32 is filled for unknown data types.
     62       // TODO(ycling): Implement type inference in TF Lite interpreter.
     63       return ::tflite::TensorType_FLOAT32;
     64   }
     65 }
     66 
     67 ArrayDataType DataType::Deserialize(int tensor_type) {
     68   switch (::tflite::TensorType(tensor_type)) {
     69     case ::tflite::TensorType_FLOAT32:
     70       return ArrayDataType::kFloat;
     71     case ::tflite::TensorType_INT32:
     72       return ArrayDataType::kInt32;
     73     case ::tflite::TensorType_INT64:
     74       return ArrayDataType::kInt64;
     75     case ::tflite::TensorType_STRING:
     76       return ArrayDataType::kString;
     77     case ::tflite::TensorType_UINT8:
     78       return ArrayDataType::kUint8;
     79     default:
     80       LOG(FATAL) << "Unhandled tensor type '" << tensor_type << "'.";
     81   }
     82 }
     83 
     84 flatbuffers::Offset<flatbuffers::Vector<uint8_t>> DataBuffer::Serialize(
     85     const Array& array, flatbuffers::FlatBufferBuilder* builder) {
     86   if (!array.buffer) return 0;  // an empty buffer, usually an output.
     87 
     88   switch (array.data_type) {
     89     case ArrayDataType::kFloat:
     90       return CopyBuffer<ArrayDataType::kFloat>(array, builder);
     91     case ArrayDataType::kInt32:
     92       return CopyBuffer<ArrayDataType::kInt32>(array, builder);
     93     case ArrayDataType::kString:
     94       return CopyBuffer<ArrayDataType::kString>(array, builder);
     95     case ArrayDataType::kUint8:
     96       return CopyBuffer<ArrayDataType::kUint8>(array, builder);
     97     default:
     98       LOG(FATAL) << "Unhandled array data type.";
     99   }
    100 }
    101 
    102 void DataBuffer::Deserialize(const ::tflite::Tensor& tensor,
    103                              const ::tflite::Buffer& buffer, Array* array) {
    104   if (tensor.buffer() == 0) return;      // an empty buffer, usually an output.
    105   if (buffer.data() == nullptr) return;  // a non-defined buffer.
    106 
    107   switch (tensor.type()) {
    108     case ::tflite::TensorType_FLOAT32:
    109       return CopyBuffer<ArrayDataType::kFloat>(buffer, array);
    110     case ::tflite::TensorType_INT32:
    111       return CopyBuffer<ArrayDataType::kInt32>(buffer, array);
    112     case ::tflite::TensorType_INT64:
    113       return CopyBuffer<ArrayDataType::kInt64>(buffer, array);
    114     case ::tflite::TensorType_STRING:
    115       return CopyBuffer<ArrayDataType::kString>(buffer, array);
    116     case ::tflite::TensorType_UINT8:
    117       return CopyBuffer<ArrayDataType::kUint8>(buffer, array);
    118     default:
    119       LOG(FATAL) << "Unhandled tensor type.";
    120   }
    121 }
    122 
    123 ::tflite::Padding Padding::Serialize(PaddingType padding_type) {
    124   switch (padding_type) {
    125     case PaddingType::kSame:
    126       return ::tflite::Padding_SAME;
    127     case PaddingType::kValid:
    128       return ::tflite::Padding_VALID;
    129     default:
    130       LOG(FATAL) << "Unhandled padding type.";
    131   }
    132 }
    133 
    134 PaddingType Padding::Deserialize(int padding) {
    135   switch (::tflite::Padding(padding)) {
    136     case ::tflite::Padding_SAME:
    137       return PaddingType::kSame;
    138     case ::tflite::Padding_VALID:
    139       return PaddingType::kValid;
    140     default:
    141       LOG(FATAL) << "Unhandled padding.";
    142   }
    143 }
    144 
    145 ::tflite::ActivationFunctionType ActivationFunction::Serialize(
    146     FusedActivationFunctionType faf_type) {
    147   switch (faf_type) {
    148     case FusedActivationFunctionType::kNone:
    149       return ::tflite::ActivationFunctionType_NONE;
    150     case FusedActivationFunctionType::kRelu:
    151       return ::tflite::ActivationFunctionType_RELU;
    152     case FusedActivationFunctionType::kRelu6:
    153       return ::tflite::ActivationFunctionType_RELU6;
    154     case FusedActivationFunctionType::kRelu1:
    155       return ::tflite::ActivationFunctionType_RELU_N1_TO_1;
    156     default:
    157       LOG(FATAL) << "Unhandled fused activation function type.";
    158   }
    159 }
    160 
    161 FusedActivationFunctionType ActivationFunction::Deserialize(
    162     int activation_function) {
    163   switch (::tflite::ActivationFunctionType(activation_function)) {
    164     case ::tflite::ActivationFunctionType_NONE:
    165       return FusedActivationFunctionType::kNone;
    166     case ::tflite::ActivationFunctionType_RELU:
    167       return FusedActivationFunctionType::kRelu;
    168     case ::tflite::ActivationFunctionType_RELU6:
    169       return FusedActivationFunctionType::kRelu6;
    170     case ::tflite::ActivationFunctionType_RELU_N1_TO_1:
    171       return FusedActivationFunctionType::kRelu1;
    172     default:
    173       LOG(FATAL) << "Unhandled fused activation function type.";
    174   }
    175 }
    176 
    177 }  // namespace tflite
    178 
    179 }  // namespace toco
    180