1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/contrib/ignite/kernels/dataset/ignite_binary_object_parser.h" 17 #include "tensorflow/core/framework/types.h" 18 #include "tensorflow/core/lib/core/errors.h" 19 20 namespace tensorflow { 21 22 BinaryObjectParser::BinaryObjectParser() : byte_swapper_(ByteSwapper(false)) {} 23 24 Status BinaryObjectParser::Parse(uint8_t** ptr, 25 std::vector<Tensor>* out_tensors, 26 std::vector<int32_t>* types) const { 27 uint8_t object_type_id = ParseByte(ptr); 28 29 // Skip non-leaf nodes. 30 if (object_type_id != WRAPPED_OBJ && object_type_id != COMPLEX_OBJ) 31 types->push_back(object_type_id); 32 33 switch (object_type_id) { 34 case BYTE: { 35 out_tensors->emplace_back(cpu_allocator(), DT_UINT8, TensorShape({})); 36 out_tensors->back().scalar<uint8>()() = ParseByte(ptr); 37 break; 38 } 39 case SHORT: { 40 out_tensors->emplace_back(cpu_allocator(), DT_INT16, TensorShape({})); 41 out_tensors->back().scalar<int16>()() = ParseShort(ptr); 42 break; 43 } 44 case USHORT: { 45 out_tensors->emplace_back(cpu_allocator(), DT_UINT16, TensorShape({})); 46 out_tensors->back().scalar<uint16>()() = ParseUnsignedShort(ptr); 47 break; 48 } 49 case INT: { 50 out_tensors->emplace_back(cpu_allocator(), DT_INT32, TensorShape({})); 51 out_tensors->back().scalar<int32>()() = ParseInt(ptr); 52 break; 53 } 54 case LONG: { 55 out_tensors->emplace_back(cpu_allocator(), DT_INT64, TensorShape({})); 56 out_tensors->back().scalar<int64>()() = ParseLong(ptr); 57 break; 58 } 59 case FLOAT: { 60 out_tensors->emplace_back(cpu_allocator(), DT_FLOAT, TensorShape({})); 61 out_tensors->back().scalar<float>()() = ParseFloat(ptr); 62 break; 63 } 64 case DOUBLE: { 65 out_tensors->emplace_back(cpu_allocator(), DT_DOUBLE, TensorShape({})); 66 out_tensors->back().scalar<double>()() = ParseDouble(ptr); 67 break; 68 } 69 case BOOL: { 70 out_tensors->emplace_back(cpu_allocator(), DT_BOOL, TensorShape({})); 71 out_tensors->back().scalar<bool>()() = ParseBool(ptr); 72 break; 73 } 74 case STRING: { 75 out_tensors->emplace_back(cpu_allocator(), DT_STRING, TensorShape({})); 76 out_tensors->back().scalar<string>()() = ParseString(ptr); 77 break; 78 } 79 case DATE: { 80 out_tensors->emplace_back(cpu_allocator(), DT_INT64, TensorShape({})); 81 out_tensors->back().scalar<int64>()() = ParseLong(ptr); 82 break; 83 } 84 case BYTE_ARR: { 85 int32_t length = ParseInt(ptr); 86 uint8_t* arr = ParseByteArr(ptr, length); 87 out_tensors->emplace_back(cpu_allocator(), DT_UINT8, 88 TensorShape({length})); 89 std::copy_n(arr, length, out_tensors->back().flat<uint8>().data()); 90 break; 91 } 92 case SHORT_ARR: { 93 int32_t length = ParseInt(ptr); 94 int16_t* arr = ParseShortArr(ptr, length); 95 out_tensors->emplace_back(cpu_allocator(), DT_INT16, 96 TensorShape({length})); 97 std::copy_n(arr, length, out_tensors->back().flat<int16>().data()); 98 break; 99 } 100 case USHORT_ARR: { 101 int32_t length = ParseInt(ptr); 102 uint16_t* arr = ParseUnsignedShortArr(ptr, length); 103 out_tensors->emplace_back(cpu_allocator(), DT_UINT16, 104 TensorShape({length})); 105 std::copy_n(arr, length, out_tensors->back().flat<uint16>().data()); 106 break; 107 } 108 case INT_ARR: { 109 int32_t length = ParseInt(ptr); 110 int32_t* arr = ParseIntArr(ptr, length); 111 out_tensors->emplace_back(cpu_allocator(), DT_INT32, 112 TensorShape({length})); 113 std::copy_n(arr, length, out_tensors->back().flat<int32>().data()); 114 break; 115 } 116 case LONG_ARR: { 117 int32_t length = ParseInt(ptr); 118 int64_t* arr = ParseLongArr(ptr, length); 119 out_tensors->emplace_back(cpu_allocator(), DT_INT64, 120 TensorShape({length})); 121 std::copy_n(arr, length, out_tensors->back().flat<int64>().data()); 122 break; 123 } 124 case FLOAT_ARR: { 125 int32_t length = ParseInt(ptr); 126 float* arr = ParseFloatArr(ptr, length); 127 out_tensors->emplace_back(cpu_allocator(), DT_FLOAT, 128 TensorShape({length})); 129 std::copy_n(arr, length, out_tensors->back().flat<float>().data()); 130 break; 131 } 132 case DOUBLE_ARR: { 133 int32_t length = ParseInt(ptr); 134 double* arr = ParseDoubleArr(ptr, length); 135 out_tensors->emplace_back(cpu_allocator(), DT_DOUBLE, 136 TensorShape({length})); 137 std::copy_n(arr, length, out_tensors->back().flat<double>().data()); 138 break; 139 } 140 case BOOL_ARR: { 141 int32_t length = ParseInt(ptr); 142 bool* arr = ParseBoolArr(ptr, length); 143 out_tensors->emplace_back(cpu_allocator(), DT_BOOL, 144 TensorShape({length})); 145 std::copy_n(arr, length, out_tensors->back().flat<bool>().data()); 146 break; 147 } 148 case STRING_ARR: { 149 int32_t length = ParseInt(ptr); 150 out_tensors->emplace_back(cpu_allocator(), DT_STRING, 151 TensorShape({length})); 152 for (int32_t i = 0; i < length; i++) 153 out_tensors->back().vec<string>()(i) = ParseString(ptr); 154 break; 155 } 156 case DATE_ARR: { 157 int32_t length = ParseInt(ptr); 158 int64_t* arr = ParseLongArr(ptr, length); 159 out_tensors->emplace_back(cpu_allocator(), DT_INT64, 160 TensorShape({length})); 161 std::copy_n(arr, length, out_tensors->back().flat<int64>().data()); 162 break; 163 } 164 case WRAPPED_OBJ: { 165 int32_t byte_arr_size = ParseInt(ptr); 166 TF_RETURN_IF_ERROR(Parse(ptr, out_tensors, types)); 167 int32_t offset = ParseInt(ptr); 168 169 break; 170 } 171 case COMPLEX_OBJ: { 172 uint8_t version = ParseByte(ptr); 173 int16_t flags = ParseShort(ptr); 174 int32_t type_id = ParseInt(ptr); 175 int32_t hash_code = ParseInt(ptr); 176 int32_t length = ParseInt(ptr); 177 int32_t schema_id = ParseInt(ptr); 178 int32_t schema_offset = ParseInt(ptr); 179 180 // 24 is size of header just read. 181 uint8_t* end = *ptr + schema_offset - 24; 182 int32_t i = 0; 183 while (*ptr < end) { 184 i++; 185 TF_RETURN_IF_ERROR(Parse(ptr, out_tensors, types)); 186 } 187 188 *ptr += (length - schema_offset); 189 190 break; 191 } 192 default: { 193 return errors::Unknown("Unknowd binary type (type id ", 194 (int)object_type_id, ")"); 195 } 196 } 197 198 return Status::OK(); 199 } 200 201 uint8_t BinaryObjectParser::ParseByte(uint8_t** ptr) const { 202 uint8_t res = **ptr; 203 *ptr += 1; 204 205 return res; 206 } 207 208 int16_t BinaryObjectParser::ParseShort(uint8_t** ptr) const { 209 int16_t* res = *reinterpret_cast<int16_t**>(ptr); 210 byte_swapper_.SwapIfRequiredInt16(res); 211 *ptr += 2; 212 213 return *res; 214 } 215 216 uint16_t BinaryObjectParser::ParseUnsignedShort(uint8_t** ptr) const { 217 uint16_t* res = *reinterpret_cast<uint16_t**>(ptr); 218 byte_swapper_.SwapIfRequiredUnsignedInt16(res); 219 *ptr += 2; 220 221 return *res; 222 } 223 224 int32_t BinaryObjectParser::ParseInt(uint8_t** ptr) const { 225 int32_t* res = *reinterpret_cast<int32_t**>(ptr); 226 byte_swapper_.SwapIfRequiredInt32(res); 227 *ptr += 4; 228 229 return *res; 230 } 231 232 int64_t BinaryObjectParser::ParseLong(uint8_t** ptr) const { 233 int64_t* res = *reinterpret_cast<int64_t**>(ptr); 234 byte_swapper_.SwapIfRequiredInt64(res); 235 *ptr += 8; 236 237 return *res; 238 } 239 240 float BinaryObjectParser::ParseFloat(uint8_t** ptr) const { 241 float* res = *reinterpret_cast<float**>(ptr); 242 byte_swapper_.SwapIfRequiredFloat(res); 243 *ptr += 4; 244 245 return *res; 246 } 247 248 double BinaryObjectParser::ParseDouble(uint8_t** ptr) const { 249 double* res = *reinterpret_cast<double**>(ptr); 250 byte_swapper_.SwapIfRequiredDouble(res); 251 *ptr += 8; 252 253 return *res; 254 } 255 256 bool BinaryObjectParser::ParseBool(uint8_t** ptr) const { 257 bool res = **reinterpret_cast<bool**>(ptr); 258 *ptr += 1; 259 260 return res; 261 } 262 263 string BinaryObjectParser::ParseString(uint8_t** ptr) const { 264 int32_t length = ParseInt(ptr); 265 string res(*reinterpret_cast<char**>(ptr), length); 266 *ptr += length; 267 268 return res; 269 } 270 271 uint8_t* BinaryObjectParser::ParseByteArr(uint8_t** ptr, int length) const { 272 uint8_t* res = *reinterpret_cast<uint8_t**>(ptr); 273 *ptr += length; 274 275 return res; 276 } 277 278 int16_t* BinaryObjectParser::ParseShortArr(uint8_t** ptr, int length) const { 279 int16_t* res = *reinterpret_cast<int16_t**>(ptr); 280 byte_swapper_.SwapIfRequiredInt16Arr(res, length); 281 *ptr += length * 2; 282 283 return res; 284 } 285 286 uint16_t* BinaryObjectParser::ParseUnsignedShortArr(uint8_t** ptr, 287 int length) const { 288 uint16_t* res = *reinterpret_cast<uint16_t**>(ptr); 289 byte_swapper_.SwapIfRequiredUnsignedInt16Arr(res, length); 290 *ptr += length * 2; 291 292 return res; 293 } 294 295 int32_t* BinaryObjectParser::ParseIntArr(uint8_t** ptr, int length) const { 296 int32_t* res = *reinterpret_cast<int32_t**>(ptr); 297 byte_swapper_.SwapIfRequiredInt32Arr(res, length); 298 *ptr += length * 4; 299 300 return res; 301 } 302 303 int64_t* BinaryObjectParser::ParseLongArr(uint8_t** ptr, int length) const { 304 int64_t* res = *reinterpret_cast<int64_t**>(ptr); 305 byte_swapper_.SwapIfRequiredInt64Arr(res, length); 306 *ptr += length * 8; 307 308 return res; 309 } 310 311 float* BinaryObjectParser::ParseFloatArr(uint8_t** ptr, int length) const { 312 float* res = *reinterpret_cast<float**>(ptr); 313 byte_swapper_.SwapIfRequiredFloatArr(res, length); 314 *ptr += length * 4; 315 316 return res; 317 } 318 319 double* BinaryObjectParser::ParseDoubleArr(uint8_t** ptr, int length) const { 320 double* res = *reinterpret_cast<double**>(ptr); 321 byte_swapper_.SwapIfRequiredDoubleArr(res, length); 322 *ptr += length * 8; 323 324 return res; 325 } 326 327 bool* BinaryObjectParser::ParseBoolArr(uint8_t** ptr, int length) const { 328 bool* res = *reinterpret_cast<bool**>(ptr); 329 *ptr += length; 330 331 return res; 332 } 333 334 } // namespace tensorflow 335