fit method
void
fit(})
Implementation
void fit(List<List<double>> inputs, List<List<double>> targets,
{int epochs = 100, bool averageWeight = false, bool debug = true}) {
if (debug) {
Logger.log('--- STARTING TRAINING ---');
}
Stopwatch stopwatch = Stopwatch();
stopwatch.start();
for (int epoch = 0; epoch < epochs; epoch = epoch + 1) {
double epochLoss = 0.0;
for (int i = 0; i < inputs.length; i = i + 1) {
Tensor<Vector> input = Tensor<Vector>(inputs[i]);
Tensor<Vector> target = Tensor<Vector>(targets[i]);
Tensor<Vector> finalOutput = forward(input) as Tensor<Vector>;
Tensor<Scalar> loss = mse(finalOutput, target);
epochLoss = epochLoss + loss.value;
loss.backward();
optimizer.step();
optimizer.zeroGrad();
if (debug) {
int barWidth = 20;
double progress = (i + 1) / inputs.length;
int completed = (progress * barWidth).round();
String bar = '';
for (int b = 0; b < completed; b = b + 1) {
bar = '$bar=';
}
bar = '$bar>';
for (int b = 0; b < (barWidth - completed); b = b + 1) {
bar = '$bar ';
}
int percent = (progress * 100).round();
stdout.write('\rEpoch ${epoch + 1}/$epochs: [$bar] $percent%');
}
}
if (debug) {
double avgLoss = epochLoss / inputs.length;
stdout.write('\rEpoch ${epoch + 1}/$epochs: [====================>] 100%, Avg Loss: ${avgLoss.toStringAsFixed(6)}');
int logInterval = max(1, (epochs / 10).round());
bool isLogInterval = (epoch + 1) % logInterval == 0;
if (averageWeight && isLogInterval) {
double totalWeightSum = 0.0;
int totalWeightCount = 0;
List<Tensor<dynamic>> params = parameters;
for (int p = 0; p < params.length; p = p + 1) {
Tensor<dynamic> param = params[p];
if (param.value is Vector) {
Vector v = param.value as Vector;
for (int w = 0; w < v.length; w = w + 1) {
totalWeightSum = totalWeightSum + v[w].abs();
totalWeightCount = totalWeightCount + 1;
}
} else if (param.value is Matrix) {
Matrix m = param.value as Matrix;
for (int r = 0; r < m.length; r = r + 1) {
for (int c = 0; c < m[r].length; c = c + 1) {
totalWeightSum = totalWeightSum + m[r][c].abs();
totalWeightCount = totalWeightCount + 1;
}
}
}
}
if (totalWeightCount > 0) {
double avg = totalWeightSum / totalWeightCount;
stdout.write(', Avg Weight Mag: ${avg.toStringAsFixed(4)}');
}
}
Logger.log('');
}
}
stopwatch.stop();
if (debug) {
Logger.log('--- TRAINING FINISHED in ${stopwatch.elapsedMilliseconds}ms ---\n');
}
}