Anny on

Source: Trainer.js

import _ from 'lodash'
import validate from './Validate'

/**
 * A Trainer teaches a {@link Network} how to correctly classify some `data`.
 *
 * @example
 * const network = new anny.Network([2, 1])
 * const trainer = new Trainer()
 *
 * trainer.train(network, anny.DATA.ORGate)
 *
 * network.activate([0, 0]) // => 0.000836743108
 * network.activate([0, 1]) // => 0.998253857294
 */
class Trainer {
  /**
   * @param {object} [options]
   * @param {boolean|number} [options.batch]
   *   Use batch, online (stochastic), or mini-batch learning modes.
   *
   *   Batch `true`: Connection weights are updated once after iterating
   *   through all the training samples in the training data (an epoch).
   *
   *   Online `false`: Connection weights are updated after every training
   *   sample in the training data.
   *
   *   Mini-batch `<number>`: Connection weights are updated every `<number>`
   *   training samples.
   * @param {number} [options.errorThreshold=0.001]
   *   The target `error` value. The goal of the Trainer is to train the
   *   Network until the `error` is below this value.
   * @param {number} [options.frequency=100]
   *   How many iterations through the training data between calling
   *   `options.onProgress`.
   * @param {number} [options.maxEpochs=20000]
   *   The max training iterations. The Trainer will stop training after
   *   iterating through the training data this number of times.  One full loop
   *   through the training data is counted as one epoch.
   * @param {Trainer~onFail} [options.onFail]
   *   Called if the Network `error` does not fall below the `errorThreshold`
   *   after `maxEpochs`.
   * @param {Trainer~onProgress} [options.onProgress]
   *   Called every `frequency` epochs.
   * @param {Trainer~onSuccess} [options.onSuccess]
   *   Called if the Network `error` falls below the `errorThreshold` during
   *   training.
   * @constructor
   */
  constructor(options = {}) {
    const defaultOptions = {
      batch: false,
      errorThreshold: 0.001,
      frequency: 100,
      maxEpochs: 20000,
    }

    const mergedOptions = _.merge(defaultOptions, options)
    validate.trainingOptions(mergedOptions)
    this.options = mergedOptions
  }

  /**
   * Train the `network` to classify the `data`.
   * @param {Network} network - The Network to be trained.
   * @param {object[]} data - Array of objects in the form
   * `{input: [], output: []}`.
   * @param {Object} [overrides] Overrides are merged into this trainer
   * @param {Trainer~onSuccess} [overrides.onSuccess] Overrides are merged into this trainer
   *   instance's options.
   * @see Network
   * @see Data
   */
  train(network, data, overrides = {}) {
    validate.trainingData(network, data)
    const mergedOptions = _.merge(this.options, overrides)
    validate.trainingOptions(mergedOptions)
    // TODO: normalize data to the range of the activation functions
    const {
      batch,
      errorThreshold,
      frequency,
      maxEpochs,
      onFail,
      onProgress,
      onSuccess,
    } = mergedOptions

    const isBatch = batch === true
    const isOnline = batch === false
    const isMiniBatch = _.isNumber(batch)

    // track samples iterated, allows for mini-batches that span epochs
    let sampleCounter = 1
    let epochCount = 1

    const getAverageSampleError = (sample, sampleIndex) => {
      const shouldUpdate = isOnline
        || isMiniBatch && sampleCounter % batch === 0
        || isBatch && sampleIndex === data.length - 1

      // propagation: set inputs & outputs, then error & deltas
      network.activate(sample.input)
      network.backprop(sample.output)

      // weight updates: update weights || accumulate weight gradients
      if (shouldUpdate) network.updateWeights()
      else network.accumulateGradients()

      sampleCounter++
      return network.error / data.length
    }

    for (let i = maxEpochs; i > 0; i -= 1) {
      // sum the average error of all training samples
      const error = _.sumBy(data, getAverageSampleError)

      // call onProgress after the first epoch and every `frequency` thereafter
      if (onProgress && epochCount % frequency === 0) {
        const stop = onProgress(error, epochCount) === false
        if (stop) break
      }

      // success
      if (onSuccess && error <= errorThreshold) {
        onSuccess(error, epochCount)
        break
      }

      // fail
      if (onFail && epochCount === maxEpochs) onFail(error, epochCount)

      epochCount++
    }
  }

  /**
   * Called if the `network` error falls below the `errorThreshold`.
   * @callback Trainer~onSuccess
   * @param {number} error
   *   The `network` error value at the time of success.
   * @param {number} epoch
   *   Indicates on which iteration through the training data the `network`
   *   became successful.
   */

  /**
   * Called if the `network` error is not below the `errorThreshold` after
   * `maxEpochs` iterations through the training data set.
   * @callback Trainer~onFail
   * @param {number} error
   *   The `network` error value at the time of success.
   * @param {number} epoch
   *   Indicates on which iteration through the training data the `network`
   *   became successful.
   */

  /**
   * Called if the `network` error falls below the `errorThreshold`.
   * @callback Trainer~onProgress
   * @param {number} error
   *   The `network` error value at the time of success.
   * @param {number} epoch
   *   Indicates on which iteration through the training data the `network`
   *   became successful.
   */
}

export default Trainer