fit method
Implementation
void fit(List<List<double>> X) {
final n = X.length;
if (n == 0) throw ArgumentError('Empty dataset');
final m = X[0].length;
final rnd = Random(0);
w1 = List.generate(
m,
(_) => List.generate(hidden, (_) => rnd.nextDouble() * 0.01),
);
w2 = List.generate(
hidden,
(_) => List.generate(m, (_) => rnd.nextDouble() * 0.01),
);
b1 = List<double>.filled(hidden, 0.0);
b2 = List<double>.filled(m, 0.0);
for (var epoch = 0; epoch < epochs; epoch++) {
// Forward pass
final Z = _matMul(X, w1!);
for (var i = 0; i < Z.length; i++) {
for (var j = 0; j < Z[0].length; j++) {
Z[i][j] += b1![j];
}
}
// ReLU
for (var i = 0; i < Z.length; i++) {
for (var j = 0; j < Z[0].length; j++) {
Z[i][j] = max(0.0, Z[i][j]);
}
}
var recon = _matMul(Z, w2!);
for (var i = 0; i < recon.length; i++) {
for (var j = 0; j < recon[0].length; j++) {
recon[i][j] += b2![j];
}
}
// compute simple gradients (MSE) and update (very small step implementation)
final gradOut = List.generate(n, (_) => List<double>.filled(m, 0.0));
for (var i = 0; i < n; i++) {
for (var j = 0; j < m; j++) {
gradOut[i][j] = recon[i][j] - X[i][j];
}
}
final dW2 = List.generate(hidden, (_) => List<double>.filled(m, 0.0));
for (var i = 0; i < hidden; i++) {
for (var j = 0; j < m; j++) {
for (var t = 0; t < n; t++) {
dW2[i][j] += Z[t][i] * gradOut[t][j];
}
}
}
final dW1 = List.generate(m, (_) => List<double>.filled(hidden, 0.0));
for (var i = 0; i < m; i++) {
for (var j = 0; j < hidden; j++) {
for (var t = 0; t < n; t++) {
dW1[i][j] += X[t][i] * gradOut[t][j] * (Z[t][j] > 0 ? 1.0 : 0.0);
}
}
}
// updates
for (var i = 0; i < m; i++) {
for (var j = 0; j < hidden; j++) {
w1![i][j] -= lr * dW1[i][j];
}
}
for (var i = 0; i < hidden; i++) {
for (var j = 0; j < m; j++) {
w2![i][j] -= lr * dW2[i][j];
}
}
for (var j = 0; j < hidden; j++) {
for (var i = 0; i < n; i++) {
b1![j] -= lr * gradOut[i][j];
}
}
for (var j = 0; j < m; j++) {
for (var i = 0; i < n; i++) {
b2![j] -= lr * gradOut[i][j];
}
}
}
}