1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #include "tensorflow/contrib/lite/toco/tflite/types.h" 16 17 namespace toco { 18 19 namespace tflite { 20 21 namespace { 22 template <ArrayDataType T> 23 DataBuffer::FlatBufferOffset CopyBuffer( 24 const Array& array, flatbuffers::FlatBufferBuilder* builder) { 25 using NativeT = ::toco::DataType<T>; 26 const auto& src_data = array.GetBuffer<T>().data; 27 const uint8_t* dst_data = reinterpret_cast<const uint8_t*>(src_data.data()); 28 auto size = src_data.size() * sizeof(NativeT); 29 return builder->CreateVector(dst_data, size); 30 } 31 32 template <ArrayDataType T> 33 void CopyBuffer(const ::tflite::Buffer& buffer, Array* array) { 34 using NativeT = ::toco::DataType<T>; 35 auto* src_buffer = buffer.data(); 36 const NativeT* src_data = 37 reinterpret_cast<const NativeT*>(src_buffer->data()); 38 int num_items = src_buffer->size() / sizeof(NativeT); 39 40 std::vector<NativeT>* dst_data = &array->GetMutableBuffer<T>().data; 41 for (int i = 0; i < num_items; ++i) { 42 dst_data->push_back(*src_data); 43 ++src_data; 44 } 45 } 46 } // namespace 47 48 ::tflite::TensorType DataType::Serialize(ArrayDataType array_data_type) { 49 switch (array_data_type) { 50 case ArrayDataType::kFloat: 51 return ::tflite::TensorType_FLOAT32; 52 case ArrayDataType::kInt32: 53 return ::tflite::TensorType_INT32; 54 case ArrayDataType::kInt64: 55 return ::tflite::TensorType_INT64; 56 case ArrayDataType::kUint8: 57 return ::tflite::TensorType_UINT8; 58 case ArrayDataType::kString: 59 return ::tflite::TensorType_STRING; 60 default: 61 // FLOAT32 is filled for unknown data types. 62 // TODO(ycling): Implement type inference in TF Lite interpreter. 63 return ::tflite::TensorType_FLOAT32; 64 } 65 } 66 67 ArrayDataType DataType::Deserialize(int tensor_type) { 68 switch (::tflite::TensorType(tensor_type)) { 69 case ::tflite::TensorType_FLOAT32: 70 return ArrayDataType::kFloat; 71 case ::tflite::TensorType_INT32: 72 return ArrayDataType::kInt32; 73 case ::tflite::TensorType_INT64: 74 return ArrayDataType::kInt64; 75 case ::tflite::TensorType_STRING: 76 return ArrayDataType::kString; 77 case ::tflite::TensorType_UINT8: 78 return ArrayDataType::kUint8; 79 default: 80 LOG(FATAL) << "Unhandled tensor type '" << tensor_type << "'."; 81 } 82 } 83 84 flatbuffers::Offset<flatbuffers::Vector<uint8_t>> DataBuffer::Serialize( 85 const Array& array, flatbuffers::FlatBufferBuilder* builder) { 86 if (!array.buffer) return 0; // an empty buffer, usually an output. 87 88 switch (array.data_type) { 89 case ArrayDataType::kFloat: 90 return CopyBuffer<ArrayDataType::kFloat>(array, builder); 91 case ArrayDataType::kInt32: 92 return CopyBuffer<ArrayDataType::kInt32>(array, builder); 93 case ArrayDataType::kString: 94 return CopyBuffer<ArrayDataType::kString>(array, builder); 95 case ArrayDataType::kUint8: 96 return CopyBuffer<ArrayDataType::kUint8>(array, builder); 97 default: 98 LOG(FATAL) << "Unhandled array data type."; 99 } 100 } 101 102 void DataBuffer::Deserialize(const ::tflite::Tensor& tensor, 103 const ::tflite::Buffer& buffer, Array* array) { 104 if (tensor.buffer() == 0) return; // an empty buffer, usually an output. 105 if (buffer.data() == nullptr) return; // a non-defined buffer. 106 107 switch (tensor.type()) { 108 case ::tflite::TensorType_FLOAT32: 109 return CopyBuffer<ArrayDataType::kFloat>(buffer, array); 110 case ::tflite::TensorType_INT32: 111 return CopyBuffer<ArrayDataType::kInt32>(buffer, array); 112 case ::tflite::TensorType_INT64: 113 return CopyBuffer<ArrayDataType::kInt64>(buffer, array); 114 case ::tflite::TensorType_STRING: 115 return CopyBuffer<ArrayDataType::kString>(buffer, array); 116 case ::tflite::TensorType_UINT8: 117 return CopyBuffer<ArrayDataType::kUint8>(buffer, array); 118 default: 119 LOG(FATAL) << "Unhandled tensor type."; 120 } 121 } 122 123 ::tflite::Padding Padding::Serialize(PaddingType padding_type) { 124 switch (padding_type) { 125 case PaddingType::kSame: 126 return ::tflite::Padding_SAME; 127 case PaddingType::kValid: 128 return ::tflite::Padding_VALID; 129 default: 130 LOG(FATAL) << "Unhandled padding type."; 131 } 132 } 133 134 PaddingType Padding::Deserialize(int padding) { 135 switch (::tflite::Padding(padding)) { 136 case ::tflite::Padding_SAME: 137 return PaddingType::kSame; 138 case ::tflite::Padding_VALID: 139 return PaddingType::kValid; 140 default: 141 LOG(FATAL) << "Unhandled padding."; 142 } 143 } 144 145 ::tflite::ActivationFunctionType ActivationFunction::Serialize( 146 FusedActivationFunctionType faf_type) { 147 switch (faf_type) { 148 case FusedActivationFunctionType::kNone: 149 return ::tflite::ActivationFunctionType_NONE; 150 case FusedActivationFunctionType::kRelu: 151 return ::tflite::ActivationFunctionType_RELU; 152 case FusedActivationFunctionType::kRelu6: 153 return ::tflite::ActivationFunctionType_RELU6; 154 case FusedActivationFunctionType::kRelu1: 155 return ::tflite::ActivationFunctionType_RELU_N1_TO_1; 156 default: 157 LOG(FATAL) << "Unhandled fused activation function type."; 158 } 159 } 160 161 FusedActivationFunctionType ActivationFunction::Deserialize( 162 int activation_function) { 163 switch (::tflite::ActivationFunctionType(activation_function)) { 164 case ::tflite::ActivationFunctionType_NONE: 165 return FusedActivationFunctionType::kNone; 166 case ::tflite::ActivationFunctionType_RELU: 167 return FusedActivationFunctionType::kRelu; 168 case ::tflite::ActivationFunctionType_RELU6: 169 return FusedActivationFunctionType::kRelu6; 170 case ::tflite::ActivationFunctionType_RELU_N1_TO_1: 171 return FusedActivationFunctionType::kRelu1; 172 default: 173 LOG(FATAL) << "Unhandled fused activation function type."; 174 } 175 } 176 177 } // namespace tflite 178 179 } // namespace toco 180