main function

void main()

Implementation

void main() {
  print('--- Example 1: Basic Arithmetic ---');
  Tensor a = Tensor.fill([1, 1], 2.0);
  Tensor b = Tensor.fill([1, 1], 3.0);
  final c = a + b;
  c.backward();
  print('c: ${c.data[0]}  // Expected: 5.0');
  print('a: ${a.grad[0]}  // Expected: 1.0');
  print('b: ${b.grad[0]}  // Expected: 1.0');

  print('\n--- Example 2: Multiplication ---');
  a = Tensor.fill([1, 1], 2.0);
  b = Tensor.fill([1, 1], 3.0);
  final d = a * b;
  d.backward();
  print('d: ${d.data[0]}  // Expected: 6.0');
  print('a: ${a.grad[0]}  // Expected: 3.0');
  print('b: ${b.grad[0]}  // Expected: 2.0');

  print('\n Matrix multiplcation');
  // Matrix A (4x3)
  final matrixA = Tensor.fromList(
    [4, 3],
    [
      1, 2, 3, // row 0
      4, 5, 6, // row 1
      7, 8, 9, // row 2
      1, 1, 1, // row 3
    ],
  );

  // Matrix B (3x4)
  final matrixB = Tensor.fromList(
    [3, 4],
    [
      9, 8, 7, 6, // row 0
      5, 4, 3, 2, // row 1
      1, 2, 3, 4, // row 2
    ],
  );

  print('\n Matrix multiplcation');
  // final matMuled = matrixA.matmul(matrixB);
  final matMuled = matrixA * matrixB;
  // print('\n Matrix multiplcation');
  print('''y1: ${matMuled.printMatrix()}
                  22.0  22.0  22.0  22.0
                  67.0  64.0  61.0  58.0
                  112.0 106.0 100.0  94.0
                  15.0  14.0  13.0  12.0''');
  matMuled.backward();
  print('gradients: ${matMuled.printGradient()}');
  print('''Expected gradients:
 1.0 1.0 1.0 1.0
 1.0 1.0 1.0 1.0
 1.0 1.0 1.0 1.0
 1.0 1.0 1.0 1.0''');

  print('\n--- Example 3: Polynomial y = x^2 + 3x + 1 ---');
  final x1 = Tensor.fill([1, 1], 2.0);
  final y1 = (x1 * x1) + (x1 * 3.0) + 1.0;
  y1.backward();
  print("y1: backward executed");
  print('y1: ${y1.data[0]}  // Expected: 11.0');
  print('x1: ${x1.grad[0]}  // Expected: 7.0');

  print('\n--- Example 4: Power y = x^3 ---');
  final x2 = Tensor.fill([1, 1], 2.0);
  final y2 = x2.pow(3);
  y2.backward();
  print('y2: ${y2.data[0]}  // Expected: 8.0');
  print('x2: ${x2.grad[0]}  // Expected: 12.0');

  print('\n--- Example 5: Negative and Division y = -a / b ---');
  final a2 = Tensor.fill([1, 1], 4.0);
  final b2 = Tensor.fill([1, 1], 2.0);
  final y3 = (-a2) / b2;
  y3.backward();
  print('y3: ${y3.data[0]}  // Expected: -2.0');
  print('a2: ${a2.grad[0]}  // Expected: -0.5');
  print('b2: ${b2.grad[0]}  // Expected: 1.0');

  print('\n--- Example 6: Sigmoid Activation ---');
  final x3 = Tensor.fill([1, 1], 1.0);
  final y4 = x3.sigmoid();
  y4.backward();
  print('y4: ${y4.data[0].toStringAsFixed(4)}  // Expected ≈ 0.7311');
  print('x3: ${x3.grad[0].toStringAsFixed(4)}  // Expected grad ≈ 0.1966');

  print('\n--- Example 7: ReLU Activation (x < 0) ---');
  final x4 = Tensor.fill([1, 1], -2.0);
  final y5 = x4.relu();
  y5.backward();
  print('y5: ${y5.data[0]}  // Expected: 0.0');
  print('x4: ${x4.grad[0]}  // Expected: 0.0');

  print('\n--- Example 8: ReLU Activation (x > 0) ---');
  final x5 = Tensor.fill([1, 1], 3.0);
  final y6 = x5.relu();
  y6.backward();
  print('y6: ${y6.data[0]}  // Expected: 3.0');
  print('x5: ${x5.grad[0]}  // Expected: 1.0');

  print('\n--- Example 9: Composite Expression y = sigmoid(a * x + b) * c ---');
  final xc = Tensor.fill([1, 1], 2.0);
  final ac = Tensor.fill([1, 1], 3.0);
  final bc = Tensor.fill([1, 1], 1.0);
  final cc = Tensor.fill([1, 1], -1.0);
  final yc = ((ac * xc + bc).sigmoid()) * cc;
  yc.backward();
  print('yc: ${yc.data[0].toStringAsFixed(5)}');
  print('xc: ${xc.grad[0].toStringAsFixed(5)}  // Expected ≈ -0.00273');
  print('ac: ${ac.grad[0].toStringAsFixed(5)}  // Expected ≈ -0.00182');
  print('bc: ${bc.grad[0].toStringAsFixed(5)}  // Expected ≈ -0.00091');
  print('cc: ${cc.grad[0].toStringAsFixed(5)}  // Expected ≈ 0.99909');

  print('\n--- Example 10: Quadratic Loss = (yTrue - yPred)^2 ---');
  final x6 = Tensor.fill([1, 1], 2.0);
  final w = Tensor.fill([1, 1], 3.0);
  final yPred = w * x6;
  final yTrue = Tensor.fill([1, 1], 10.0);
  final loss = (yTrue - yPred).pow(2);
  loss.backward();
  print('loss: ${loss.data[0]}  // Expected: 16.0');
  print('x6: ${x6.grad[0]}  // Expected: -24');
  print('w : ${w.grad[0]}  // Expected: -16');

  print('\n--- Example 11: Chain Rule ---');
  final x7 = Tensor.fill([1, 1], 2.0);
  final y7 = x7 * 3.0;
  final z7 = y7 + 5.0;
  final out7 = z7.pow(2);
  out7.backward();
  print('out7: ${out7.data[0]}  // Expected: 121.0');
  print('x7: ${x7.grad[0]}  // Expected: 66.0');

  print('\n--- Example 12: Simple Addition with Negation ---');
  final x8 = Tensor.fill([1, 1], 5.0);
  final y8 = Tensor.fill([1, 1], 3.0);
  final z8 = -(x8 + y8);
  z8.backward();
  print('z8: ${z8.data[0]}  // Expected: -8.0');
  print('x8: ${x8.grad[0]}  // Expected: -1.0');
  print('y8: ${y8.grad[0]}  // Expected: -1.0');

  print('\n--- Example 13: Chain of Operations (x + 2) * (y - 1) ---');
  final x9 = Tensor.fill([1, 1], 4.0);
  final y9 = Tensor.fill([1, 1], 6.0);
  final z9 = (x9 + 2.0) * (y9 - 1.0);
  z9.backward();
  print('z9: ${z9.data[0]}  // Expected: 30.0');
  print('x9: ${x9.grad[0]}  // Expected: 5.0');
  print('y9: ${y9.grad[0]}  // Expected: 6.0');

  print('\n--- Example 14: More Complex Expression ---');
  final x10 = Tensor.fill([1, 1], 1.0);
  final y10 = x10 * 2.0;
  // print('y10: ${y10.data[0]}  // Expected: 2.0');
  final z10 = (y10 + 3.0).pow(2);
  // print('z10: ${z10.data[0]}  // Expected: 25.0');
  // final out10 = z10 / Tensor.fill([1, 1], 4.0);
  final out10 = z10 / 4;
  out10.backward();
  print('out10: ${out10.data[0]}  // Expected: 6.25');
  print('x10: ${x10.grad[0]}  // Expected: 5.0');

  print('\n--- Example 15: Reshape and Backprop ---');
  // Start with a 1x4 vector
  final x11 = Tensor.fill([1, 4], 2.0);
  // Reshape to 2x2
  final reshaped = x11.reshape([2, 2]);
  // Perform an operation on the reshaped version
  final out11 = reshaped * 3.0;

  out11.backward();

  print('reshaped shape: ${reshaped.shape}; // Expected: [2, 2]');
  print('out11 data[0]: ${out11.data[0]}; // Expected: 6.0');
  print('x11 grad[0]: ${x11.grad[0]}; // Expected: 3.0');
}