fit method
void
fit(
- List<
List< X,double> > - List<
List< Y, {double> > - int? batchSize,
- bool verbose = false,
- String optimizer = 'sgd',
- double momentum = 0.9,
- double beta1 = 0.9,
- double beta2 = 0.999,
- double epsilon = 1e-8,
- double l2 = 0.0,
- String lrSchedule = 'constant',
- int stepSize = 10,
- double stepDecay = 0.5,
- double expDecay = 0.99,
- int epochsOverride = -1,
Convenience training wrapper that calls the main fit implementation. (This class intentionally exposes a single training API.)
Implementation
void fit(
List<List<double>> X,
List<List<double>> Y, {
int? batchSize,
bool verbose = false,
String optimizer = 'sgd',
double momentum = 0.9,
double beta1 = 0.9,
double beta2 = 0.999,
double epsilon = 1e-8,
// l2 regularization (weight decay)
double l2 = 0.0,
// learning rate schedule: 'constant', 'step', 'exp'
String lrSchedule = 'constant',
int stepSize = 10,
double stepDecay = 0.5,
double expDecay = 0.99,
int epochsOverride = -1,
}) {
if (X.isEmpty) throw ArgumentError('Empty dataset');
if (X.length != Y.length) {
throw ArgumentError('X and Y must have same number of rows');
}
_initParams();
final n = X.length;
final useBatch = (batchSize == null || batchSize >= n) ? n : batchSize;
final int runEpochs = epochsOverride > 0 ? epochsOverride : epochs;
// optimizer state (for momentum/adam)
List<List<List<double>>> vW = [];
List<List<double>> vB = [];
List<List<List<double>>> mW = [];
List<List<List<double>>> vAdamW = [];
List<List<double>> mB = [];
List<List<double>> vAdamB = [];
var t = 0;
// initialize optimizer accumulators
if (optimizer == 'momentum') {
for (var l = 0; l < weights.length; l++) {
vW.add(
List.generate(
weights[l].length,
(_) => List<double>.filled(weights[l][0].length, 0.0),
),
);
vB.add(List<double>.filled(biases[l].length, 0.0));
}
} else if (optimizer == 'adam') {
for (var l = 0; l < weights.length; l++) {
mW.add(
List.generate(
weights[l].length,
(_) => List<double>.filled(weights[l][0].length, 0.0),
),
);
vAdamW.add(
List.generate(
weights[l].length,
(_) => List<double>.filled(weights[l][0].length, 0.0),
),
);
mB.add(List<double>.filled(biases[l].length, 0.0));
vAdamB.add(List<double>.filled(biases[l].length, 0.0));
}
}
for (var epoch = 0; epoch < runEpochs; epoch++) {
// compute current learning rate based on schedule
double currentLr = lr;
if (lrSchedule == 'step') {
currentLr = lr * pow(stepDecay, (epoch / stepSize).floor());
} else if (lrSchedule == 'exp') {
currentLr = lr * pow(expDecay, epoch);
}
// create shuffled indices for mini-batching
final indices = List<int>.generate(n, (i) => i);
if (useBatch < n) {
for (var i = indices.length - 1; i > 0; i--) {
final j = _rand.nextInt(i + 1);
final tmp = indices[i];
indices[i] = indices[j];
indices[j] = tmp;
}
}
for (var batchStart = 0; batchStart < n; batchStart += useBatch) {
final batchEnd = (batchStart + useBatch).clamp(0, n);
final bsize = batchEnd - batchStart;
if (bsize <= 0) continue;
// prepare batch activations and preacts
final activations = <List<List<double>>>[];
final preacts = <List<List<double>>>[];
for (var bi = batchStart; bi < batchEnd; bi++) {
final idx = indices[bi];
var a = X[idx];
final acts = <List<double>>[a];
final pres = <List<double>>[];
for (var l = 0; l < weights.length; l++) {
final w = weights[l];
final b = biases[l];
final z = List<double>.filled(w.length, 0.0);
for (var ii = 0; ii < w.length; ii++) {
var s = b[ii];
for (var jj = 0; jj < w[ii].length; jj++) {
s += w[ii][jj] * a[jj];
}
z[ii] = s;
}
pres.add(z);
if (l == weights.length - 1) {
a = z.map(_sigmoid).toList();
} else {
a = z.map(_relu).toList();
}
acts.add(a);
}
activations.add(acts);
preacts.add(pres);
}
// compute gradients for batch
final gradW = List.generate(
weights.length,
(l) => List.generate(
weights[l].length,
(_) => List<double>.filled(weights[l][0].length, 0.0),
),
);
final gradB = List.generate(
biases.length,
(l) => List<double>.filled(biases[l].length, 0.0),
);
for (var bi = 0; bi < activations.length; bi++) {
final idx = indices[batchStart + bi];
final y = Y[idx];
final acts = activations[bi];
final pres = preacts[bi];
final out = acts.last;
final delta = List<double>.filled(out.length, 0.0);
for (var k = 0; k < out.length; k++) {
final err = out[k] - y[k];
delta[k] = err * _sigmoidDeriv(out[k]);
}
var curDelta = delta;
for (var l = weights.length - 1; l >= 0; l--) {
final aPrev = acts[l];
for (var iOut = 0; iOut < weights[l].length; iOut++) {
gradB[l][iOut] += curDelta[iOut];
for (var iIn = 0; iIn < weights[l][iOut].length; iIn++) {
gradW[l][iOut][iIn] += curDelta[iOut] * aPrev[iIn];
}
}
if (l > 0) {
final nextDelta = List<double>.filled(weights[l - 1].length, 0.0);
for (var iPrev = 0; iPrev < weights[l - 1].length; iPrev++) {
var s = 0.0;
for (var iOut = 0; iOut < weights[l].length; iOut++) {
s += weights[l][iOut][iPrev] * curDelta[iOut];
}
nextDelta[iPrev] = s * _reluDeriv(pres[l - 1][iPrev]);
}
curDelta = nextDelta;
}
}
}
// average gradients over batch
for (var l = 0; l < weights.length; l++) {
for (var iOut = 0; iOut < weights[l].length; iOut++) {
for (var iIn = 0; iIn < weights[l][iOut].length; iIn++) {
gradW[l][iOut][iIn] /= bsize;
}
gradB[l][iOut] /= bsize;
}
}
// optimizer updates (use currentLr)
if (optimizer == 'sgd') {
for (var l = 0; l < weights.length; l++) {
for (var iOut = 0; iOut < weights[l].length; iOut++) {
for (var iIn = 0; iIn < weights[l][iOut].length; iIn++) {
weights[l][iOut][iIn] -= currentLr * gradW[l][iOut][iIn];
// weight decay
if (l2 > 0) weights[l][iOut][iIn] *= (1 - currentLr * l2);
}
biases[l][iOut] -= currentLr * gradB[l][iOut];
}
}
} else if (optimizer == 'momentum') {
for (var l = 0; l < weights.length; l++) {
for (var iOut = 0; iOut < weights[l].length; iOut++) {
for (var iIn = 0; iIn < weights[l][iOut].length; iIn++) {
vW[l][iOut][iIn] =
momentum * vW[l][iOut][iIn] +
currentLr * gradW[l][iOut][iIn];
weights[l][iOut][iIn] -= vW[l][iOut][iIn];
if (l2 > 0) weights[l][iOut][iIn] *= (1 - currentLr * l2);
}
vB[l][iOut] = momentum * vB[l][iOut] + currentLr * gradB[l][iOut];
biases[l][iOut] -= vB[l][iOut];
}
}
} else if (optimizer == 'adam') {
t += 1;
for (var l = 0; l < weights.length; l++) {
for (var iOut = 0; iOut < weights[l].length; iOut++) {
for (var iIn = 0; iIn < weights[l][iOut].length; iIn++) {
final g = gradW[l][iOut][iIn];
mW[l][iOut][iIn] = beta1 * mW[l][iOut][iIn] + (1 - beta1) * g;
vAdamW[l][iOut][iIn] =
beta2 * vAdamW[l][iOut][iIn] + (1 - beta2) * g * g;
final mHat = mW[l][iOut][iIn] / (1 - pow(beta1, t));
final vHat = vAdamW[l][iOut][iIn] / (1 - pow(beta2, t));
weights[l][iOut][iIn] -=
currentLr * mHat / (sqrt(vHat) + epsilon);
if (l2 > 0) weights[l][iOut][iIn] *= (1 - currentLr * l2);
}
final gb = gradB[l][iOut];
mB[l][iOut] = beta1 * mB[l][iOut] + (1 - beta1) * gb;
vAdamB[l][iOut] = beta2 * vAdamB[l][iOut] + (1 - beta2) * gb * gb;
final mHatB = mB[l][iOut] / (1 - pow(beta1, t));
final vHatB = vAdamB[l][iOut] / (1 - pow(beta2, t));
biases[l][iOut] -= currentLr * mHatB / (sqrt(vHatB) + epsilon);
}
}
} else {
throw ArgumentError('Unknown optimizer: $optimizer');
}
}
// compute epoch loss and store (mean over samples) using helper
lastLoss = _mseLoss(predict(X), Y);
lossHistory.add(lastLoss!);
if (verbose) print('epoch=$epoch loss=$lastLoss');
}
}