Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include "tensorflow/core/kernels/spectrogram.h"
     17 
     18 #include <math.h>
     19 
     20 #include "third_party/fft2d/fft.h"
     21 #include "tensorflow/core/lib/core/bits.h"
     22 
     23 namespace tensorflow {
     24 
     25 using std::complex;
     26 
     27 namespace {
     28 // Returns the default Hann window function for the spectrogram.
     29 void GetPeriodicHann(int window_length, std::vector<double>* window) {
     30   // Some platforms don't have M_PI, so define a local constant here.
     31   const double pi = std::atan(1) * 4;
     32   window->resize(window_length);
     33   for (int i = 0; i < window_length; ++i) {
     34     (*window)[i] = 0.5 - 0.5 * cos((2 * pi * i) / window_length);
     35   }
     36 }
     37 }  // namespace
     38 
     39 bool Spectrogram::Initialize(int window_length, int step_length) {
     40   std::vector<double> window;
     41   GetPeriodicHann(window_length, &window);
     42   return Initialize(window, step_length);
     43 }
     44 
     45 bool Spectrogram::Initialize(const std::vector<double>& window,
     46                              int step_length) {
     47   window_length_ = window.size();
     48   window_ = window;  // Copy window.
     49   if (window_length_ < 2) {
     50     LOG(ERROR) << "Window length too short.";
     51     initialized_ = false;
     52     return false;
     53   }
     54 
     55   step_length_ = step_length;
     56   if (step_length_ < 1) {
     57     LOG(ERROR) << "Step length must be positive.";
     58     initialized_ = false;
     59     return false;
     60   }
     61 
     62   fft_length_ = NextPowerOfTwo(window_length_);
     63   CHECK(fft_length_ >= window_length_);
     64   output_frequency_channels_ = 1 + fft_length_ / 2;
     65 
     66   // Allocate 2 more than what rdft needs, so we can rationalize the layout.
     67   fft_input_output_.assign(fft_length_ + 2, 0.0);
     68 
     69   int half_fft_length = fft_length_ / 2;
     70   fft_double_working_area_.assign(half_fft_length, 0.0);
     71   fft_integer_working_area_.assign(2 + static_cast<int>(sqrt(half_fft_length)),
     72                                    0);
     73   // Set flag element to ensure that the working areas are initialized
     74   // on the first call to cdft.  It's redundant given the assign above,
     75   // but keep it as a reminder.
     76   fft_integer_working_area_[0] = 0;
     77   input_queue_.clear();
     78   samples_to_next_step_ = window_length_;
     79   initialized_ = true;
     80   return true;
     81 }
     82 
     83 template <class InputSample, class OutputSample>
     84 bool Spectrogram::ComputeComplexSpectrogram(
     85     const std::vector<InputSample>& input,
     86     std::vector<std::vector<complex<OutputSample>>>* output) {
     87   if (!initialized_) {
     88     LOG(ERROR) << "ComputeComplexSpectrogram() called before successful call "
     89                << "to Initialize().";
     90     return false;
     91   }
     92   CHECK(output);
     93   output->clear();
     94   int input_start = 0;
     95   while (GetNextWindowOfSamples(input, &input_start)) {
     96     DCHECK_EQ(input_queue_.size(), window_length_);
     97     ProcessCoreFFT();  // Processes input_queue_ to fft_input_output_.
     98     // Add a new slice vector onto the output, to save new result to.
     99     output->resize(output->size() + 1);
    100     // Get a reference to the newly added slice to fill in.
    101     auto& spectrogram_slice = output->back();
    102     spectrogram_slice.resize(output_frequency_channels_);
    103     for (int i = 0; i < output_frequency_channels_; ++i) {
    104       // This will convert double to float if it needs to.
    105       spectrogram_slice[i] = complex<OutputSample>(
    106           fft_input_output_[2 * i], fft_input_output_[2 * i + 1]);
    107     }
    108   }
    109   return true;
    110 }
    111 // Instantiate it four ways:
    112 template bool Spectrogram::ComputeComplexSpectrogram(
    113     const std::vector<float>& input, std::vector<std::vector<complex<float>>>*);
    114 template bool Spectrogram::ComputeComplexSpectrogram(
    115     const std::vector<double>& input,
    116     std::vector<std::vector<complex<float>>>*);
    117 template bool Spectrogram::ComputeComplexSpectrogram(
    118     const std::vector<float>& input,
    119     std::vector<std::vector<complex<double>>>*);
    120 template bool Spectrogram::ComputeComplexSpectrogram(
    121     const std::vector<double>& input,
    122     std::vector<std::vector<complex<double>>>*);
    123 
    124 template <class InputSample, class OutputSample>
    125 bool Spectrogram::ComputeSquaredMagnitudeSpectrogram(
    126     const std::vector<InputSample>& input,
    127     std::vector<std::vector<OutputSample>>* output) {
    128   if (!initialized_) {
    129     LOG(ERROR) << "ComputeSquaredMagnitudeSpectrogram() called before "
    130                << "successful call to Initialize().";
    131     return false;
    132   }
    133   CHECK(output);
    134   output->clear();
    135   int input_start = 0;
    136   while (GetNextWindowOfSamples(input, &input_start)) {
    137     DCHECK_EQ(input_queue_.size(), window_length_);
    138     ProcessCoreFFT();  // Processes input_queue_ to fft_input_output_.
    139     // Add a new slice vector onto the output, to save new result to.
    140     output->resize(output->size() + 1);
    141     // Get a reference to the newly added slice to fill in.
    142     auto& spectrogram_slice = output->back();
    143     spectrogram_slice.resize(output_frequency_channels_);
    144     for (int i = 0; i < output_frequency_channels_; ++i) {
    145       // Similar to the Complex case, except storing the norm.
    146       // But the norm function is known to be a performance killer,
    147       // so do it this way with explicit real and imagninary temps.
    148       const double re = fft_input_output_[2 * i];
    149       const double im = fft_input_output_[2 * i + 1];
    150       // Which finally converts double to float if it needs to.
    151       spectrogram_slice[i] = re * re + im * im;
    152     }
    153   }
    154   return true;
    155 }
    156 // Instantiate it four ways:
    157 template bool Spectrogram::ComputeSquaredMagnitudeSpectrogram(
    158     const std::vector<float>& input, std::vector<std::vector<float>>*);
    159 template bool Spectrogram::ComputeSquaredMagnitudeSpectrogram(
    160     const std::vector<double>& input, std::vector<std::vector<float>>*);
    161 template bool Spectrogram::ComputeSquaredMagnitudeSpectrogram(
    162     const std::vector<float>& input, std::vector<std::vector<double>>*);
    163 template bool Spectrogram::ComputeSquaredMagnitudeSpectrogram(
    164     const std::vector<double>& input, std::vector<std::vector<double>>*);
    165 
    166 // Return true if a full window of samples is prepared; manage the queue.
    167 template <class InputSample>
    168 bool Spectrogram::GetNextWindowOfSamples(const std::vector<InputSample>& input,
    169                                          int* input_start) {
    170   auto input_it = input.begin() + *input_start;
    171   int input_remaining = input.end() - input_it;
    172   if (samples_to_next_step_ > input_remaining) {
    173     // Copy in as many samples are left and return false, no full window.
    174     input_queue_.insert(input_queue_.end(), input_it, input.end());
    175     *input_start += input_remaining;  // Increases it to input.size().
    176     samples_to_next_step_ -= input_remaining;
    177     return false;  // Not enough for a full window.
    178   } else {
    179     // Copy just enough into queue to make a new window, then trim the
    180     // front off the queue to make it window-sized.
    181     input_queue_.insert(input_queue_.end(), input_it,
    182                         input_it + samples_to_next_step_);
    183     *input_start += samples_to_next_step_;
    184     input_queue_.erase(
    185         input_queue_.begin(),
    186         input_queue_.begin() + input_queue_.size() - window_length_);
    187     DCHECK_EQ(window_length_, input_queue_.size());
    188     samples_to_next_step_ = step_length_;  // Be ready for next time.
    189     return true;  // Yes, input_queue_ now contains exactly a window-full.
    190   }
    191 }
    192 
    193 void Spectrogram::ProcessCoreFFT() {
    194   for (int j = 0; j < window_length_; ++j) {
    195     fft_input_output_[j] = input_queue_[j] * window_[j];
    196   }
    197   // Zero-pad the rest of the input buffer.
    198   for (int j = window_length_; j < fft_length_; ++j) {
    199     fft_input_output_[j] = 0.0;
    200   }
    201   const int kForwardFFT = 1;  // 1 means forward; -1 reverse.
    202   // This real FFT is a fair amount faster than using cdft here.
    203   rdft(fft_length_, kForwardFFT, &fft_input_output_[0],
    204        &fft_integer_working_area_[0], &fft_double_working_area_[0]);
    205   // Make rdft result look like cdft result;
    206   // unpack the last real value from the first position's imag slot.
    207   fft_input_output_[fft_length_] = fft_input_output_[1];
    208   fft_input_output_[fft_length_ + 1] = 0;
    209   fft_input_output_[1] = 0;
    210 }
    211 
    212 }  // namespace tensorflow
    213