main function

void main()

Implementation

void main() {
  print("--- Transformer Encoder Example ---");

  final vocabSize = 20;
  final embedSize = 32;
  final blockSize = 10;
  final numLayers = 2;
  final numHeads = 4;

  // Initialize the Transformer Encoder model
  final encoder = TransformerEncoder(
    vocabSize: vocabSize,
    embedSize: embedSize,
    blockSize: blockSize,
    numLayers: numLayers,
    numHeads: numHeads,
  );

  // You would typically define a custom loss function and optimizer for encoder tasks
  // For demonstration, let's just show a forward pass.
  // In a real scenario, you might have a downstream task (e.g., classification)
  // that provides a loss to backpropagate through the encoder.

  // Sample input sequence (e.g., token IDs)
  // Let's say input is a sentence like "the cat sat on the mat"
  // and we map words to integer IDs: [1, 5, 8, 2, 1, 9]
  final sampleInputSequence = [1, 5, 8, 2, 1, 9];
  print("Input sequence: $sampleInputSequence");

  // Perform a forward pass to get contextualized embeddings
  final encodedEmbeddings = encoder.forward(sampleInputSequence);

  print(
      "\nEncoded Embeddings (first token's data - only showing first 5 values):");
  if (encodedEmbeddings.isNotEmpty) {
    print(encodedEmbeddings[0]
        .values
        .sublist(0, min(5, encodedEmbeddings[0].values.length))
        .map((v) => v.data.toStringAsFixed(4))
        .toList());
    print(
        "Shape of encoded embeddings: (${encodedEmbeddings.length}, ${encodedEmbeddings[0].values.length})");
    assert(encodedEmbeddings.length == sampleInputSequence.length);
    assert(encodedEmbeddings[0].values.length == embedSize);
    print("Encoder output shape is correct.");
  } else {
    print("No encoded embeddings produced.");
  }

  // To demonstrate backpropagation (e.g., for a classification task)
  // Let's assume we want the first token's embedding to be close to some target vector
  if (encodedEmbeddings.isNotEmpty) {
    print("\n--- Dummy Training Step for Encoder ---");
    // Dummy target for the first token's embedding (e.g., for a classification head attached to it)
    final dummyTarget = ValueVector.fromDoubleList(
        List.generate(embedSize, (i) => Random().nextDouble() * 2 - 1));

    // Simple mean squared error loss for demonstration
    Value dummyLoss = Value(0.0);
    for (int i = 0; i < embedSize; i++) {
      dummyLoss +=
          (encodedEmbeddings[0].values[i] - dummyTarget.values[i]).pow(2);
    }
    dummyLoss = dummyLoss / Value(embedSize.toDouble()); // Mean squared error

    print("Initial Dummy Loss: ${dummyLoss.data.toStringAsFixed(4)}");

    // Zero gradients, backward pass, and optimizer step
    encoder.zeroGrad();
    dummyLoss.backward();

    final optimizer = SGD(encoder.parameters(), 0.01);
    optimizer.step();

    // Re-run forward pass to see the effect of the update
    final encodedEmbeddingsAfterUpdate = encoder.forward(sampleInputSequence);
    Value dummyLossAfterUpdate = Value(0.0);
    for (int i = 0; i < embedSize; i++) {
      dummyLossAfterUpdate +=
          (encodedEmbeddingsAfterUpdate[0].values[i] - dummyTarget.values[i])
              .pow(2);
    }
    dummyLossAfterUpdate = dummyLossAfterUpdate / Value(embedSize.toDouble());

    print(
        "Dummy Loss After 1 Step: ${dummyLossAfterUpdate.data.toStringAsFixed(4)}");
    print(
        "This shows the encoder's parameters are updated based on a downstream loss.");
  }
}