Home | History | Annotate | Download | only in wav
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // Functions to write audio in WAV format.
     17 
     18 #include <math.h>
     19 #include <string.h>
     20 #include <algorithm>
     21 
     22 #include "tensorflow/core/lib/core/casts.h"
     23 #include "tensorflow/core/lib/core/coding.h"
     24 #include "tensorflow/core/lib/core/errors.h"
     25 #include "tensorflow/core/lib/wav/wav_io.h"
     26 #include "tensorflow/core/platform/cpu_info.h"
     27 #include "tensorflow/core/platform/logging.h"
     28 #include "tensorflow/core/platform/macros.h"
     29 
     30 namespace tensorflow {
     31 namespace wav {
     32 namespace {
     33 
     34 struct TF_PACKED RiffChunk {
     35   char chunk_id[4];
     36   char chunk_data_size[4];
     37   char riff_type[4];
     38 };
     39 static_assert(sizeof(RiffChunk) == 12, "TF_PACKED does not work.");
     40 
     41 struct TF_PACKED FormatChunk {
     42   char chunk_id[4];
     43   char chunk_data_size[4];
     44   char compression_code[2];
     45   char channel_numbers[2];
     46   char sample_rate[4];
     47   char bytes_per_second[4];
     48   char bytes_per_frame[2];
     49   char bits_per_sample[2];
     50 };
     51 static_assert(sizeof(FormatChunk) == 24, "TF_PACKED does not work.");
     52 
     53 struct TF_PACKED DataChunk {
     54   char chunk_id[4];
     55   char chunk_data_size[4];
     56 };
     57 static_assert(sizeof(DataChunk) == 8, "TF_PACKED does not work.");
     58 
     59 struct TF_PACKED WavHeader {
     60   RiffChunk riff_chunk;
     61   FormatChunk format_chunk;
     62   DataChunk data_chunk;
     63 };
     64 static_assert(sizeof(WavHeader) ==
     65                   sizeof(RiffChunk) + sizeof(FormatChunk) + sizeof(DataChunk),
     66               "TF_PACKED does not work.");
     67 
     68 constexpr char kRiffChunkId[] = "RIFF";
     69 constexpr char kRiffType[] = "WAVE";
     70 constexpr char kFormatChunkId[] = "fmt ";
     71 constexpr char kDataChunkId[] = "data";
     72 
     73 inline int16 FloatToInt16Sample(float data) {
     74   constexpr float kMultiplier = 1.0f * (1 << 15);
     75   return std::min<float>(std::max<float>(roundf(data * kMultiplier), kint16min),
     76                          kint16max);
     77 }
     78 
     79 inline float Int16SampleToFloat(int16 data) {
     80   constexpr float kMultiplier = 1.0f / (1 << 15);
     81   return data * kMultiplier;
     82 }
     83 
     84 Status ExpectText(const string& data, const string& expected_text,
     85                   int* offset) {
     86   const int new_offset = *offset + expected_text.size();
     87   if (new_offset > data.size()) {
     88     return errors::InvalidArgument("Data too short when trying to read ",
     89                                    expected_text);
     90   }
     91   const string found_text(data.begin() + *offset, data.begin() + new_offset);
     92   if (found_text != expected_text) {
     93     return errors::InvalidArgument("Header mismatch: Expected ", expected_text,
     94                                    " but found ", found_text);
     95   }
     96   *offset = new_offset;
     97   return Status::OK();
     98 }
     99 
    100 template <class T>
    101 Status ReadValue(const string& data, T* value, int* offset) {
    102   const int new_offset = *offset + sizeof(T);
    103   if (new_offset > data.size()) {
    104     return errors::InvalidArgument("Data too short when trying to read value");
    105   }
    106   if (port::kLittleEndian) {
    107     memcpy(value, data.data() + *offset, sizeof(T));
    108   } else {
    109     *value = 0;
    110     const uint8* data_buf =
    111         reinterpret_cast<const uint8*>(data.data() + *offset);
    112     int shift = 0;
    113     for (int i = 0; i < sizeof(T); ++i, shift += 8) {
    114       *value = *value | (data_buf[i] << shift);
    115     }
    116   }
    117   *offset = new_offset;
    118   return Status::OK();
    119 }
    120 
    121 Status ReadString(const string& data, int expected_length, string* value,
    122                   int* offset) {
    123   const int new_offset = *offset + expected_length;
    124   if (new_offset > data.size()) {
    125     return errors::InvalidArgument("Data too short when trying to read string");
    126   }
    127   *value = string(data.begin() + *offset, data.begin() + new_offset);
    128   *offset = new_offset;
    129   return Status::OK();
    130 }
    131 
    132 }  // namespace
    133 
    134 Status EncodeAudioAsS16LEWav(const float* audio, size_t sample_rate,
    135                              size_t num_channels, size_t num_frames,
    136                              string* wav_string) {
    137   constexpr size_t kFormatChunkSize = 16;
    138   constexpr size_t kCompressionCodePcm = 1;
    139   constexpr size_t kBitsPerSample = 16;
    140   constexpr size_t kBytesPerSample = kBitsPerSample / 8;
    141   constexpr size_t kHeaderSize = sizeof(WavHeader);
    142 
    143   if (audio == nullptr) {
    144     return errors::InvalidArgument("audio is null");
    145   }
    146   if (wav_string == nullptr) {
    147     return errors::InvalidArgument("wav_string is null");
    148   }
    149   if (sample_rate == 0 || sample_rate > kuint32max) {
    150     return errors::InvalidArgument("sample_rate must be in (0, 2^32), got: ",
    151                                    sample_rate);
    152   }
    153   if (num_channels == 0 || num_channels > kuint16max) {
    154     return errors::InvalidArgument("num_channels must be in (0, 2^16), got: ",
    155                                    num_channels);
    156   }
    157   if (num_frames == 0) {
    158     return errors::InvalidArgument("num_frames must be positive.");
    159   }
    160 
    161   const size_t bytes_per_second = sample_rate * kBytesPerSample * num_channels;
    162   const size_t num_samples = num_frames * num_channels;
    163   const size_t data_size = num_samples * kBytesPerSample;
    164   const size_t file_size = kHeaderSize + num_samples * kBytesPerSample;
    165   const size_t bytes_per_frame = kBytesPerSample * num_channels;
    166 
    167   // WAV represents the length of the file as a uint32 so file_size cannot
    168   // exceed kuint32max.
    169   if (file_size > kuint32max) {
    170     return errors::InvalidArgument(
    171         "Provided channels and frames cannot be encoded as a WAV.");
    172   }
    173 
    174   wav_string->resize(file_size);
    175   char* data = &wav_string->at(0);
    176   WavHeader* header = bit_cast<WavHeader*>(data);
    177 
    178   // Fill RIFF chunk.
    179   auto* riff_chunk = &header->riff_chunk;
    180   memcpy(riff_chunk->chunk_id, kRiffChunkId, 4);
    181   core::EncodeFixed32(riff_chunk->chunk_data_size, file_size - 8);
    182   memcpy(riff_chunk->riff_type, kRiffType, 4);
    183 
    184   // Fill format chunk.
    185   auto* format_chunk = &header->format_chunk;
    186   memcpy(format_chunk->chunk_id, kFormatChunkId, 4);
    187   core::EncodeFixed32(format_chunk->chunk_data_size, kFormatChunkSize);
    188   core::EncodeFixed16(format_chunk->compression_code, kCompressionCodePcm);
    189   core::EncodeFixed16(format_chunk->channel_numbers, num_channels);
    190   core::EncodeFixed32(format_chunk->sample_rate, sample_rate);
    191   core::EncodeFixed32(format_chunk->bytes_per_second, bytes_per_second);
    192   core::EncodeFixed16(format_chunk->bytes_per_frame, bytes_per_frame);
    193   core::EncodeFixed16(format_chunk->bits_per_sample, kBitsPerSample);
    194 
    195   // Fill data chunk.
    196   auto* data_chunk = &header->data_chunk;
    197   memcpy(data_chunk->chunk_id, kDataChunkId, 4);
    198   core::EncodeFixed32(data_chunk->chunk_data_size, data_size);
    199 
    200   // Write the audio.
    201   data += kHeaderSize;
    202   for (size_t i = 0; i < num_samples; ++i) {
    203     int16 sample = FloatToInt16Sample(audio[i]);
    204     core::EncodeFixed16(&data[i * kBytesPerSample],
    205                         static_cast<uint16>(sample));
    206   }
    207   return Status::OK();
    208 }
    209 
    210 Status DecodeLin16WaveAsFloatVector(const string& wav_string,
    211                                     std::vector<float>* float_values,
    212                                     uint32* sample_count, uint16* channel_count,
    213                                     uint32* sample_rate) {
    214   int offset = 0;
    215   TF_RETURN_IF_ERROR(ExpectText(wav_string, kRiffChunkId, &offset));
    216   uint32 total_file_size;
    217   TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, &total_file_size, &offset));
    218   TF_RETURN_IF_ERROR(ExpectText(wav_string, kRiffType, &offset));
    219   TF_RETURN_IF_ERROR(ExpectText(wav_string, kFormatChunkId, &offset));
    220   uint32 format_chunk_size;
    221   TF_RETURN_IF_ERROR(
    222       ReadValue<uint32>(wav_string, &format_chunk_size, &offset));
    223   if ((format_chunk_size != 16) && (format_chunk_size != 18)) {
    224     return errors::InvalidArgument(
    225         "Bad file size for WAV: Expected 16 or 18, but got", format_chunk_size);
    226   }
    227   uint16 audio_format;
    228   TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, &audio_format, &offset));
    229   if (audio_format != 1) {
    230     return errors::InvalidArgument(
    231         "Bad audio format for WAV: Expected 1 (PCM), but got", audio_format);
    232   }
    233   TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, channel_count, &offset));
    234   TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, sample_rate, &offset));
    235   uint32 bytes_per_second;
    236   TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, &bytes_per_second, &offset));
    237   uint16 bytes_per_sample;
    238   TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, &bytes_per_sample, &offset));
    239   // Confusingly, bits per sample is defined as holding the number of bits for
    240   // one channel, unlike the definition of sample used elsewhere in the WAV
    241   // spec. For example, bytes per sample is the memory needed for all channels
    242   // for one point in time.
    243   uint16 bits_per_sample;
    244   TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, &bits_per_sample, &offset));
    245   if (bits_per_sample != 16) {
    246     return errors::InvalidArgument(
    247         "Can only read 16-bit WAV files, but received ", bits_per_sample);
    248   }
    249   const uint32 expected_bytes_per_sample =
    250       ((bits_per_sample * *channel_count) + 7) / 8;
    251   if (bytes_per_sample != expected_bytes_per_sample) {
    252     return errors::InvalidArgument(
    253         "Bad bytes per sample in WAV header: Expected ",
    254         expected_bytes_per_sample, " but got ", bytes_per_sample);
    255   }
    256   const uint32 expected_bytes_per_second = bytes_per_sample * *sample_rate;
    257   if (bytes_per_second != expected_bytes_per_second) {
    258     return errors::InvalidArgument(
    259         "Bad bytes per second in WAV header: Expected ",
    260         expected_bytes_per_second, " but got ", bytes_per_second,
    261         " (sample_rate=", *sample_rate, ", bytes_per_sample=", bytes_per_sample,
    262         ")");
    263   }
    264   if (format_chunk_size == 18) {
    265     // Skip over this unused section.
    266     offset += 2;
    267   }
    268 
    269   bool was_data_found = false;
    270   while (offset < wav_string.size()) {
    271     string chunk_id;
    272     TF_RETURN_IF_ERROR(ReadString(wav_string, 4, &chunk_id, &offset));
    273     uint32 chunk_size;
    274     TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, &chunk_size, &offset));
    275     if (chunk_id == kDataChunkId) {
    276       if (was_data_found) {
    277         return errors::InvalidArgument("More than one data chunk found in WAV");
    278       }
    279       was_data_found = true;
    280       *sample_count = chunk_size / bytes_per_sample;
    281       const uint32 data_count = *sample_count * *channel_count;
    282       float_values->resize(data_count);
    283       for (int i = 0; i < data_count; ++i) {
    284         int16 single_channel_value = 0;
    285         TF_RETURN_IF_ERROR(
    286             ReadValue<int16>(wav_string, &single_channel_value, &offset));
    287         (*float_values)[i] = Int16SampleToFloat(single_channel_value);
    288       }
    289     } else {
    290       offset += chunk_size;
    291     }
    292   }
    293   if (!was_data_found) {
    294     return errors::InvalidArgument("No data chunk found in WAV");
    295   }
    296   return Status::OK();
    297 }
    298 
    299 }  // namespace wav
    300 }  // namespace tensorflow
    301