causalMask static method

Tensor causalMask(
  1. int size
)

Create a causal (lower triangular) attention mask.

Implementation

static Tensor causalMask(int size) {
  final data = Float32List(size * size);
  for (int i = 0; i < size; i++) {
    for (int j = 0; j < size; j++) {
      data[i * size + j] = j <= i ? 0.0 : -1e9;
    }
  }
  return Tensor(data, [size, size]);
}