Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/kernels/spectrogram_test_utils.h"
     17 
     18 #include <math.h>
     19 #include <stddef.h>
     20 
     21 #include "tensorflow/core/lib/core/status_test_util.h"
     22 #include "tensorflow/core/lib/io/path.h"
     23 #include "tensorflow/core/lib/strings/str_util.h"
     24 #include "tensorflow/core/lib/wav/wav_io.h"
     25 #include "tensorflow/core/platform/env.h"
     26 #include "tensorflow/core/platform/test.h"
     27 #include "tensorflow/core/platform/types.h"
     28 
     29 namespace tensorflow {
     30 
     31 bool ReadWaveFileToVector(const string& file_name, std::vector<double>* data) {
     32   string wav_data;
     33   if (!ReadFileToString(Env::Default(), file_name, &wav_data).ok()) {
     34     LOG(ERROR) << "Wave file read failed for " << file_name;
     35     return false;
     36   }
     37   std::vector<float> decoded_data;
     38   uint32 decoded_sample_count;
     39   uint16 decoded_channel_count;
     40   uint32 decoded_sample_rate;
     41   if (!wav::DecodeLin16WaveAsFloatVector(
     42            wav_data, &decoded_data, &decoded_sample_count,
     43            &decoded_channel_count, &decoded_sample_rate)
     44            .ok()) {
     45     return false;
     46   }
     47   // Convert from float to double for the output value.
     48   data->resize(decoded_data.size());
     49   for (int i = 0; i < decoded_data.size(); ++i) {
     50     (*data)[i] = decoded_data[i];
     51   }
     52   return true;
     53 }
     54 
     55 bool ReadRawFloatFileToComplexVector(
     56     const string& file_name, int row_length,
     57     std::vector<std::vector<std::complex<double> > >* data) {
     58   data->clear();
     59   string data_string;
     60   if (!ReadFileToString(Env::Default(), file_name, &data_string).ok()) {
     61     LOG(ERROR) << "Failed to open file " << file_name;
     62     return false;
     63   }
     64   float real_out;
     65   float imag_out;
     66   const int kBytesPerValue = 4;
     67   CHECK_EQ(sizeof(real_out), kBytesPerValue);
     68   std::vector<std::complex<double> > data_row;
     69   int row_counter = 0;
     70   int offset = 0;
     71   const int end = data_string.size();
     72   while (offset < end) {
     73 #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
     74     char arr[4];
     75     for (int i = 0; i < kBytesPerValue; ++i) {
     76       arr[3 - i] = *(data_string.data() + offset + i);
     77     }
     78     memcpy(&real_out, arr, kBytesPerValue);
     79     offset += kBytesPerValue;
     80     for (int i = 0; i < kBytesPerValue; ++i) {
     81       arr[3 - i] = *(data_string.data() + offset + i);
     82     }
     83     memcpy(&imag_out, arr, kBytesPerValue);
     84     offset += kBytesPerValue;
     85 #else
     86     memcpy(&real_out, data_string.data() + offset, kBytesPerValue);
     87     offset += kBytesPerValue;
     88     memcpy(&imag_out, data_string.data() + offset, kBytesPerValue);
     89     offset += kBytesPerValue;
     90 #endif
     91     if (row_counter >= row_length) {
     92       data->push_back(data_row);
     93       data_row.clear();
     94       row_counter = 0;
     95     }
     96     data_row.push_back(std::complex<double>(real_out, imag_out));
     97     ++row_counter;
     98   }
     99   if (row_counter >= row_length) {
    100     data->push_back(data_row);
    101   }
    102   return true;
    103 }
    104 
    105 void ReadCSVFileToComplexVectorOrDie(
    106     const string& file_name,
    107     std::vector<std::vector<std::complex<double> > >* data) {
    108   data->clear();
    109   string data_string;
    110   if (!ReadFileToString(Env::Default(), file_name, &data_string).ok()) {
    111     LOG(FATAL) << "Failed to open file " << file_name;
    112     return;
    113   }
    114   std::vector<string> lines = str_util::Split(data_string, '\n');
    115   for (const string& line : lines) {
    116     if (line.empty()) {
    117       continue;
    118     }
    119     std::vector<std::complex<double> > data_line;
    120     std::vector<string> values = str_util::Split(line, ',');
    121     for (std::vector<string>::const_iterator i = values.begin();
    122          i != values.end(); ++i) {
    123       // each element of values may be in the form:
    124       // 0.001+0.002i, 0.001, 0.001i, -1.2i, -1.2-3.2i, 1.5, 1.5e-03+21.0i
    125       std::vector<string> parts;
    126       // Find the first instance of + or - after the second character
    127       // in the string, that does not immediately follow an 'e'.
    128       size_t operator_index = i->find_first_of("+-", 2);
    129       if (operator_index < i->size() &&
    130           i->substr(operator_index - 1, 1) == "e") {
    131         operator_index = i->find_first_of("+-", operator_index + 1);
    132       }
    133       parts.push_back(i->substr(0, operator_index));
    134       if (operator_index < i->size()) {
    135         parts.push_back(i->substr(operator_index, string::npos));
    136       }
    137 
    138       double real_part = 0.0;
    139       double imaginary_part = 0.0;
    140       for (std::vector<string>::const_iterator j = parts.begin();
    141            j != parts.end(); ++j) {
    142         if (j->find_first_of("ij") != string::npos) {
    143           strings::safe_strtod((*j).c_str(), &imaginary_part);
    144         } else {
    145           strings::safe_strtod((*j).c_str(), &real_part);
    146         }
    147       }
    148       data_line.push_back(std::complex<double>(real_part, imaginary_part));
    149     }
    150     data->push_back(data_line);
    151   }
    152 }
    153 
    154 void ReadCSVFileToArrayOrDie(const string& filename,
    155                              std::vector<std::vector<float> >* array) {
    156   string contents;
    157   TF_CHECK_OK(ReadFileToString(Env::Default(), filename, &contents));
    158   std::vector<string> lines = str_util::Split(contents, '\n');
    159   contents.clear();
    160 
    161   array->clear();
    162   std::vector<float> values;
    163   for (int l = 0; l < lines.size(); ++l) {
    164     values.clear();
    165     CHECK(str_util::SplitAndParseAsFloats(lines[l], ',', &values));
    166     array->push_back(values);
    167   }
    168 }
    169 
    170 bool WriteDoubleVectorToFile(const string& file_name,
    171                              const std::vector<double>& data) {
    172   std::unique_ptr<WritableFile> file;
    173   if (!Env::Default()->NewWritableFile(file_name, &file).ok()) {
    174     LOG(ERROR) << "Failed to open file " << file_name;
    175     return false;
    176   }
    177   for (int i = 0; i < data.size(); ++i) {
    178     if (!file->Append(StringPiece(reinterpret_cast<const char*>(&(data[i])),
    179                                   sizeof(data[i])))
    180              .ok()) {
    181       LOG(ERROR) << "Failed to append to file " << file_name;
    182       return false;
    183     }
    184   }
    185   if (!file->Close().ok()) {
    186     LOG(ERROR) << "Failed to close file " << file_name;
    187     return false;
    188   }
    189   return true;
    190 }
    191 
    192 bool WriteFloatVectorToFile(const string& file_name,
    193                             const std::vector<float>& data) {
    194   std::unique_ptr<WritableFile> file;
    195   if (!Env::Default()->NewWritableFile(file_name, &file).ok()) {
    196     LOG(ERROR) << "Failed to open file " << file_name;
    197     return false;
    198   }
    199   for (int i = 0; i < data.size(); ++i) {
    200     if (!file->Append(StringPiece(reinterpret_cast<const char*>(&(data[i])),
    201                                   sizeof(data[i])))
    202              .ok()) {
    203       LOG(ERROR) << "Failed to append to file " << file_name;
    204       return false;
    205     }
    206   }
    207   if (!file->Close().ok()) {
    208     LOG(ERROR) << "Failed to close file " << file_name;
    209     return false;
    210   }
    211   return true;
    212 }
    213 
    214 bool WriteDoubleArrayToFile(const string& file_name, int size,
    215                             const double* data) {
    216   std::unique_ptr<WritableFile> file;
    217   if (!Env::Default()->NewWritableFile(file_name, &file).ok()) {
    218     LOG(ERROR) << "Failed to open file " << file_name;
    219     return false;
    220   }
    221   for (int i = 0; i < size; ++i) {
    222     if (!file->Append(StringPiece(reinterpret_cast<const char*>(&(data[i])),
    223                                   sizeof(data[i])))
    224              .ok()) {
    225       LOG(ERROR) << "Failed to append to file " << file_name;
    226       return false;
    227     }
    228   }
    229   if (!file->Close().ok()) {
    230     LOG(ERROR) << "Failed to close file " << file_name;
    231     return false;
    232   }
    233   return true;
    234 }
    235 
    236 bool WriteFloatArrayToFile(const string& file_name, int size,
    237                            const float* data) {
    238   std::unique_ptr<WritableFile> file;
    239   if (!Env::Default()->NewWritableFile(file_name, &file).ok()) {
    240     LOG(ERROR) << "Failed to open file " << file_name;
    241     return false;
    242   }
    243   for (int i = 0; i < size; ++i) {
    244     if (!file->Append(StringPiece(reinterpret_cast<const char*>(&(data[i])),
    245                                   sizeof(data[i])))
    246              .ok()) {
    247       LOG(ERROR) << "Failed to append to file " << file_name;
    248       return false;
    249     }
    250   }
    251   if (!file->Close().ok()) {
    252     LOG(ERROR) << "Failed to close file " << file_name;
    253     return false;
    254   }
    255   return true;
    256 }
    257 
    258 bool WriteComplexVectorToRawFloatFile(
    259     const string& file_name,
    260     const std::vector<std::vector<std::complex<double> > >& data) {
    261   std::unique_ptr<WritableFile> file;
    262   if (!Env::Default()->NewWritableFile(file_name, &file).ok()) {
    263     LOG(ERROR) << "Failed to open file " << file_name;
    264     return false;
    265   }
    266   for (int i = 0; i < data.size(); ++i) {
    267     for (int j = 0; j < data[i].size(); ++j) {
    268       const float real_part(real(data[i][j]));
    269       if (!file->Append(StringPiece(reinterpret_cast<const char*>(&real_part),
    270                                     sizeof(real_part)))
    271                .ok()) {
    272         LOG(ERROR) << "Failed to append to file " << file_name;
    273         return false;
    274       }
    275 
    276       const float imag_part(imag(data[i][j]));
    277       if (!file->Append(StringPiece(reinterpret_cast<const char*>(&imag_part),
    278                                     sizeof(imag_part)))
    279                .ok()) {
    280         LOG(ERROR) << "Failed to append to file " << file_name;
    281         return false;
    282       }
    283     }
    284   }
    285   if (!file->Close().ok()) {
    286     LOG(ERROR) << "Failed to close file " << file_name;
    287     return false;
    288   }
    289   return true;
    290 }
    291 
    292 void SineWave(int sample_rate, float frequency, float duration_seconds,
    293               std::vector<double>* data) {
    294   data->clear();
    295   for (int i = 0; i < static_cast<int>(sample_rate * duration_seconds); ++i) {
    296     data->push_back(
    297         sin(2.0 * M_PI * i * frequency / static_cast<double>(sample_rate)));
    298   }
    299 }
    300 
    301 }  // namespace tensorflow
    302