train method

TrainArtifacts train(
  1. TrainSet trainSet, {
  2. double learningRate = 0.04,
  3. double maxErrClipAbove = 0.0,
  4. double skipIfErrBelow = 0.0,
  5. bool skipIfOutput(
    1. FVector
    )?,
})

Train network with a single training pair, for a single epoch.

returns the propagated error of the first layer, which is good for chained networks.

Implementation

TrainArtifacts train(TrainSet trainSet,
    {double learningRate = 0.04,
    double maxErrClipAbove = 0.0,
    double skipIfErrBelow = 0.0,
    bool Function(FVector)? skipIfOutput}) {
  assert(trainSet.input.length == layers.first.inputLength);
  FVector nextInputs = trainSet.input;
  List<FeedArtifacts> artifacts = [
    FeedArtifacts(nextInputs, FVector.zero(nextInputs.length))
  ];
  for (TfannLayer l in layers) {
    artifacts.add(l.createFeedArtifacts(nextInputs));
    nextInputs = artifacts.last.activation;
  }
  FVector netOutput = nextInputs;

  FVector netErrors;
  if (trainSet is TrainSetInputOutput) {
    assert(trainSet.output.length == layers.last.outputLength);
    netErrors = netOutput - trainSet.output;
  } else {
    assert((trainSet as TrainSetInputError).error.length == layers.last.outputLength);
    netErrors = (trainSet as TrainSetInputError).error;
  }
  if (skipIfOutput?.call(netOutput) ?? false) {
    return TrainArtifacts(netErrors, FVector.zero(layers.first.inputLength));
  }
  FVector normalizedErrors = netErrors;
  if (maxErrClipAbove > 0.0) {
    double norm = normalizedErrors.abs().largestElement();
    if (norm < skipIfErrBelow) {
      return TrainArtifacts(
          netErrors, FVector.zero(layers.first.inputLength));
    }
    //double norm = normalizedErrors.squared().sumElements();
    if (norm > maxErrClipAbove) {
      normalizedErrors = netErrors.scaled(maxErrClipAbove / norm);
    }
  }

  FVector previousDelta = normalizedErrors;
  List<FVector> layerDelta = [];
  for (int i = layers.length - 1; i >= 0; --i) {
    FVector currentDelta = (artifacts[i + 1].derivative * previousDelta);
    layerDelta.add(currentDelta);

    previousDelta =
        (layers[i].weights.transposed().multiplyVector(currentDelta));
  }

  var arti = artifacts.iterator;
  for (TfannLayer l in layers) {
    l.bias -= layerDelta.last.scaled(learningRate);
    arti.moveNext();

    l.weights -= layerDelta.last.multiplyTransposed(arti.current.activation)
      ..scale(learningRate);
    layerDelta.removeLast();
  }
  return TrainArtifacts(netErrors, previousDelta);
}