main function
void
main()
Implementation
void main() {
// --- Inline Helper Function for Complex Data ---
void prepareComplexRnnData({
required List<Tensor<Matrix>> inputs,
required List<Tensor<Vector>> targets,
required int numSamples,
required int sequenceLength,
required double startOffset,
}) {
Random noiseGenerator = Random();
for (int i = 0; i < numSamples; i++) {
Matrix sequence = <Vector>[];
double start = startOffset + i * 0.5; // More spacing between samples
for (int j = 0; j < sequenceLength; j++) {
double timeStep = start + j * 0.1;
// 1. A slow, strong yearly trend (period of ~365 days)
double yearlyTrend = sin(timeStep * (2 * pi / 36.5));
// 2. A medium, clear monthly trend (period of ~30 days)
double monthlyTrend = 0.5 * cos(timeStep * (2 * pi / 3.0));
// 3. A noisy, fast weekly pattern (period of 7 days)
double weeklyNoise = 0.2 * sin(timeStep * (2 * pi / 0.7)) + (noiseGenerator.nextDouble() - 0.5) * 0.1;
double finalValue = yearlyTrend + monthlyTrend + weeklyNoise;
sequence.add(<double>[finalValue]);
}
inputs.add(Tensor<Matrix>(sequence));
// Target is the next value in the sequence
double finalTimeStep = start + sequenceLength * 0.1;
double nextYearly = sin(finalTimeStep * (2 * pi / 36.5));
double nextMonthly = 0.5 * cos(finalTimeStep * (2 * pi / 3.0));
double nextWeekly = 0.2 * sin(finalTimeStep * (2 * pi / 0.7));
targets.add(Tensor<Vector>(<double>[nextYearly + nextMonthly + nextWeekly]));
}
}
double calculateMSE(SNetwork model, List<Tensor<Matrix>> testX, List<Tensor<Vector>> testY) {
double totalLoss = 0.0;
for (int i = 0; i < testX.length; i++) {
totalLoss += mse(model.predict(testX[i]) as Tensor<Vector>, testY[i]).value;
}
return totalLoss / testX.length;
}
// Trial runner now accepts a total time limit for the entire training run.
double runSingleTrial({
required SNetwork model,
required List<Tensor<Matrix>> trainX,
required List<Tensor<Vector>> trainY,
required List<Tensor<Matrix>> testX,
required List<Tensor<Vector>> testY,
required int maxEpochs,
required double learningRate,
required Duration? timeLimit,
}) {
model.call(trainX[0]);
SGD optimizer = SGD(model.parameters, learningRate: learningRate);
Stopwatch trainingStopwatch = Stopwatch()..start();
int epochsCompleted = 0;
for (int epoch = 0; epoch < maxEpochs; epoch++) {
if (timeLimit != null && trainingStopwatch.elapsed > timeLimit) {
print(' -> Time limit reached. Stopping after $epochsCompleted epochs.');
break;
}
for (int i = 0; i < trainX.length; i++) {
optimizer.zeroGrad();
Tensor<Scalar> loss = mse(model.call(trainX[i]) as Tensor<Vector>, trainY[i]);
loss.backward();
optimizer.step();
}
epochsCompleted++;
}
trainingStopwatch.stop();
print(' -> Trained for $epochsCompleted epochs in ${trainingStopwatch.elapsedMilliseconds}ms.');
return calculateMSE(model, testX, testY);
}
// --- 1. Experiment Configuration ---
print('🔬 Setting up TIME-FAIR comparison on a complex signal...');
final List<List<int>> configurationsToTest = [
[], // Baseline: Standard LSTMLayer
[7], // 2-Tier: Aims to capture weekly patterns
[30], // 2-Tier: Aims to capture monthly patterns
[7, 4], // 3-Tier: Aims to capture weekly and monthly patterns (~7 days, ~28 days)
];
// --- 2. Global Hyperparameters and Data ---
int sequenceLength = 60; // Longer sequence to capture multiple cycles
int hiddenSize = 8;
int epochsForBaseline = 30;
double learningRate = 0.03;
int numTrainSamples = 400;
int numTestSamples = 50;
List<Tensor<Matrix>> trainX = [];
List<Tensor<Vector>> trainY = [];
List<Tensor<Matrix>> testX = [];
List<Tensor<Vector>> testY = [];
prepareComplexRnnData(inputs: trainX, targets: trainY, numSamples: numTrainSamples, sequenceLength: sequenceLength, startOffset: 0.0);
prepareComplexRnnData(inputs: testX, targets: testY, numSamples: numTestSamples, sequenceLength: sequenceLength, startOffset: numTrainSamples * 0.5 + 50.0);
print('📊 Complex Data Prepared. Calibrating time budget...');
print('---');
// --- 3. Calibration Step ---
Duration timeBudget;
SNetwork baselineModel = SNetwork([LSTMLayer(hiddenSize), DenseLayer(1)]);
print('⏱️ Calibrating time budget with Standard LSTM for $epochsForBaseline epochs...');
Stopwatch calibrationStopwatch = Stopwatch()..start();
runSingleTrial(
model: baselineModel, trainX: trainX, trainY: trainY, testX: testX, testY: testY,
maxEpochs: epochsForBaseline, learningRate: learningRate, timeLimit: null,
);
calibrationStopwatch.stop();
timeBudget = calibrationStopwatch.elapsed;
print('✅ Time budget set to: ${timeBudget.inMilliseconds}ms');
print('---');
// --- 4. Main Comparison Loop ---
Map<String, double> results = {};
for (List<int> clockCycles in configurationsToTest) {
Layer coreLayer;
String modelName;
if (clockCycles.isEmpty) {
modelName = 'Standard LSTM';
coreLayer = LSTMLayer(hiddenSize);
} else {
modelName = 'MultiTier LSTM (Cycles: $clockCycles)';
coreLayer = MultiTierLSTMLayer(hiddenSize, tierClockCycles: clockCycles);
}
print('🏋️ Training $modelName with a time budget of ${timeBudget.inMilliseconds}ms...');
SNetwork model = SNetwork([coreLayer, DenseLayer(1)]);
int maxEpochs = (epochsForBaseline * 3).toInt();
double finalTestLoss = runSingleTrial(
model: model, trainX: trainX, trainY: trainY, testX: testX, testY: testY,
maxEpochs: maxEpochs, learningRate: learningRate, timeLimit: timeBudget,
);
print('✅ Finished training. Final Test MSE: ${finalTestLoss.toStringAsFixed(8)}');
print('---');
results[modelName] = finalTestLoss;
}
// --- 5. Final Results Summary ---
print('\n\n--- 🏆 FINAL TIME-FAIR COMPARISON RESULTS (Complex Signal) ---');
print('Total training budget for each model: ${timeBudget.inMilliseconds}ms');
print('--------------------------------------------------');
String bestModel = '';
double lowestLoss = double.infinity;
results.forEach((modelName, loss) {
print('${modelName.padRight(35)} | Final Test MSE: ${loss.toStringAsFixed(8)}');
if (loss < lowestLoss) {
lowestLoss = loss;
bestModel = modelName;
}
});
print('--------------------------------------------------');
print('🥇 Best performing model (given equal time): $bestModel');
}