train method
Train network with a single training pair, for a single epoch.
returns the propagated error of the first layer, which is good for chained networks.
Implementation
TrainArtifacts train(TrainSet trainSet,
{double learningRate = 0.04,
double maxErrClipAbove = 0.0,
double skipIfErrBelow = 0.0,
bool Function(FVector)? skipIfOutput}) {
assert(trainSet.input.length == layers.first.inputLength);
FVector nextInputs = trainSet.input;
List<FeedArtifacts> artifacts = [
FeedArtifacts(nextInputs, FVector.zero(nextInputs.length))
];
for (TfannLayer l in layers) {
artifacts.add(l.createFeedArtifacts(nextInputs));
nextInputs = artifacts.last.activation;
}
FVector netOutput = nextInputs;
FVector netErrors;
if (trainSet is TrainSetInputOutput) {
assert(trainSet.output.length == layers.last.outputLength);
netErrors = netOutput - trainSet.output;
} else {
assert((trainSet as TrainSetInputError).error.length == layers.last.outputLength);
netErrors = (trainSet as TrainSetInputError).error;
}
if (skipIfOutput?.call(netOutput) ?? false) {
return TrainArtifacts(netErrors, FVector.zero(layers.first.inputLength));
}
FVector normalizedErrors = netErrors;
if (maxErrClipAbove > 0.0) {
double norm = normalizedErrors.abs().largestElement();
if (norm < skipIfErrBelow) {
return TrainArtifacts(
netErrors, FVector.zero(layers.first.inputLength));
}
//double norm = normalizedErrors.squared().sumElements();
if (norm > maxErrClipAbove) {
normalizedErrors = netErrors.scaled(maxErrClipAbove / norm);
}
}
FVector previousDelta = normalizedErrors;
List<FVector> layerDelta = [];
for (int i = layers.length - 1; i >= 0; --i) {
FVector currentDelta = (artifacts[i + 1].derivative * previousDelta);
layerDelta.add(currentDelta);
previousDelta =
(layers[i].weights.transposed().multiplyVector(currentDelta));
}
var arti = artifacts.iterator;
for (TfannLayer l in layers) {
l.bias -= layerDelta.last.scaled(learningRate);
arti.moveNext();
l.weights -= layerDelta.last.multiplyTransposed(arti.current.activation)
..scale(learningRate);
layerDelta.removeLast();
}
return TrainArtifacts(netErrors, previousDelta);
}