BartDecoder constructor

BartDecoder({
  1. int decoderLayers = 4,
  2. int maxPositionEmbeddings = 1536,
  3. int vocabSize = 57522,
  4. int embedDim = 1024,
  5. int ffnDim = 4096,
  6. int numHeads = 16,
})

Implementation

BartDecoder({
  this.decoderLayers = 4,
  this.maxPositionEmbeddings = 1536,
  this.vocabSize = 57522,
  this.embedDim = 1024,
  this.ffnDim = 4096,
  this.numHeads = 16,
}) {
  embedTokens = Embedding(vocabSize, embedDim);
  // +2 for BART's position embedding offset
  embedPositions = Embedding(maxPositionEmbeddings + 2, embedDim);
  layerNorm = LayerNorm(embedDim);
  lmHead = Linear(embedDim, vocabSize, useBias: false);
  embedScale = math.sqrt(embedDim.toDouble());

  layers = List.generate(
    decoderLayers,
    (i) => BartDecoderLayer(
      embedDim: embedDim,
      ffnDim: ffnDim,
      numHeads: numHeads,
    ),
  );
}