Home | History | Annotate | Download | only in dataset
      1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/contrib/ignite/kernels/dataset/ignite_binary_object_parser.h"
     17 #include "tensorflow/core/framework/types.h"
     18 #include "tensorflow/core/lib/core/errors.h"
     19 
     20 namespace tensorflow {
     21 
     22 BinaryObjectParser::BinaryObjectParser() : byte_swapper_(ByteSwapper(false)) {}
     23 
     24 Status BinaryObjectParser::Parse(uint8_t** ptr,
     25                                  std::vector<Tensor>* out_tensors,
     26                                  std::vector<int32_t>* types) const {
     27   uint8_t object_type_id = ParseByte(ptr);
     28 
     29   // Skip non-leaf nodes.
     30   if (object_type_id != WRAPPED_OBJ && object_type_id != COMPLEX_OBJ)
     31     types->push_back(object_type_id);
     32 
     33   switch (object_type_id) {
     34     case BYTE: {
     35       out_tensors->emplace_back(cpu_allocator(), DT_UINT8, TensorShape({}));
     36       out_tensors->back().scalar<uint8>()() = ParseByte(ptr);
     37       break;
     38     }
     39     case SHORT: {
     40       out_tensors->emplace_back(cpu_allocator(), DT_INT16, TensorShape({}));
     41       out_tensors->back().scalar<int16>()() = ParseShort(ptr);
     42       break;
     43     }
     44     case USHORT: {
     45       out_tensors->emplace_back(cpu_allocator(), DT_UINT16, TensorShape({}));
     46       out_tensors->back().scalar<uint16>()() = ParseUnsignedShort(ptr);
     47       break;
     48     }
     49     case INT: {
     50       out_tensors->emplace_back(cpu_allocator(), DT_INT32, TensorShape({}));
     51       out_tensors->back().scalar<int32>()() = ParseInt(ptr);
     52       break;
     53     }
     54     case LONG: {
     55       out_tensors->emplace_back(cpu_allocator(), DT_INT64, TensorShape({}));
     56       out_tensors->back().scalar<int64>()() = ParseLong(ptr);
     57       break;
     58     }
     59     case FLOAT: {
     60       out_tensors->emplace_back(cpu_allocator(), DT_FLOAT, TensorShape({}));
     61       out_tensors->back().scalar<float>()() = ParseFloat(ptr);
     62       break;
     63     }
     64     case DOUBLE: {
     65       out_tensors->emplace_back(cpu_allocator(), DT_DOUBLE, TensorShape({}));
     66       out_tensors->back().scalar<double>()() = ParseDouble(ptr);
     67       break;
     68     }
     69     case BOOL: {
     70       out_tensors->emplace_back(cpu_allocator(), DT_BOOL, TensorShape({}));
     71       out_tensors->back().scalar<bool>()() = ParseBool(ptr);
     72       break;
     73     }
     74     case STRING: {
     75       out_tensors->emplace_back(cpu_allocator(), DT_STRING, TensorShape({}));
     76       out_tensors->back().scalar<string>()() = ParseString(ptr);
     77       break;
     78     }
     79     case DATE: {
     80       out_tensors->emplace_back(cpu_allocator(), DT_INT64, TensorShape({}));
     81       out_tensors->back().scalar<int64>()() = ParseLong(ptr);
     82       break;
     83     }
     84     case BYTE_ARR: {
     85       int32_t length = ParseInt(ptr);
     86       uint8_t* arr = ParseByteArr(ptr, length);
     87       out_tensors->emplace_back(cpu_allocator(), DT_UINT8,
     88                                 TensorShape({length}));
     89       std::copy_n(arr, length, out_tensors->back().flat<uint8>().data());
     90       break;
     91     }
     92     case SHORT_ARR: {
     93       int32_t length = ParseInt(ptr);
     94       int16_t* arr = ParseShortArr(ptr, length);
     95       out_tensors->emplace_back(cpu_allocator(), DT_INT16,
     96                                 TensorShape({length}));
     97       std::copy_n(arr, length, out_tensors->back().flat<int16>().data());
     98       break;
     99     }
    100     case USHORT_ARR: {
    101       int32_t length = ParseInt(ptr);
    102       uint16_t* arr = ParseUnsignedShortArr(ptr, length);
    103       out_tensors->emplace_back(cpu_allocator(), DT_UINT16,
    104                                 TensorShape({length}));
    105       std::copy_n(arr, length, out_tensors->back().flat<uint16>().data());
    106       break;
    107     }
    108     case INT_ARR: {
    109       int32_t length = ParseInt(ptr);
    110       int32_t* arr = ParseIntArr(ptr, length);
    111       out_tensors->emplace_back(cpu_allocator(), DT_INT32,
    112                                 TensorShape({length}));
    113       std::copy_n(arr, length, out_tensors->back().flat<int32>().data());
    114       break;
    115     }
    116     case LONG_ARR: {
    117       int32_t length = ParseInt(ptr);
    118       int64_t* arr = ParseLongArr(ptr, length);
    119       out_tensors->emplace_back(cpu_allocator(), DT_INT64,
    120                                 TensorShape({length}));
    121       std::copy_n(arr, length, out_tensors->back().flat<int64>().data());
    122       break;
    123     }
    124     case FLOAT_ARR: {
    125       int32_t length = ParseInt(ptr);
    126       float* arr = ParseFloatArr(ptr, length);
    127       out_tensors->emplace_back(cpu_allocator(), DT_FLOAT,
    128                                 TensorShape({length}));
    129       std::copy_n(arr, length, out_tensors->back().flat<float>().data());
    130       break;
    131     }
    132     case DOUBLE_ARR: {
    133       int32_t length = ParseInt(ptr);
    134       double* arr = ParseDoubleArr(ptr, length);
    135       out_tensors->emplace_back(cpu_allocator(), DT_DOUBLE,
    136                                 TensorShape({length}));
    137       std::copy_n(arr, length, out_tensors->back().flat<double>().data());
    138       break;
    139     }
    140     case BOOL_ARR: {
    141       int32_t length = ParseInt(ptr);
    142       bool* arr = ParseBoolArr(ptr, length);
    143       out_tensors->emplace_back(cpu_allocator(), DT_BOOL,
    144                                 TensorShape({length}));
    145       std::copy_n(arr, length, out_tensors->back().flat<bool>().data());
    146       break;
    147     }
    148     case STRING_ARR: {
    149       int32_t length = ParseInt(ptr);
    150       out_tensors->emplace_back(cpu_allocator(), DT_STRING,
    151                                 TensorShape({length}));
    152       for (int32_t i = 0; i < length; i++)
    153         out_tensors->back().vec<string>()(i) = ParseString(ptr);
    154       break;
    155     }
    156     case DATE_ARR: {
    157       int32_t length = ParseInt(ptr);
    158       int64_t* arr = ParseLongArr(ptr, length);
    159       out_tensors->emplace_back(cpu_allocator(), DT_INT64,
    160                                 TensorShape({length}));
    161       std::copy_n(arr, length, out_tensors->back().flat<int64>().data());
    162       break;
    163     }
    164     case WRAPPED_OBJ: {
    165       int32_t byte_arr_size = ParseInt(ptr);
    166       TF_RETURN_IF_ERROR(Parse(ptr, out_tensors, types));
    167       int32_t offset = ParseInt(ptr);
    168 
    169       break;
    170     }
    171     case COMPLEX_OBJ: {
    172       uint8_t version = ParseByte(ptr);
    173       int16_t flags = ParseShort(ptr);
    174       int32_t type_id = ParseInt(ptr);
    175       int32_t hash_code = ParseInt(ptr);
    176       int32_t length = ParseInt(ptr);
    177       int32_t schema_id = ParseInt(ptr);
    178       int32_t schema_offset = ParseInt(ptr);
    179 
    180       // 24 is size of header just read.
    181       uint8_t* end = *ptr + schema_offset - 24;
    182       int32_t i = 0;
    183       while (*ptr < end) {
    184         i++;
    185         TF_RETURN_IF_ERROR(Parse(ptr, out_tensors, types));
    186       }
    187 
    188       *ptr += (length - schema_offset);
    189 
    190       break;
    191     }
    192     default: {
    193       return errors::Unknown("Unknowd binary type (type id ",
    194                              (int)object_type_id, ")");
    195     }
    196   }
    197 
    198   return Status::OK();
    199 }
    200 
    201 uint8_t BinaryObjectParser::ParseByte(uint8_t** ptr) const {
    202   uint8_t res = **ptr;
    203   *ptr += 1;
    204 
    205   return res;
    206 }
    207 
    208 int16_t BinaryObjectParser::ParseShort(uint8_t** ptr) const {
    209   int16_t* res = *reinterpret_cast<int16_t**>(ptr);
    210   byte_swapper_.SwapIfRequiredInt16(res);
    211   *ptr += 2;
    212 
    213   return *res;
    214 }
    215 
    216 uint16_t BinaryObjectParser::ParseUnsignedShort(uint8_t** ptr) const {
    217   uint16_t* res = *reinterpret_cast<uint16_t**>(ptr);
    218   byte_swapper_.SwapIfRequiredUnsignedInt16(res);
    219   *ptr += 2;
    220 
    221   return *res;
    222 }
    223 
    224 int32_t BinaryObjectParser::ParseInt(uint8_t** ptr) const {
    225   int32_t* res = *reinterpret_cast<int32_t**>(ptr);
    226   byte_swapper_.SwapIfRequiredInt32(res);
    227   *ptr += 4;
    228 
    229   return *res;
    230 }
    231 
    232 int64_t BinaryObjectParser::ParseLong(uint8_t** ptr) const {
    233   int64_t* res = *reinterpret_cast<int64_t**>(ptr);
    234   byte_swapper_.SwapIfRequiredInt64(res);
    235   *ptr += 8;
    236 
    237   return *res;
    238 }
    239 
    240 float BinaryObjectParser::ParseFloat(uint8_t** ptr) const {
    241   float* res = *reinterpret_cast<float**>(ptr);
    242   byte_swapper_.SwapIfRequiredFloat(res);
    243   *ptr += 4;
    244 
    245   return *res;
    246 }
    247 
    248 double BinaryObjectParser::ParseDouble(uint8_t** ptr) const {
    249   double* res = *reinterpret_cast<double**>(ptr);
    250   byte_swapper_.SwapIfRequiredDouble(res);
    251   *ptr += 8;
    252 
    253   return *res;
    254 }
    255 
    256 bool BinaryObjectParser::ParseBool(uint8_t** ptr) const {
    257   bool res = **reinterpret_cast<bool**>(ptr);
    258   *ptr += 1;
    259 
    260   return res;
    261 }
    262 
    263 string BinaryObjectParser::ParseString(uint8_t** ptr) const {
    264   int32_t length = ParseInt(ptr);
    265   string res(*reinterpret_cast<char**>(ptr), length);
    266   *ptr += length;
    267 
    268   return res;
    269 }
    270 
    271 uint8_t* BinaryObjectParser::ParseByteArr(uint8_t** ptr, int length) const {
    272   uint8_t* res = *reinterpret_cast<uint8_t**>(ptr);
    273   *ptr += length;
    274 
    275   return res;
    276 }
    277 
    278 int16_t* BinaryObjectParser::ParseShortArr(uint8_t** ptr, int length) const {
    279   int16_t* res = *reinterpret_cast<int16_t**>(ptr);
    280   byte_swapper_.SwapIfRequiredInt16Arr(res, length);
    281   *ptr += length * 2;
    282 
    283   return res;
    284 }
    285 
    286 uint16_t* BinaryObjectParser::ParseUnsignedShortArr(uint8_t** ptr,
    287                                                     int length) const {
    288   uint16_t* res = *reinterpret_cast<uint16_t**>(ptr);
    289   byte_swapper_.SwapIfRequiredUnsignedInt16Arr(res, length);
    290   *ptr += length * 2;
    291 
    292   return res;
    293 }
    294 
    295 int32_t* BinaryObjectParser::ParseIntArr(uint8_t** ptr, int length) const {
    296   int32_t* res = *reinterpret_cast<int32_t**>(ptr);
    297   byte_swapper_.SwapIfRequiredInt32Arr(res, length);
    298   *ptr += length * 4;
    299 
    300   return res;
    301 }
    302 
    303 int64_t* BinaryObjectParser::ParseLongArr(uint8_t** ptr, int length) const {
    304   int64_t* res = *reinterpret_cast<int64_t**>(ptr);
    305   byte_swapper_.SwapIfRequiredInt64Arr(res, length);
    306   *ptr += length * 8;
    307 
    308   return res;
    309 }
    310 
    311 float* BinaryObjectParser::ParseFloatArr(uint8_t** ptr, int length) const {
    312   float* res = *reinterpret_cast<float**>(ptr);
    313   byte_swapper_.SwapIfRequiredFloatArr(res, length);
    314   *ptr += length * 4;
    315 
    316   return res;
    317 }
    318 
    319 double* BinaryObjectParser::ParseDoubleArr(uint8_t** ptr, int length) const {
    320   double* res = *reinterpret_cast<double**>(ptr);
    321   byte_swapper_.SwapIfRequiredDoubleArr(res, length);
    322   *ptr += length * 8;
    323 
    324   return res;
    325 }
    326 
    327 bool* BinaryObjectParser::ParseBoolArr(uint8_t** ptr, int length) const {
    328   bool* res = *reinterpret_cast<bool**>(ptr);
    329   *ptr += length;
    330 
    331   return res;
    332 }
    333 
    334 }  // namespace tensorflow
    335