TextDecoderBlock constructor

TextDecoderBlock(
  1. int embedSize,
  2. int numHeads,
  3. int blockSize
)

Implementation

TextDecoderBlock(int embedSize, int numHeads, int blockSize)
  : selfAttention = AFTAttention(
      embedSize,
      numHeads,
      blockSize,
      masked: true,
    ), // <--- CAUSAL
    norm1 = LayerNorm(embedSize),
    crossAttention = AFTAttention(
      embedSize,
      numHeads,
      blockSize,
    ), // <--- STANDARD
    norm2 = LayerNorm(embedSize),
    ff1 = Layer(embedSize, embedSize * 4, useGelu: true),
    ff2 = Layer(embedSize * 4, embedSize, useGelu: false),
    norm3 = LayerNorm(embedSize);