Source code for zounds.learn.spectral

import torch
from torch import nn
from torch.nn import functional as F
from zounds.spectral import morlet_filter_bank, AWeighting, FrequencyDimension
from zounds.core import ArrayWithUnits
from zounds.timeseries import TimeDimension
from .util import batchwise_mean_std_normalization


[docs]class FilterBank(nn.Module): """ A torch module that convolves a 1D input signal with a bank of morlet filters. Args: samplerate (SampleRate): the samplerate of the input signal kernel_size (int): the length in samples of each filter scale (FrequencyScale): a scale whose center frequencies determine the fundamental frequency of each filer scaling_factors (int or list of int): Scaling factors for each band, which determine the time-frequency resolution tradeoff. The number(s) should fall between 0 and 1, with smaller numbers achieving better frequency resolution, and larget numbers better time resolution normalize_filters (bool): When true, ensure that each filter in the bank has unit norm a_weighting (bool): When true, apply a perceptually-motivated weighting of the filters See Also: :class:`~zounds.spectral.AWeighting` :func:`~zounds.spectral.morlet_filter_bank` """ def __init__( self, samplerate, kernel_size, scale, scaling_factors, normalize_filters=True, a_weighting=True): super(FilterBank, self).__init__() self.samplerate = samplerate filter_bank = morlet_filter_bank( samplerate, kernel_size, scale, scaling_factors, normalize=normalize_filters) if a_weighting: filter_bank *= AWeighting() self.scale = scale filter_bank = torch.from_numpy(filter_bank).float() \ .view(len(scale), 1, kernel_size) self.register_buffer('filter_bank', filter_bank) @property def n_bands(self): return len(self.scale) def convolve(self, x): x = x.view(-1, 1, x.shape[-1]) x = F.conv1d( x, self.filter_bank, padding=self.filter_bank.shape[-1] // 2) return x def transposed_convolve(self, x): x = F.conv_transpose1d( x, self.filter_bank, padding=self.filter_bank.shape[-1] // 2) return x def log_magnitude(self, x): x = F.relu(x) x = 20 * torch.log10(1 + x) return x def temporal_pooling(self, x, kernel_size, stride): x = F.avg_pool1d(x, kernel_size, stride, padding=kernel_size // 2) return x def transform(self, samples, pooling_kernel_size, pooling_stride): # convert the raw audio samples to a PyTorch tensor tensor_samples = torch.from_numpy(samples).float() \ .to(self.filter_bank.device) # compute the transform spectral = self.convolve(tensor_samples) log_magnitude = self.log_magnitude(spectral) pooled = self.temporal_pooling( log_magnitude, pooling_kernel_size, pooling_stride) # convert back to an ArrayWithUnits instance samplerate = samples.samplerate time_frequency = pooled.data.cpu().numpy().squeeze().T time_frequency = ArrayWithUnits(time_frequency, [ TimeDimension( frequency=samplerate.frequency * pooling_stride, duration=samplerate.frequency * pooling_kernel_size), FrequencyDimension(self.scale) ]) return time_frequency def forward(self, x, normalize=True): nsamples = x.shape[-1] x = self.convolve(x) x = self.log_magnitude(x) if normalize: x = batchwise_mean_std_normalization(x) return x[..., :nsamples].contiguous()