computeCrossEntropy method

Tensor computeCrossEntropy(
  1. Tensor pred,
  2. List<double> targets,
  3. int vocabSize,
  4. List<Tensor> tracker,
)

Implementation

Tensor computeCrossEntropy(
  Tensor pred,
  List<double> targets,
  int vocabSize,
  List<Tensor> tracker,
) {
  final targetsMatrix = Tensor.fromList(shape, targets);
  // logits are [T, V]
  final int T = targetsMatrix.shape[0];
  final int V = targetsMatrix.shape[1];

  if (pred.shape[0] != T) {
    throw ArgumentError(
      "Target length ${targets.length} must match Logits T: $T",
    );
  }

  int numTokens = targets.length;
  double totalLoss = 0;
  for (int t = 0; t < numTokens; t++) {
    int target = targets[t].toInt();
    int offset = t * vocabSize;
    print("offset: $offset");
    double maxL = -double.infinity;
    for (int v = 0; v < vocabSize; v++) {
      if (pred.data[offset + v] > maxL) maxL = pred.data[offset + v];
    }
    double sumExp = 0;
    for (int v = 0; v < vocabSize; v++) {
      sumExp += math.exp(pred.data[offset + v] - maxL);
    }
    totalLoss +=
        (maxL + math.log(sumExp + 1e-12) - pred.data[offset + target]);
  }

  // 1. logPred = log(pred)
  final logPred = pred.log();
  // final targetsMatrix = Tensor.fromList(shape, targets);
  // 2. product = target * log(pred)
  final product = targetsMatrix * logPred;

  // 3. Apply negative sign: loss = product * -1
  // We'll create a constant tensor for -1.0
  final negOne = Tensor.fill(product.shape, -1.0);
  final loss = product * negOne;

  // Track all temporaries for manual disposal later
  tracker.addAll([logPred, product, negOne, targetsMatrix]);

  return loss;
}