1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Functions to write audio in WAV format. 17 18 #include <math.h> 19 #include <string.h> 20 #include <algorithm> 21 22 #include "tensorflow/core/lib/core/casts.h" 23 #include "tensorflow/core/lib/core/coding.h" 24 #include "tensorflow/core/lib/core/errors.h" 25 #include "tensorflow/core/lib/wav/wav_io.h" 26 #include "tensorflow/core/platform/cpu_info.h" 27 #include "tensorflow/core/platform/logging.h" 28 #include "tensorflow/core/platform/macros.h" 29 30 namespace tensorflow { 31 namespace wav { 32 namespace { 33 34 struct TF_PACKED RiffChunk { 35 char chunk_id[4]; 36 char chunk_data_size[4]; 37 char riff_type[4]; 38 }; 39 static_assert(sizeof(RiffChunk) == 12, "TF_PACKED does not work."); 40 41 struct TF_PACKED FormatChunk { 42 char chunk_id[4]; 43 char chunk_data_size[4]; 44 char compression_code[2]; 45 char channel_numbers[2]; 46 char sample_rate[4]; 47 char bytes_per_second[4]; 48 char bytes_per_frame[2]; 49 char bits_per_sample[2]; 50 }; 51 static_assert(sizeof(FormatChunk) == 24, "TF_PACKED does not work."); 52 53 struct TF_PACKED DataChunk { 54 char chunk_id[4]; 55 char chunk_data_size[4]; 56 }; 57 static_assert(sizeof(DataChunk) == 8, "TF_PACKED does not work."); 58 59 struct TF_PACKED WavHeader { 60 RiffChunk riff_chunk; 61 FormatChunk format_chunk; 62 DataChunk data_chunk; 63 }; 64 static_assert(sizeof(WavHeader) == 65 sizeof(RiffChunk) + sizeof(FormatChunk) + sizeof(DataChunk), 66 "TF_PACKED does not work."); 67 68 constexpr char kRiffChunkId[] = "RIFF"; 69 constexpr char kRiffType[] = "WAVE"; 70 constexpr char kFormatChunkId[] = "fmt "; 71 constexpr char kDataChunkId[] = "data"; 72 73 inline int16 FloatToInt16Sample(float data) { 74 constexpr float kMultiplier = 1.0f * (1 << 15); 75 return std::min<float>(std::max<float>(roundf(data * kMultiplier), kint16min), 76 kint16max); 77 } 78 79 inline float Int16SampleToFloat(int16 data) { 80 constexpr float kMultiplier = 1.0f / (1 << 15); 81 return data * kMultiplier; 82 } 83 84 Status ExpectText(const string& data, const string& expected_text, 85 int* offset) { 86 const int new_offset = *offset + expected_text.size(); 87 if (new_offset > data.size()) { 88 return errors::InvalidArgument("Data too short when trying to read ", 89 expected_text); 90 } 91 const string found_text(data.begin() + *offset, data.begin() + new_offset); 92 if (found_text != expected_text) { 93 return errors::InvalidArgument("Header mismatch: Expected ", expected_text, 94 " but found ", found_text); 95 } 96 *offset = new_offset; 97 return Status::OK(); 98 } 99 100 template <class T> 101 Status ReadValue(const string& data, T* value, int* offset) { 102 const int new_offset = *offset + sizeof(T); 103 if (new_offset > data.size()) { 104 return errors::InvalidArgument("Data too short when trying to read value"); 105 } 106 if (port::kLittleEndian) { 107 memcpy(value, data.data() + *offset, sizeof(T)); 108 } else { 109 *value = 0; 110 const uint8* data_buf = 111 reinterpret_cast<const uint8*>(data.data() + *offset); 112 int shift = 0; 113 for (int i = 0; i < sizeof(T); ++i, shift += 8) { 114 *value = *value | (data_buf[i] << shift); 115 } 116 } 117 *offset = new_offset; 118 return Status::OK(); 119 } 120 121 Status ReadString(const string& data, int expected_length, string* value, 122 int* offset) { 123 const int new_offset = *offset + expected_length; 124 if (new_offset > data.size()) { 125 return errors::InvalidArgument("Data too short when trying to read string"); 126 } 127 *value = string(data.begin() + *offset, data.begin() + new_offset); 128 *offset = new_offset; 129 return Status::OK(); 130 } 131 132 } // namespace 133 134 Status EncodeAudioAsS16LEWav(const float* audio, size_t sample_rate, 135 size_t num_channels, size_t num_frames, 136 string* wav_string) { 137 constexpr size_t kFormatChunkSize = 16; 138 constexpr size_t kCompressionCodePcm = 1; 139 constexpr size_t kBitsPerSample = 16; 140 constexpr size_t kBytesPerSample = kBitsPerSample / 8; 141 constexpr size_t kHeaderSize = sizeof(WavHeader); 142 143 if (audio == nullptr) { 144 return errors::InvalidArgument("audio is null"); 145 } 146 if (wav_string == nullptr) { 147 return errors::InvalidArgument("wav_string is null"); 148 } 149 if (sample_rate == 0 || sample_rate > kuint32max) { 150 return errors::InvalidArgument("sample_rate must be in (0, 2^32), got: ", 151 sample_rate); 152 } 153 if (num_channels == 0 || num_channels > kuint16max) { 154 return errors::InvalidArgument("num_channels must be in (0, 2^16), got: ", 155 num_channels); 156 } 157 if (num_frames == 0) { 158 return errors::InvalidArgument("num_frames must be positive."); 159 } 160 161 const size_t bytes_per_second = sample_rate * kBytesPerSample * num_channels; 162 const size_t num_samples = num_frames * num_channels; 163 const size_t data_size = num_samples * kBytesPerSample; 164 const size_t file_size = kHeaderSize + num_samples * kBytesPerSample; 165 const size_t bytes_per_frame = kBytesPerSample * num_channels; 166 167 // WAV represents the length of the file as a uint32 so file_size cannot 168 // exceed kuint32max. 169 if (file_size > kuint32max) { 170 return errors::InvalidArgument( 171 "Provided channels and frames cannot be encoded as a WAV."); 172 } 173 174 wav_string->resize(file_size); 175 char* data = &wav_string->at(0); 176 WavHeader* header = bit_cast<WavHeader*>(data); 177 178 // Fill RIFF chunk. 179 auto* riff_chunk = &header->riff_chunk; 180 memcpy(riff_chunk->chunk_id, kRiffChunkId, 4); 181 core::EncodeFixed32(riff_chunk->chunk_data_size, file_size - 8); 182 memcpy(riff_chunk->riff_type, kRiffType, 4); 183 184 // Fill format chunk. 185 auto* format_chunk = &header->format_chunk; 186 memcpy(format_chunk->chunk_id, kFormatChunkId, 4); 187 core::EncodeFixed32(format_chunk->chunk_data_size, kFormatChunkSize); 188 core::EncodeFixed16(format_chunk->compression_code, kCompressionCodePcm); 189 core::EncodeFixed16(format_chunk->channel_numbers, num_channels); 190 core::EncodeFixed32(format_chunk->sample_rate, sample_rate); 191 core::EncodeFixed32(format_chunk->bytes_per_second, bytes_per_second); 192 core::EncodeFixed16(format_chunk->bytes_per_frame, bytes_per_frame); 193 core::EncodeFixed16(format_chunk->bits_per_sample, kBitsPerSample); 194 195 // Fill data chunk. 196 auto* data_chunk = &header->data_chunk; 197 memcpy(data_chunk->chunk_id, kDataChunkId, 4); 198 core::EncodeFixed32(data_chunk->chunk_data_size, data_size); 199 200 // Write the audio. 201 data += kHeaderSize; 202 for (size_t i = 0; i < num_samples; ++i) { 203 int16 sample = FloatToInt16Sample(audio[i]); 204 core::EncodeFixed16(&data[i * kBytesPerSample], 205 static_cast<uint16>(sample)); 206 } 207 return Status::OK(); 208 } 209 210 Status DecodeLin16WaveAsFloatVector(const string& wav_string, 211 std::vector<float>* float_values, 212 uint32* sample_count, uint16* channel_count, 213 uint32* sample_rate) { 214 int offset = 0; 215 TF_RETURN_IF_ERROR(ExpectText(wav_string, kRiffChunkId, &offset)); 216 uint32 total_file_size; 217 TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, &total_file_size, &offset)); 218 TF_RETURN_IF_ERROR(ExpectText(wav_string, kRiffType, &offset)); 219 TF_RETURN_IF_ERROR(ExpectText(wav_string, kFormatChunkId, &offset)); 220 uint32 format_chunk_size; 221 TF_RETURN_IF_ERROR( 222 ReadValue<uint32>(wav_string, &format_chunk_size, &offset)); 223 if ((format_chunk_size != 16) && (format_chunk_size != 18)) { 224 return errors::InvalidArgument( 225 "Bad file size for WAV: Expected 16 or 18, but got", format_chunk_size); 226 } 227 uint16 audio_format; 228 TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, &audio_format, &offset)); 229 if (audio_format != 1) { 230 return errors::InvalidArgument( 231 "Bad audio format for WAV: Expected 1 (PCM), but got", audio_format); 232 } 233 TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, channel_count, &offset)); 234 TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, sample_rate, &offset)); 235 uint32 bytes_per_second; 236 TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, &bytes_per_second, &offset)); 237 uint16 bytes_per_sample; 238 TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, &bytes_per_sample, &offset)); 239 // Confusingly, bits per sample is defined as holding the number of bits for 240 // one channel, unlike the definition of sample used elsewhere in the WAV 241 // spec. For example, bytes per sample is the memory needed for all channels 242 // for one point in time. 243 uint16 bits_per_sample; 244 TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, &bits_per_sample, &offset)); 245 if (bits_per_sample != 16) { 246 return errors::InvalidArgument( 247 "Can only read 16-bit WAV files, but received ", bits_per_sample); 248 } 249 const uint32 expected_bytes_per_sample = 250 ((bits_per_sample * *channel_count) + 7) / 8; 251 if (bytes_per_sample != expected_bytes_per_sample) { 252 return errors::InvalidArgument( 253 "Bad bytes per sample in WAV header: Expected ", 254 expected_bytes_per_sample, " but got ", bytes_per_sample); 255 } 256 const uint32 expected_bytes_per_second = bytes_per_sample * *sample_rate; 257 if (bytes_per_second != expected_bytes_per_second) { 258 return errors::InvalidArgument( 259 "Bad bytes per second in WAV header: Expected ", 260 expected_bytes_per_second, " but got ", bytes_per_second, 261 " (sample_rate=", *sample_rate, ", bytes_per_sample=", bytes_per_sample, 262 ")"); 263 } 264 if (format_chunk_size == 18) { 265 // Skip over this unused section. 266 offset += 2; 267 } 268 269 bool was_data_found = false; 270 while (offset < wav_string.size()) { 271 string chunk_id; 272 TF_RETURN_IF_ERROR(ReadString(wav_string, 4, &chunk_id, &offset)); 273 uint32 chunk_size; 274 TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, &chunk_size, &offset)); 275 if (chunk_id == kDataChunkId) { 276 if (was_data_found) { 277 return errors::InvalidArgument("More than one data chunk found in WAV"); 278 } 279 was_data_found = true; 280 *sample_count = chunk_size / bytes_per_sample; 281 const uint32 data_count = *sample_count * *channel_count; 282 float_values->resize(data_count); 283 for (int i = 0; i < data_count; ++i) { 284 int16 single_channel_value = 0; 285 TF_RETURN_IF_ERROR( 286 ReadValue<int16>(wav_string, &single_channel_value, &offset)); 287 (*float_values)[i] = Int16SampleToFloat(single_channel_value); 288 } 289 } else { 290 offset += chunk_size; 291 } 292 } 293 if (!was_data_found) { 294 return errors::InvalidArgument("No data chunk found in WAV"); 295 } 296 return Status::OK(); 297 } 298 299 } // namespace wav 300 } // namespace tensorflow 301