forward method

  1. @override
Tensor<Tensor3D> forward(
  1. Tensor input
)
override

The core logic of the layer's transformation.

Subclasses must implement this method to define how they process input tensors and return an output tensor.

Implementation

@override
Tensor<Tensor3D> forward(Tensor<dynamic> input) {
  Matrix batchIndices = (input as Tensor<Matrix>).value;
  Tensor3D outputBatch = [];

  for (Vector wordIndices in batchIndices) {
    Matrix outputSequence = [];
    for (double indexDouble in wordIndices) {
      int index = indexDouble.toInt();
      outputSequence.add(embeddings.value[index]);
    }
    outputBatch.add(outputSequence);
  }

  Tensor<Tensor3D> out = Tensor<Tensor3D>(outputBatch);

  out.creator = Node([embeddings], () {
    for (int b = 0; b < batchIndices.length; b++) {
      for (int i = 0; i < batchIndices[b].length; i++) {
        int index = batchIndices[b][i].toInt();
        for (int j = 0; j < embeddingDimension; j++) {
          embeddings.grad[index][j] += out.grad[b][i][j];
        }
      }
    }
  }, opName: 'embedding_lookup_batch', cost: 0);

  return out;
}