Source code for zounds.learn.supervised

from .trainer import Trainer
import numpy as np
import warnings


[docs]class SupervisedTrainer(Trainer): def __init__( self, model, loss, optimizer, epochs, batch_size, holdout_percent=0.0, data_preprocessor=lambda x: x, label_preprocessor=lambda x: x, checkpoint_epochs=1): super(SupervisedTrainer, self).__init__( epochs, batch_size, checkpoint_epochs=checkpoint_epochs) self.label_preprocessor = label_preprocessor self.data_preprocessor = data_preprocessor self.holdout_percent = holdout_percent self.optimizer = optimizer(model) self.loss = loss self.network = model self.register_batch_complete_callback(self._log) self.samples = None def _log(self, *args, **kwargs): if kwargs['batch'] % 10: return msg = \ 'Epoch {epoch}, batch {batch}, train error ' \ '{train_error}, test error {test_error}' print(msg.format(**kwargs)) def random_sample(self): if self.samples is None: raise RuntimeError( 'There are no samples yet. Has training started?') index = np.random.randint(0, len(self.samples)) inp, label = self.samples[index] return inp, label def _cuda(self, device=None): self.network = self.network.cuda() self.loss = self.loss.cuda() def train(self, data): data, labels = data['data'], data['labels'] test_size = int(self.holdout_percent * len(data)) test_data, test_labels = data[:test_size], labels[:test_size] data, labels = data[test_size:], labels[test_size:] def batch(d, l, test=False): d = self.data_preprocessor(d) l = self.label_preprocessor(l) inp_v = self._variable(d, volatile=test) output = self.network(inp_v) labels_v = self._variable(l) error = self.loss(output, labels_v) if not test: error.backward() self.optimizer.step() self.samples = list(zip(inp_v, output)) return inp_v, output, error.data.item() start = self._current_epoch stop = self._current_epoch + self.checkpoint_epochs for epoch in range(start, stop): if epoch >= self.epochs: break for i in range(0, len(data), self.batch_size): self.network.zero_grad() # training batch minibatch_slice = slice(i, i + self.batch_size) minibatch_data = data[minibatch_slice] minibatch_labels = labels[minibatch_slice] try: inp, output, e = batch( minibatch_data, minibatch_labels, test=False) except RuntimeError as e: if 'Assert' in e.message: warnings.warn(e.message) continue else: raise # test batch if test_size: indices = np.random.randint(0, test_size, self.batch_size) test_batch_data = test_data[indices, ...] test_batch_labels = test_labels[indices, ...] inp, output, te = batch( test_batch_data, test_batch_labels, test=True) else: te = 'n/a' self.on_batch_complete( epoch=epoch, batch=i, train_error=te, test_error=e, samples=self.samples) self._current_epoch += 1 return self.network