Anny on

Source: Validate.js

import _ from 'lodash'
import { type } from './Util'

/**
 * Thin Error wrapper that concatenates all arguments into the Error message.
 * @param {string|string[]} message - The error message, concatenated if an
 *   array is passed.
 * @returns {Error}
 * @constructor
 */
class ValidationError extends Error {
  constructor(message) {
    const msg = [].concat(message).join('')
    super(msg)
    this.name = 'ValidationError'
    this.message = msg
    Error.captureStackTrace(this, 'ValidationError')
  }
}

/**
 * @namespace
 * @type {{}}
 */
const validate = {
  /**
   * @param {array} data The training data set.
   */
  dataIsArray: (data) => {
    if (!_.isArray(data)) {
      throw new ValidationError(`Training data must be an array, not: ${type(data)}`)
    }
  },

  /**
   * @param {array} data The training data set.
   */
  dataIsNotEmpty: (data) => {
    if (_.isEmpty(data)) {
      throw new ValidationError('Training data array must not be empty.')
    }
  },

  /**
   * @param {{}} sample A single object from a training data set.
   * @param {number} i The index of the sample in the training set.
   */
  sampleIsObject: (sample, i) => {
    if (!_.isPlainObject(sample)) {
      throw new ValidationError([
        `Training data array elements must be objects. Element at index ${i}`,
        ` is of type: ${type(sample)}`,
      ])
    }
  },

  /**
   * @param {{}} sample A single object from a training data set.
   * @param {number} i The index of the sample in the training set.
   */
  sampleHasInput: (sample, i) => {
    if (!_.isArray(sample.input)) {
      throw new ValidationError([
        'Training object "input" property must be an array.',
        ` data[${i}]["input"] is type: ${type(sample.input)}`,
      ])
    }
  },

  /**
   * @param {{}} sample A single object from a training data set.
   * @param {number} i The index of the sample in the training set.
   */
  sampleHasOutput: (sample, i) => {
    if (!_.isArray(sample.output)) {
      throw new ValidationError([
        'Training object "output" property must be an array.',
        ` data[${i}]["output"] is type: ${type(sample.output)}`,
      ])
    }
  },

  /**
   * @param {{}} sample A single object from a training data set.
   * @param {number} i The index of the sample in the training set.
   */
  sampleInputIsNotEmpty: (sample, i) => {
    if (_.isEmpty(sample.input)) {
      throw new ValidationError([
        'Training object "input" array must not be empty.',
        ` See data[${i}]["input"].`,
      ])
    }
  },

  /**
   * @param {{}} sample A single object from a training data set.
   * @param {number} i The index of the sample in the training set.
   */
  sampleOutputIsNotEmpty: (sample, i) => {
    if (_.isEmpty(sample.output)) {
      throw new ValidationError([
        'Training object "output" array must not be empty.',
        ` See data[${i}]["output"].`,
      ])
    }
  },

  /**
   * @param {{}} sample A single object from a training data set.
   * @param {number} i The index of the sample in the training set.
   */
  sampleInputIsOnlyNumbers: (sample, i) => {
    _.forEach(sample.input, (n) => {
      if (!_.isNumber(n)) {
        throw new ValidationError([
          'Training object "input" property values must be numbers.',
          ` data[${i}]["input"] contains: ${type(n)}`,
        ])
      }
    })
  },

  /**
   * @param {{}} sample A single object from a training data set.
   * @param {number} i The index of the sample in the training set.
   */
  sampleOutputIsOnlyNumbers: (sample, i) => {
    _.forEach(sample.output, (n) => {
      if (!_.isNumber(n)) {
        throw new ValidationError([
          'Training object "output" property values must be numbers.',
          ` data[${i}]["output"] contains: ${type(n)}`,
        ])
      }
    })
  },

  /**
   * @param {{}} sample A single object from a training data set.
   * @param {number} i The index of the sample in the training set.
   * @param {Network} network The Network to be trained on the training data
   *   set.
   */
  sampleInputFitsNetwork: (sample, i, network) => {
    const inputLayerSize = network.inputLayer.size()
    if (inputLayerSize !== sample.input.length) {
      throw new ValidationError([
        'Training data "input" and network input layer size must be equal.',
        ` data[${i}] input size is ${sample.input.length}.`,
        ` Network input size is ${inputLayerSize}.`,
      ])
    }
  },

  /**
   * @param {{}} sample A single object from a training data set.
   * @param {number} i The index of the sample in the training set.
   * @param {Network} network The Network to be trained on the training data
   *   set.
   */
  sampleOutputFitsNetwork: (sample, i, network) => {
    const outputLayerSize = network.outputLayer.size()
    if (outputLayerSize !== sample.output.length) {
      throw new ValidationError([
        'Training data "output" and network output layer size must be equal.',
        ` data[${i}] output size is ${sample.output.length}.`,
        ` Network output size is ${outputLayerSize}.`,
      ])
    }
  },

  /**
   * Ensures that a training set in valid form.
   * @param {Network} network The Network to be trained.
   * @param {object[]} data The data set to train the Network on.
   */
  trainingData: (network, data) => {
    validate.dataIsArray(data)
    validate.dataIsNotEmpty(data)

    _.forEach(data, (sample, i) => {
      validate.sampleIsObject(sample, i)
      validate.sampleHasInput(sample, i)
      validate.sampleHasOutput(sample, i)
      validate.sampleInputIsNotEmpty(sample, i)
      validate.sampleOutputIsNotEmpty(sample, i)
      validate.sampleInputIsOnlyNumbers(sample, i)
      validate.sampleOutputIsOnlyNumbers(sample, i)
      validate.sampleInputFitsNetwork(sample, i, network)
      validate.sampleOutputFitsNetwork(sample, i, network)
    })
  },

  trainingOptions(options) {
    if (!_.isPlainObject(options)) {
      throw new Error('training "options" must be a plain object.')
    }

    const validOptions = [
      'batch',
      'errorThreshold',
      'frequency',
      'maxEpochs',
      'onFail',
      'onProgress',
      'onSuccess',
    ]

    _.forEach(options, (val, key) => {
      if (_.includes(validOptions, key)) return
      throw new Error(`Unknown training option "${key}", try: ${validOptions}`)
    })

    if (!_.isBoolean(options.batch) && !_.isNumber(options.batch)) {
      throw new Error('training option "batch" must be a boolean or number.')
    }

    if (!_.isNumber(options.errorThreshold)) {
      throw new Error('training option "errorThreshold" must be a number.')
    }

    if (!_.isNumber(options.frequency)) {
      throw new Error('training option "frequency" must be a number.')
    }

    if (!_.isNumber(options.maxEpochs)) {
      throw new Error('training option "maxEpochs" must be a number')
    }

    if (_.has(options, 'onFail') && !_.isFunction(options.onFail)) {
      throw new Error('training option "onFail" must be a function.')
    }

    if (_.has(options, 'onProgress') && !_.isFunction(options.onProgress)) {
      throw new Error('training option "onProgress" must be a function.')
    }

    if (_.has(options, 'onSuccess') && !_.isFunction(options.onSuccess)) {
      throw new Error('training option "onSuccess" must be a function.')
    }
  },
}

export default validate