computeCrossEntropy method
Implementation
Tensor computeCrossEntropy(
Tensor pred,
List<double> targets,
int vocabSize,
List<Tensor> tracker,
) {
final targetsMatrix = Tensor.fromList(shape, targets);
// logits are [T, V]
final int T = targetsMatrix.shape[0];
final int V = targetsMatrix.shape[1];
if (pred.shape[0] != T) {
throw ArgumentError(
"Target length ${targets.length} must match Logits T: $T",
);
}
int numTokens = targets.length;
double totalLoss = 0;
for (int t = 0; t < numTokens; t++) {
int target = targets[t].toInt();
int offset = t * vocabSize;
print("offset: $offset");
double maxL = -double.infinity;
for (int v = 0; v < vocabSize; v++) {
if (pred.data[offset + v] > maxL) maxL = pred.data[offset + v];
}
double sumExp = 0;
for (int v = 0; v < vocabSize; v++) {
sumExp += math.exp(pred.data[offset + v] - maxL);
}
totalLoss +=
(maxL + math.log(sumExp + 1e-12) - pred.data[offset + target]);
}
// 1. logPred = log(pred)
final logPred = pred.log();
// final targetsMatrix = Tensor.fromList(shape, targets);
// 2. product = target * log(pred)
final product = targetsMatrix * logPred;
// 3. Apply negative sign: loss = product * -1
// We'll create a constant tensor for -1.0
final negOne = Tensor.fill(product.shape, -1.0);
final loss = product * negOne;
// Track all temporaries for manual disposal later
tracker.addAll([logPred, product, negOne, targetsMatrix]);
return loss;
}