fit method
void
fit(})
Implementation
void fit(List<List<double>> inputs, List<List<double>> targets,
{int epochs = 100, bool averageWeight = false, bool debug = true}) {
if (debug) {
print('--- STARTING TRAINING ---');
}
Stopwatch stopwatch = Stopwatch()..start();
for (int epoch = 0; epoch < epochs; epoch++) {
double epochLoss = 0.0;
for (int i = 0; i < inputs.length; i++) {
Tensor<Vector> input = Tensor<Vector>(inputs[i]);
Tensor<Vector> target = Tensor<Vector>(targets[i]);
Tensor<Vector> finalOutput = forward(input) as Tensor<Vector>;
Tensor<Scalar> loss = mse(finalOutput, target);
epochLoss += loss.value;
loss.backward();
optimizer.step();
optimizer.zeroGrad();
// --- NEW: Progress Bar Logic ---
if (debug) {
int barWidth = 20;
double progress = (i + 1) / inputs.length;
int completed = (progress * barWidth).round();
String bar = '=' * completed + '>' + ' ' * (barWidth - completed);
int percent = (progress * 100).round();
// Use stdout.write and carriage return to update the line
stdout.write('\rEpoch ${epoch + 1}/$epochs: [$bar] $percent%');
}
}
if (debug) {
double avgLoss = epochLoss / inputs.length;
// After the progress bar is full, overwrite it with the final loss
stdout.write('\rEpoch ${epoch + 1}/$epochs: [====================>] 100%, Avg Loss: ${avgLoss.toStringAsFixed(6)}');
// Calculate the interval once before the loop for efficiency.
int logInterval = max(1, (epochs / 10).round());
// Inside the loop, the check is now always safe.
bool isLogInterval = (epoch + 1) % logInterval == 0;
if (averageWeight && isLogInterval) {
// Calculate and print weight magnitude on a new line for clarity
double totalWeightSum = 0;
int totalWeightCount = 0;
for (Tensor param in parameters) {
if (param.value is Vector) {
for (double weight in (param.value as Vector)) {
totalWeightSum += weight.abs();
totalWeightCount++;
}
} else if (param.value is Matrix) {
for (Vector row in (param.value as Matrix)) {
for (double weight in row) {
totalWeightSum += weight.abs();
totalWeightCount++;
}
}
}
}
if (totalWeightCount > 0) {
double avg = totalWeightSum / totalWeightCount;
stdout.write(', Avg Weight Mag: ${avg.toStringAsFixed(4)}');
}
}
// Print a newline to move to the next epoch's log
print('');
}
}
stopwatch.stop();
if (debug) {
print('--- TRAINING FINISHED in ${stopwatch.elapsedMilliseconds}ms ---\n');
}
}