main function

void main()

Implementation

void main() async {
  // 1. Create a dummy Transformer model
  final transformer = EncoderDecoderTransformer(
    sourceVocabSize: 100,
    targetVocabSize: 100,
    embedSize: 32,
    sourceBlockSize: 10,
    targetBlockSize: 10,
    numLayers: 2,
    numHeads: 4,
  );

  print('Initial parameters before saving:');
  final initialParams = transformer.parameters();
  // Print first few parameters to verify
  for (int i = 0;
      i < (initialParams.length > 5 ? 5 : initialParams.length);
      i++) {
    print('Param $i: ${initialParams[i].data}');
  }
  print('Total initial parameters: ${initialParams.length}');

  final String filePath = 'transformer_weights.json';

  // 2. Save the weights to a file
  await saveModuleParameters(transformer, filePath);

  // 3. (Optional) Modify weights or create a new model instance
  // For demonstration, let's just re-initialize the transformer to simulate
  // loading into a fresh model or after training.
  final newTransformer = EncoderDecoderTransformer(
    sourceVocabSize: 100,
    targetVocabSize: 100,
    embedSize: 32,
    sourceBlockSize: 10,
    targetBlockSize: 10,
    numLayers: 2,
    numHeads: 4,
  );

  print('\nParameters of new transformer before loading:');
  final newParamsBeforeLoad = newTransformer.parameters();
  for (int i = 0;
      i < (newParamsBeforeLoad.length > 5 ? 5 : newParamsBeforeLoad.length);
      i++) {
    print('Param $i: ${newParamsBeforeLoad[i].data}');
  }

  // 4. Load the weights from the file into the new model
  await loadModuleParameters(newTransformer, filePath);

  print('\nParameters of new transformer after loading:');
  final newParamsAfterLoad = newTransformer.parameters();
  for (int i = 0;
      i < (newParamsAfterLoad.length > 5 ? 5 : newParamsAfterLoad.length);
      i++) {
    print('Param $i: ${newParamsAfterLoad[i].data}');
  }

  // Verify that the loaded weights are the same as the initial saved weights
  bool weightsMatch = true;
  if (initialParams.length != newParamsAfterLoad.length) {
    weightsMatch = false;
  } else {
    for (int i = 0; i < initialParams.length; i++) {
      if (initialParams[i].data != newParamsAfterLoad[i].data) {
        weightsMatch = false;
        break;
      }
    }
  }
  print('\nDo initial and loaded weights match? $weightsMatch');

  // Example of using the transformer for a forward pass (requires actual implementation for sub-modules)
  // try {
  //   List<int> sourceInput = List.generate(10, (index) => index % 100);
  //   List<int> targetInput = List.generate(8, (index) => index % 100);
  //   final logits = newTransformer.forward(sourceInput, targetInput);
  //   print('Forward pass successful. Logits length: ${logits.length}');
  // } catch (e) {
  //   print('Error during forward pass: $e');
  // }
}