main function
void
main()
Implementation
void main() {
// --- 1. DATA LOADING ---
String hemingwayText = File('lib/examples/hemingway.txt').readAsStringSync();
String shakespeareText = File('lib/examples/shakespear.txt').readAsStringSync();
Logger.cyan('Step 1: Loaded text files.', prefix: 'โ
');
// --- 2. PREPROCESSING ---
String allText = hemingwayText + " " + shakespeareText;
String cleanedText = allText.toLowerCase().replaceAll(RegExp(r'[^\w\s]+'), '');
Set<String> uniqueWords = Set<String>.from(cleanedText.split(RegExp(r'\s+')));
Map<String, int> vocabulary = {'<pad>': 0};
int index = 1;
for (String word in uniqueWords) {
if (word.isNotEmpty) {
vocabulary[word] = index++;
}
}
Logger.cyan('Step 2: Built vocabulary with ${vocabulary.length} unique words.', prefix: 'โ
');
List<Map<String, dynamic>> dataset = [];
int maxSequenceLength = 25;
for (String sentence in hemingwayText.split(RegExp(r'[\.!?]'))) {
if (sentence.trim().length > 5) {
dataset.add({'text': sentence, 'label': 0.0}); // 0.0 for Hemingway
}
}
for (String sentence in shakespeareText.split(RegExp(r'[\.!?]'))) {
if (sentence.trim().length > 5) {
dataset.add({'text': sentence, 'label': 1.0}); // 1.0 for Shakespeare
}
}
Logger.cyan('Step 3: Created dataset with ${dataset.length} sentences.', prefix: 'โ
');
// --- 3. TRAIN/TEST SPLIT ---
dataset.shuffle();
int splitIndex = (dataset.length * 0.8).floor();
List<Map<String, dynamic>> trainData = dataset.sublist(0, splitIndex);
List<Map<String, dynamic>> testData = dataset.sublist(splitIndex);
List<Vector> trainInputs = [];
List<Vector> trainTargets = [];
for (Map<String, dynamic> item in trainData) {
trainInputs.add(preprocessSentence(item['text'] as String, vocabulary, maxSequenceLength));
trainTargets.add([item['label'] as double]);
}
List<Vector> testInputs = [];
List<Vector> testTargets = [];
for (Map<String, dynamic> item in testData) {
testInputs.add(preprocessSentence(item['text'] as String, vocabulary, maxSequenceLength));
testTargets.add([item['label'] as double]);
}
Logger.cyan('Step 4: Split data into ${trainInputs.length} training and ${testInputs.length} test samples.', prefix: 'โ
');
// --- 4. MODEL DEFINITION & TRAINING ---
int vocabSize = vocabulary.length;
int dModel = 16;
int numHeads = 4;
int dff = 32;
int numEncoderBlocks = 1;
SNetwork styleClassifier = SNetwork([
EmbeddingLayer(vocabSize, dModel),
PositionalEncoding(maxSequenceLength, dModel),
TransformerEncoderBlock(dModel, numHeads, dff),
TransformerEncoderBlock(dModel, numHeads, dff),
GlobalAveragePooling1D(),
DenseLayer(1, activation: Sigmoid()),
]);
styleClassifier.predict(Tensor<Vector>(trainInputs[0]));
styleClassifier.compile(
configuredOptimizer: Adam(styleClassifier.parameters, learningRate: 0.001)
);
styleClassifier.fit(trainInputs, trainTargets, epochs: 10, debug: true);
// --- 5. EVALUATION ---
Logger.green('\n--- FINAL EVALUATION ON UNSEEN TEST DATA ---', prefix: '๐');
styleClassifier.evaluate(testInputs, testTargets);
// --- 6. TESTING ON NEW SENTENCES ---
Logger.blue('\n--- TESTING ON NEW SENTENCES ---', prefix: '๐งช');
List<String> testSentences = [
"the old man and the sea",
"wherefore art thou",
"a farewell to arms",
"o happy dagger this is thy sheath",
"death which cannot choose"
];
// Correct labels for the sentences above
List<String> correctLabels = [
"Hemingway",
"Shakespeare",
"Hemingway",
"Shakespeare",
"Hemingway"
];
Tensor<Vector>? lastPrediction;
for (int i = 0; i < testSentences.length; i++) {
String sentence = testSentences[i];
String correctLabel = correctLabels[i];
Vector tokenized = preprocessSentence(sentence, vocabulary, maxSequenceLength);
Tensor<Vector> inputTensor = Tensor<Vector>(tokenized);
Tensor<Vector> prediction = styleClassifier.predict(inputTensor) as Tensor<Vector>;
lastPrediction = prediction; // Save for graph printing
// The model outputs > 0.5 for Shakespeare (label 1.0)
String predictedLabel = (prediction.value[0] > 0.5) ? "Shakespeare" : "Hemingway";
String resultEmoji = (predictedLabel == correctLabel) ? 'โ
Correct' : 'โ Incorrect';
print('Input: "$sentence"');
print(' -> Prediction: $predictedLabel (Raw: ${prediction.value[0].toStringAsFixed(4)})');
print(' -> Correct: $correctLabel');
print(' -> Result: $resultEmoji\n');
}
// Print graph for the last sentence processed
if (lastPrediction != null) {
print('--- Computation Graph for last prediction ---');
lastPrediction.printGraph();
lastPrediction.printParallelGraph();
}
}