Source code for zounds.learn.sinclayer

from torch import nn
import torch
import math
from torch.nn import functional as F


[docs]class SincLayer(nn.Module): """ A layer as described in the paper "Speaker Recognition from raw waveform with SincNet" .. epigraph:: This paper proposes a novel CNN architecture, called SincNet, that encourages the first convolutional layer to discover more meaningful filters. SincNet is based on parametrized sinc functions, which implement band-pass filters. In contrast to standard CNNs, that learn all elements of each filter, only low and high cutoff frequencies are directly learned from data with the proposed method. This offers a very compact and efficient way to derive a customized filter bank specifically tuned for the desired application. Our experiments, conducted on both speaker identification and speaker verification tasks, show that the proposed architecture converges faster and performs better than a standard CNN on raw waveforms. -- https://arxiv.org/abs/1808.00158 Args: scale (FrequencyScale): A scale defining the initial bandpass filters taps (int): The length of the filter in samples samplerate (SampleRate): The sampling rate of incoming samples See Also: :class:`~zounds.spectral.FrequencyScale` :class:`~zounds.timeseries.SampleRate` """ def __init__(self, scale, taps, samplerate): super(SincLayer, self).__init__() self.samplerate = int(samplerate) self.taps = taps self.scale = scale # each filter requires two parameters to define the filter bandwidth filter_parameters = torch.FloatTensor(len(scale), 2) self.linear = nn.Parameter( torch.linspace(-math.pi, math.pi, steps=taps), requires_grad=False) self.window = nn.Parameter( torch.hamming_window(self.taps), requires_grad=False) for i, band in enumerate(scale): start = self.samplerate / band.start_hz stop = self.samplerate / band.stop_hz filter_parameters[i, 0] = start filter_parameters[i, 1] = stop self.filter_parameters = nn.Parameter(filter_parameters) def _sinc(self, frequency): x = self.linear[None, :] * frequency[:, None] return torch.sin(x) / x def _start_frequencies(self): return torch.abs(self.filter_parameters[:, 0]) def _stop_frequencies(self): start = self._start_frequencies() return start + torch.abs(self.filter_parameters[:, 1] - start) def _filter_bank(self): start = self._start_frequencies()[:, None] stop = self._stop_frequencies()[:, None] start_sinc = self._sinc(start) stop_sinc = self._sinc(stop) filters = \ (2 * stop[..., None] * stop_sinc) \ - (2 * start[..., None] * start_sinc) windowed = filters * self.window[None, None, :] return windowed.squeeze() def forward(self, x): x = x.view(-1, 1, x.shape[-1]) filters = self._filter_bank().view(len(self.scale), 1, self.taps) x = F.conv1d(x, filters, stride=1, padding=self.taps // 2) return x