TransformerDecoderBlock constructor

TransformerDecoderBlock(
  1. int embedSize,
  2. int numHeads,
  3. int encoderEmbedSize,
  4. int maxSeqLen,
)

Implementation

TransformerDecoderBlock(
  this.embedSize,
  int numHeads,
  int encoderEmbedSize,
  int maxSeqLen,
) : selfAttention = MultiHeadAFT(
      numHeads,
      embedSize,
      maxSeqLen,
      masked: true,
    ),
    crossAttention = MultiHeadAFTCross(
      numHeads,
      embedSize,
      encoderEmbedSize,
      maxSeqLen,
      maxSeqLen,
    ),
    ffn = FeedForward(embedSize),
    ln1 = LayerNorm(embedSize),
    ln2 = LayerNorm(embedSize),
    ln3 = LayerNorm(embedSize) {
  _initializeWeights();
}