main function

void main()

Implementation

void main() {
  // --- Inline Helper Function for Complex Data ---

  void prepareComplexRnnData({
    required List<Tensor<Matrix>> inputs,
    required List<Tensor<Vector>> targets,
    required int numSamples,
    required int sequenceLength,
    required double startOffset,
  }) {
    Random noiseGenerator = Random();
    for (int i = 0; i < numSamples; i++) {
      Matrix sequence = <Vector>[];
      double start = startOffset + i * 0.5; // More spacing between samples
      for (int j = 0; j < sequenceLength; j++) {
        double timeStep = start + j * 0.1;
        // 1. A slow, strong yearly trend (period of ~365 days)
        double yearlyTrend = sin(timeStep * (2 * pi / 36.5));
        // 2. A medium, clear monthly trend (period of ~30 days)
        double monthlyTrend = 0.5 * cos(timeStep * (2 * pi / 3.0));
        // 3. A noisy, fast weekly pattern (period of 7 days)
        double weeklyNoise = 0.2 * sin(timeStep * (2 * pi / 0.7)) + (noiseGenerator.nextDouble() - 0.5) * 0.1;

        double finalValue = yearlyTrend + monthlyTrend + weeklyNoise;
        sequence.add(<double>[finalValue]);
      }
      inputs.add(Tensor<Matrix>(sequence));

      // Target is the next value in the sequence
      double finalTimeStep = start + sequenceLength * 0.1;
      double nextYearly = sin(finalTimeStep * (2 * pi / 36.5));
      double nextMonthly = 0.5 * cos(finalTimeStep * (2 * pi / 3.0));
      double nextWeekly = 0.2 * sin(finalTimeStep * (2 * pi / 0.7));
      targets.add(Tensor<Vector>(<double>[nextYearly + nextMonthly + nextWeekly]));
    }
  }

  double calculateMSE(SNetwork model, List<Tensor<Matrix>> testX, List<Tensor<Vector>> testY) {
    double totalLoss = 0.0;
    for (int i = 0; i < testX.length; i++) {
      totalLoss += mse(model.predict(testX[i]) as Tensor<Vector>, testY[i]).value;
    }
    return totalLoss / testX.length;
  }

  // Trial runner now accepts a total time limit for the entire training run.
  double runSingleTrial({
    required SNetwork model,
    required List<Tensor<Matrix>> trainX,
    required List<Tensor<Vector>> trainY,
    required List<Tensor<Matrix>> testX,
    required List<Tensor<Vector>> testY,
    required int maxEpochs,
    required double learningRate,
    required Duration? timeLimit,
  }) {
    model.call(trainX[0]);
    SGD optimizer = SGD(model.parameters, learningRate: learningRate);
    Stopwatch trainingStopwatch = Stopwatch()..start();
    int epochsCompleted = 0;

    for (int epoch = 0; epoch < maxEpochs; epoch++) {
      if (timeLimit != null && trainingStopwatch.elapsed > timeLimit) {
        print('  -> Time limit reached. Stopping after $epochsCompleted epochs.');
        break;
      }
      for (int i = 0; i < trainX.length; i++) {
        optimizer.zeroGrad();
        Tensor<Scalar> loss = mse(model.call(trainX[i]) as Tensor<Vector>, trainY[i]);
        loss.backward();
        optimizer.step();
      }
      epochsCompleted++;
    }
    trainingStopwatch.stop();
    print('  -> Trained for $epochsCompleted epochs in ${trainingStopwatch.elapsedMilliseconds}ms.');
    return calculateMSE(model, testX, testY);
  }

  // --- 1. Experiment Configuration ---
  print('🔬 Setting up TIME-FAIR comparison on a complex signal...');
  final List<List<int>> configurationsToTest = [
    [],        // Baseline: Standard LSTMLayer
    [7],       // 2-Tier: Aims to capture weekly patterns
    [30],      // 2-Tier: Aims to capture monthly patterns
    [7, 4],    // 3-Tier: Aims to capture weekly and monthly patterns (~7 days, ~28 days)
  ];

  // --- 2. Global Hyperparameters and Data ---
  int sequenceLength = 60; // Longer sequence to capture multiple cycles
  int hiddenSize = 8;
  int epochsForBaseline = 30;
  double learningRate = 0.03;
  int numTrainSamples = 400;
  int numTestSamples = 50;

  List<Tensor<Matrix>> trainX = [];
  List<Tensor<Vector>> trainY = [];
  List<Tensor<Matrix>> testX = [];
  List<Tensor<Vector>> testY = [];
  prepareComplexRnnData(inputs: trainX, targets: trainY, numSamples: numTrainSamples, sequenceLength: sequenceLength, startOffset: 0.0);
  prepareComplexRnnData(inputs: testX, targets: testY, numSamples: numTestSamples, sequenceLength: sequenceLength, startOffset: numTrainSamples * 0.5 + 50.0);

  print('📊 Complex Data Prepared. Calibrating time budget...');
  print('---');

  // --- 3. Calibration Step ---
  Duration timeBudget;
  SNetwork baselineModel = SNetwork([LSTMLayer(hiddenSize), DenseLayer(1)]);

  print('⏱️  Calibrating time budget with Standard LSTM for $epochsForBaseline epochs...');
  Stopwatch calibrationStopwatch = Stopwatch()..start();
  runSingleTrial(
    model: baselineModel, trainX: trainX, trainY: trainY, testX: testX, testY: testY,
    maxEpochs: epochsForBaseline, learningRate: learningRate, timeLimit: null,
  );
  calibrationStopwatch.stop();
  timeBudget = calibrationStopwatch.elapsed;
  print('✅ Time budget set to: ${timeBudget.inMilliseconds}ms');
  print('---');

  // --- 4. Main Comparison Loop ---
  Map<String, double> results = {};

  for (List<int> clockCycles in configurationsToTest) {
    Layer coreLayer;
    String modelName;
    if (clockCycles.isEmpty) {
      modelName = 'Standard LSTM';
      coreLayer = LSTMLayer(hiddenSize);
    } else {
      modelName = 'MultiTier LSTM (Cycles: $clockCycles)';
      coreLayer = MultiTierLSTMLayer(hiddenSize, tierClockCycles: clockCycles);
    }
    print('🏋️ Training $modelName with a time budget of ${timeBudget.inMilliseconds}ms...');
    SNetwork model = SNetwork([coreLayer, DenseLayer(1)]);
    int maxEpochs = (epochsForBaseline * 3).toInt();

    double finalTestLoss = runSingleTrial(
      model: model, trainX: trainX, trainY: trainY, testX: testX, testY: testY,
      maxEpochs: maxEpochs, learningRate: learningRate, timeLimit: timeBudget,
    );
    print('✅ Finished training. Final Test MSE: ${finalTestLoss.toStringAsFixed(8)}');
    print('---');
    results[modelName] = finalTestLoss;
  }

  // --- 5. Final Results Summary ---
  print('\n\n--- 🏆 FINAL TIME-FAIR COMPARISON RESULTS (Complex Signal) ---');
  print('Total training budget for each model: ${timeBudget.inMilliseconds}ms');
  print('--------------------------------------------------');
  String bestModel = '';
  double lowestLoss = double.infinity;

  results.forEach((modelName, loss) {
    print('${modelName.padRight(35)} | Final Test MSE: ${loss.toStringAsFixed(8)}');
    if (loss < lowestLoss) {
      lowestLoss = loss;
      bestModel = modelName;
    }
  });

  print('--------------------------------------------------');
  print('🥇 Best performing model (given equal time): $bestModel');
}