Source code for zounds.learn.loss

import torch
from scipy.signal import gaussian
from torch import nn
from torch.autograd import Variable
from torch.nn import functional as F

from .dct_transform import DctTransform
from zounds.spectral import fir_filter_bank
from zounds.timeseries import SampleRate

[docs]class PerceptualLoss(nn.Module): """ `PerceptualLoss` computes loss/distance in a feature space that roughly approximates early stages of the human audio processing pipeline, instead of computing raw sample loss. It decomposes a 1D (audio) signal into frequency bands using an FIR filter bank whose frequencies are centered according to a user-defined scale, performs half-wave rectification, puts amplitudes on a log scale, and finally optionally applies a re-weighting of frequency bands. Args: scale (FrequencyScale): a scale defining frequencies at which the FIR filters will be centered samplerate (SampleRate): samplerate needed to construct the FIR filter bank frequency_window (ndarray): window determining how narrow or wide filter responses should be basis_size (int): The kernel size, or number of "taps" for each filter lap (int): The filter stride log_factor (int): How much compression should be applied in the log amplitude stage frequency_weighting (FrequencyWeighting): an optional frequency weighting to be applied after log amplitude scaling cosine_similarity (bool): If `True`, compute the cosine similarity between spectrograms, otherwise, compute the mean squared error """ def __init__( self, scale, samplerate, frequency_window=gaussian(100, 3), basis_size=512, lap=2, log_factor=100, frequency_weighting=None, cosine_similarity=True): super(PerceptualLoss, self).__init__() self.cosine_similarity = cosine_similarity self.log_factor = log_factor self.scale = scale basis_size = basis_size self.lap = lap self.basis_size = basis_size basis = fir_filter_bank( scale, basis_size, samplerate, frequency_window) weights = torch.from_numpy(basis).float() # out channels x in channels x kernel width self.weights = weights.view(len(scale), 1, basis_size).contiguous() self.frequency_weights = None if frequency_weighting: fw = frequency_weighting._wdata(self.scale) self.frequency_weights = torch.from_numpy(fw)\ .float().view(1, len(self.scale), 1) def cuda(self, device=None): self.weights = self.weights.cuda(device=device) if self.frequency_weights is not None: self.frequency_weights = self.frequency_weights.cuda(device=device) return super(PerceptualLoss, self).cuda(device=device) def to(self, device=None): self.weights = if self.frequency_weights is not None: self.frequency_weights = return super(PerceptualLoss, self).to(device=device) def _transform(self, x): x = x.view(x.shape[0], 1, -1) # frequency decomposition features = F.conv1d( x, self.weights, stride=self.lap, padding=self.basis_size) # half-wave rectification features = F.relu(features) # log magnitude features = torch.log(1 + features * self.log_factor) # perceptual frequency weighting if self.frequency_weights is not None: features = features * self.frequency_weights return features def forward(self, input, target): input = input.view(input.shape[0], 1, -1) target = target.view(input.shape[0], 1, -1) input_features = self._transform(input).view(input.shape[0], -1) target_features = self._transform(target).view(input.shape[0], -1) if self.cosine_similarity: spectral_error = \ -(F.cosine_similarity(input_features, target_features).mean()) return spectral_error else: return ((input_features - target_features) ** 2).mean()
class BandLoss(nn.MSELoss): def __init__(self, factors, spectral_shape_weight=1): super(BandLoss, self).__init__() self.spectral_shape_weight = spectral_shape_weight self.factors = factors self.dct_transform = DctTransform() def cuda(self, device=None): self.dct_transform = DctTransform(use_cuda=True) return super(BandLoss, self).cuda(device=device) def _transform(self, x): bands = self.dct_transform.frequency_decomposition( x, self.factors, axis=-1) norms = [torch.norm(b, dim=-1, keepdim=True) for b in bands] bands = [b / (n + 1e-8) for (b, n) in zip(bands, norms)] fine =, dim=-1) coarse =, dim=-1) coarse_norms = torch.norm(coarse, dim=-1, keepdim=True) coarse = coarse / (coarse_norms + 1e-8) return fine, coarse def forward(self, input, target): input_bands, input_coarse = self._transform(input) target_bands, target_coarse = self._transform(target) fine = super(BandLoss, self).forward(input_bands, target_bands) coarse = super(BandLoss, self).forward(input_coarse, target_coarse) return fine + (coarse * self.spectral_shape_weight) class CategoricalLoss(object): def __init__(self, n_categories): super(CategoricalLoss, self).__init__() self.n_categories = n_categories self.use_cuda = False self.loss = nn.NLLLoss2d() def cuda(self, device=None): self.use_cuda = True self.loss = self.loss.cuda(device=device) return self def _variable(self, x, *args, **kwargs): v = Variable(x, *args, **kwargs) if self.use_cuda: v = v.cuda() return v def _mu_law(self, x): m = self._variable(torch.FloatTensor(1)) m[:] = self.n_categories + 1 s = torch.sign(x) x = torch.abs(x) x = s * (torch.log(1 + (self.n_categories * x)) / torch.log(m)) return x def _shift_and_scale(self, x): x = x + 1 x = x * ((self.n_categories) / 2.) return x def _one_hot(self, x): y = self._variable(torch.arange(0, self.n_categories + 1)) x = -(((x[..., None] - y) ** 2) * 1e12) x = F.log_softmax(x, dim=-1) return x def _discretized(self, x): x = x.view(-1) x = x / torch.abs(x).max() x = self._mu_law(x) x = self._shift_and_scale(x) return x def _categorical(self, x): x = self._discretized(x) x = self._one_hot(x) return x def forward(self, input, target): if input.shape[1] == self.n_categories + 1: categorical = input else: categorical = self._categorical(input) discretized = self._discretized(target) inp = categorical.view( -1, self.n_categories + 1, 2, input.shape[-1] // 2) t = discretized.view(-1, 2, target.shape[-1] // 2).long() error = self.loss(inp, t) return error def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) class BaseLoss(object): def __init__(self): super(BaseLoss, self).__init__() self.use_cuda = False def _cuda(self, device=None): raise NotImplementedError() def cuda(self, device=None): self.use_cuda = True self._cuda(device=device) return self def forward(self, *args, **kwargs): raise NotImplementedError() def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) class LearnedWassersteinLoss(BaseLoss): def __init__(self, critic): super(LearnedWassersteinLoss, self).__init__() self.critic = critic def _cuda(self, device=None): self.critic.cuda(device=device) def forward(self, x, **critic_kwargs): w = self.critic(x, **critic_kwargs) return -torch.mean(w) class WassersteinCriticLoss(BaseLoss): def __init__(self, critic): super(WassersteinCriticLoss, self).__init__() self.critic = critic def _cuda(self, device=None): self.critic.cuda(device=device) def forward(self, real, fake, **critic_kwargs): d_real = self.critic(real, **critic_kwargs) d_fake = self.critic(fake, **critic_kwargs) real_mean = torch.mean(d_real) fake_mean = torch.mean(d_fake) return fake_mean - real_mean class WassersteinGradientPenaltyLoss(BaseLoss): def __init__(self, critic, weight=10): super(WassersteinGradientPenaltyLoss, self).__init__() self.weight = weight self.critic = critic def _cuda(self, device=None): self.critic.cuda(device=device) def forward(self, real_samples, fake_samples, **critic_kwargs): from torch.autograd import grad real_samples = real_samples.view(fake_samples.shape) subset_size = real_samples.shape[0] real_samples = real_samples[:subset_size] fake_samples = fake_samples[:subset_size] alpha = torch.rand(subset_size) if self.use_cuda: alpha = alpha.cuda() alpha = alpha.view((-1,) + ((1,) * (real_samples.dim() - 1))) interpolates = alpha * real_samples + ((1 - alpha) * fake_samples) if self.use_cuda: interpolates = interpolates.cuda() interpolates = Variable(interpolates, requires_grad=True) d_output = self.critic(interpolates, **critic_kwargs) output = torch.ones(d_output.size()) if self.use_cuda: output = output.cuda() gradients = grad( outputs=d_output, inputs=interpolates, grad_outputs=output, create_graph=True, retain_graph=True, only_inputs=True)[0] return ((gradients.norm(2, dim=1) - 1) ** 2).mean() * self.weight