decode method
void
decode()
Implementation
void decode() {
Logger.yellow('--- Decoding Execution Tape (${tape.length} bytes) ---', prefix: '📜');
_offset = 0;
while (_offset < tape.length) {
int opCode = _readInt();
String opName = _getOpName(opCode);
switch (opCode) {
// --- Data Loading & Storage ---
case OP_LOAD_SAMPLE:
String nameIn = _readString();
String nameOut = _readString();
int sampleIdx = _readInt();
Logger.cyan('$opName: $nameOut = load_sample($nameIn, index: $sampleIdx)');
break;
case OP_STORE_SAMPLE:
String nameIn = _readString();
String nameDest = _readString();
int sampleIdx = _readInt();
Logger.cyan('$opName: store_sample($nameIn -> $nameDest, index: $sampleIdx)');
break;
case OP_COPY:
String nameIn = _readString();
String nameOut = _readString();
Logger.cyan('$opName: $nameOut = copy($nameIn)');
break;
case OP_FILL:
String nameOut = _readString();
double value = _readFloat();
Logger.log('$opName: fill($nameOut, $value)');
break;
case OP_ZERO_GRAD:
String nameZero = _readString();
Logger.yellow('$opName: zero_grad($nameZero)');
break;
// --- Basic Arithmetic ---
case OP_ADD:
case OP_SUBTRACT:
case OP_MULTIPLY:
case OP_DIVIDE:
case OP_BROADCAST_ADD:
String nameA = _readString();
String nameB = _readString();
String nameC = _readString();
Logger.cyan('$opName: $nameC = $nameA op $nameB');
break;
case OP_ADD_INTO:
case OP_SUBTRACT_INTO:
String nameSrc = _readString();
String nameDest = _readString();
Logger.cyan('$opName: $nameDest += $nameSrc');
break;
case OP_MULTIPLY_BACKWARD:
String gradOut = _readString();
String otherIn = _readString();
String gradIn = _readString();
Logger.blue('$opName: $gradIn += $gradOut * $otherIn');
break;
case OP_DIVIDE_BACKWARD:
String nameA = _readString();
String nameB = _readString();
String gradOut = _readString();
String gradA = _readString();
String gradB = _readString();
Logger.blue('$opName: gradA=$gradA, gradB=$gradB from div($nameA, $nameB, $gradOut)');
break;
case OP_ADD_SCALAR:
String nameInAddS = _readString();
String nameOutAddS = _readString();
double scalarAdd = _readFloat();
Logger.cyan('$opName: $nameOutAddS = $nameInAddS + $scalarAdd');
break;
// --- New Element-wise Math ---
case OP_LOG_ELEMENTWISE:
case OP_ABS_ELEMENTWISE:
case OP_SQRT_ELEMENTWISE:
case OP_EXP_ELEMENTWISE:
String nameInMath = _readString();
String nameOutMath = _readString();
Logger.cyan('$opName: $nameOutMath = op($nameInMath)');
break;
case OP_LOG_BACKWARD:
case OP_ABS_BACKWARD:
String gradOutMath = _readString();
String inDataMath = _readString();
String gradInMath = _readString();
Logger.blue('$opName: $gradInMath = bw($gradOutMath, inData: $inDataMath)');
break;
case OP_EXP_BACKWARD:
case OP_SQRT_BACKWARD:
String gradOutSE = _readString();
String outDataSE = _readString();
String gradInSE = _readString();
Logger.blue('$opName: $gradInSE = bw($gradOutSE, outData: $outDataSE)');
break;
case OP_POW_ELEMENTWISE:
String nameInPow = _readString();
String nameOutPow = _readString();
double exponent = _readFloat();
Logger.cyan('$opName: $nameOutPow = pow($nameInPow, exp: $exponent)');
break;
case OP_POW_BACKWARD:
String gradOutPow = _readString();
String inDataPow = _readString();
String gradInPow = _readString();
double expBw = _readFloat();
Logger.blue('$opName: $gradInPow = pow_bw($gradOutPow, inData: $inDataPow, exp: $expBw)');
break;
case OP_CLAMP_ELEMENTWISE:
String nameInClamp = _readString();
String nameOutClamp = _readString();
double minVal = _readFloat();
double maxVal = _readFloat();
Logger.cyan('$opName: $nameOutClamp = clamp($nameInClamp, min: $minVal, max: $maxVal)');
break;
case OP_CLAMP_BACKWARD:
String gradOutClamp = _readString();
String inDataClamp = _readString();
String gradInClamp = _readString();
double minValBw = _readFloat();
double maxValBw = _readFloat();
Logger.blue('$opName: $gradInClamp = clamp_bw($gradOutClamp, in: $inDataClamp, min: $minValBw, max: $maxValBw)');
break;
// --- Matrix Operations ---
case OP_MATMUL:
String nameaMm = _readString();
String namebMm = _readString();
String namecMm = _readString();
bool transA = _readBool();
bool transB = _readBool();
double alpha = _readFloat();
double beta = _readFloat();
bool tCores = _readBool();
Logger.cyan('$opName: $namecMm = matmul($nameaMm, $namebMm) transA:$transA transB:$transB');
break;
case OP_TRANSPOSE:
String nameInTr = _readString();
String nameOutTr = _readString();
Logger.cyan('$opName: $nameOutTr = transpose($nameInTr)');
break;
case OP_SCALE_MATRIX:
String nameInSM = _readString();
String nameOutSM = _readString();
double scalarSM = _readFloat();
Logger.cyan('$opName: $nameOutSM = scale($nameInSM, val: $scalarSM)');
break;
case OP_SCALE_MATRIX_BACKWARD:
String gradOutSMB = _readString();
String gradInSMB = _readString();
double scaleSMB = _readFloat();
Logger.blue('$opName: $gradInSMB += $gradOutSMB * $scaleSMB');
break;
// --- Activations ---
case OP_RELU:
case OP_SIGMOID:
case OP_TANH:
case OP_GELU_FORWARD:
case OP_SOFTMAX_FORWARD:
String nameInAct = _readString();
String nameOutAct = _readString();
Logger.green('$opName: $nameOutAct = act($nameInAct)');
break;
case OP_RELU_BACKWARD:
case OP_GELU_BACKWARD:
String nameInBack = _readString();
String gradOutBack = _readString();
String gradInBack = _readString();
Logger.blue('$opName: $gradInBack = act_bw($nameInBack, $gradOutBack)');
break;
case OP_SIGMOID_BACKWARD:
case OP_TANH_BACKWARD:
String outData = _readString(); // C++ expects outData first here
String gradOutAct = _readString();
String gradInAct = _readString();
Logger.blue('$opName: $gradInAct = act_bw($gradOutAct, outData: $outData)');
break;
case OP_SOFTMAX_BACKWARD:
String gradOutSm = _readString(); // C++ expects gradOut first here
String outDataSm = _readString();
String gradInSm = _readString();
Logger.blue('$opName: $gradInSm = softmax_bw($gradOutSm, outData: $outDataSm)');
break;
// --- Loss Functions ---
case OP_MSE_LOSS_FORWARD:
case OP_BCE_LOSS_FORWARD:
String namePred = _readString();
String nameTarget = _readString();
String nameOutLoss = _readString();
Logger.green('$opName: $nameOutLoss = loss($namePred, target: $nameTarget)');
break;
case OP_MSE_LOSS_BACKWARD:
case OP_BCE_LOSS_BACKWARD:
String gradOutLoss = _readString();
String namePredLoss = _readString();
String nameTargetLoss = _readString();
String gradInLoss = _readString();
Logger.blue('$opName: $gradInLoss = loss_bw($gradOutLoss, pred: $namePredLoss, target: $nameTargetLoss)');
break;
// --- Optimizers ---
case OP_SGD_UPDATE:
String nameDataSgd = _readString();
String nameGradSgd = _readString();
double lrSgd = _readFloat();
Logger.yellow('$opName: sgd($nameDataSgd, grad: $nameGradSgd, lr: $lrSgd)');
break;
case OP_ADAM_UPDATE:
String nameDataAdam = _readString();
String nameGradAdam = _readString();
String nameM = _readString();
String nameV = _readString();
double lrAdam = _readFloat();
double b1 = _readFloat();
double b2 = _readFloat();
double epsAdam = _readFloat();
int step = _readInt();
double wd = _readFloat();
Logger.yellow('$opName: adam($nameDataAdam, step: $step, lr: $lrAdam)');
break;
case OP_CLIP_GRAD_VALUE:
String nameBufferClip = _readString();
double clipVal = _readFloat();
Logger.yellow('$opName: clip($nameBufferClip, val: $clipVal)');
break;
// --- Reductions ---
case OP_SUM_REDUCE:
case OP_SUM_REDUCE_COLUMNS:
case OP_SUM_REDUCE_ROWS:
String nameInReduce = _readString();
String nameOutReduce = _readString();
Logger.cyan('$opName: $nameOutReduce = reduce($nameInReduce)');
break;
case OP_SUM_REDUCE_BACKWARD:
String gradOutReduce = _readString();
String gradInReduce = _readString();
Logger.blue('$opName: $gradInReduce = reduce_bw($gradOutReduce)');
break;
case OP_EMBEDDING_FORWARD:
String nameIndicesEmb = _readString();
String nameWeightEmb = _readString();
String nameOutEmb = _readString();
Logger.cyan('$opName: $nameOutEmb = embedding($nameIndicesEmb, $nameWeightEmb)');
break;
case OP_EMBEDDING_BACKWARD:
String gradOutEmb = _readString();
String nameIndicesEmbBw = _readString();
String gradWeightEmb = _readString();
Logger.blue('$opName: $gradWeightEmb = embedding_bw($gradOutEmb, $nameIndicesEmbBw)');
break;
// --- Tensor Manipulation ---
case OP_SLICE_ROW:
case OP_SLICE_ROW_BACKWARD:
String nameInSlice = _readString();
String nameOutSlice = _readString();
int row = _readInt();
Logger.cyan('$opName: $nameOutSlice = slice_row($nameInSlice, row: $row)');
break;
case OP_SLICE_COLUMN:
case OP_SLICE_COLUMN_BACKWARD:
String nameInCol = _readString();
String nameOutCol = _readString();
int start = _readInt();
int end = _readInt();
Logger.cyan('$opName: $nameOutCol = slice_col($nameInCol, start: $start, end: $end)');
break;
case OP_STACK_ROWS:
int countStack = _readInt();
List<String> namesInStack = [];
for (int i = 0; i < countStack; i = i + 1) {
namesInStack.add(_readString());
}
String nameOutStack = _readString();
int axisStack = _readInt();
Logger.cyan('$opName: $nameOutStack = stack(${namesInStack.length} tensors, axis: $axisStack)');
break;
case OP_STACK_ROWS_BACKWARD:
String gradOutStack = _readString();
int countBw = _readInt();
List<String> namesGradIn = [];
for (int i = 0; i < countBw; i = i + 1) {
namesGradIn.add(_readString());
}
int axisBw = _readInt();
Logger.blue('$opName: bw_stack($gradOutStack -> ${namesGradIn.length} tensors, axis: $axisBw)');
break;
case OP_CONCATENATE:
String nameaCat = _readString();
String namebCat = _readString();
String nameOutCat = _readString();
int axisCat = _readInt();
Logger.cyan('$opName: $nameOutCat = concat($nameaCat, $namebCat, axis: $axisCat)');
break;
case OP_CONCATENATE_BACKWARD:
String gradOutCat = _readString();
String gradinaCat = _readString();
String gradinbCat = _readString();
int axisCatBw = _readInt();
int split = _readInt();
Logger.blue('$opName: bw_concat($gradOutCat -> $gradinaCat, $gradinbCat, split: $split)');
break;
case OP_PAD2D:
case OP_PAD2D_BACKWARD:
String nameInPad = _readString();
String nameOutPad = _readString();
int padT = _readInt();
int padB = _readInt();
int padL = _readInt();
int padR = _readInt();
Logger.cyan('$opName: $nameOutPad = pad($nameInPad, t:$padT, b:$padB, l:$padL, r:$padR)');
break;
// --- Advanced Layers ---
case OP_MATMUL_BIAS_RELU_FORWARD:
String nameX = _readString();
String nameW = _readString();
String namebMatmul = _readString();
String nameReluOut = _readString();
String namePreROut = _readString();
Logger.cyan('$opName: $nameReluOut = matmul_bias_relu($nameX, $nameW, $namebMatmul)');
break;
case OP_LAYER_NORM_FORWARD:
String nameInNorm = _readString();
String nameGamma = _readString();
String nameBetaNorm = _readString();
String nameOutNorm = _readString();
String nameMean = _readString();
String nameRstd = _readString();
double eps = _readFloat();
Logger.cyan('$opName: $nameOutNorm = layer_norm($nameInNorm) eps:$eps');
break;
case OP_LAYER_NORM_BACKWARD:
String gradOutNorm = _readString();
String nameInNormBw = _readString();
String nameGammaBw = _readString();
String nameMeanBw = _readString();
String nameRstdBw = _readString();
String gradInNorm = _readString();
String gradGamma = _readString();
String gradBeta = _readString();
Logger.blue('$opName: bw($gradOutNorm, in:$nameInNormBw)');
break;
case OP_CONV2D_FORWARD:
String nameInConv = _readString();
String nameKernel = _readString();
String nameOutConv = _readString();
Logger.cyan('$opName: $nameOutConv = conv2d($nameInConv, $nameKernel)');
break;
case OP_CONV2D_BACKWARD_INPUT:
String gradOutConv = _readString();
String nameKernelBw = _readString();
String gradInConv = _readString();
Logger.blue('$opName: $gradInConv = conv2d_bw_in($gradOutConv, $nameKernelBw)');
break;
case OP_CONV2D_BACKWARD_KERNEL:
String nameInConvBw = _readString();
String gradOutConvBw = _readString();
String gradKernel = _readString();
Logger.blue('$opName: $gradKernel = conv2d_bw_k($nameInConvBw, $gradOutConvBw)');
break;
case OP_CONV2D_MULTI_FORWARD:
String nameInMulti = _readString();
String nameWeight = _readString();
String nameBiasMulti = _readString();
String nameOutMulti = _readString();
int inC = _readInt();
int outC = _readInt();
int kh = _readInt();
int kw = _readInt();
int pt = _readInt();
int pl = _readInt();
int sh = _readInt();
int sw = _readInt();
Logger.cyan('$opName: $nameOutMulti = conv2d_multi($nameInMulti, inC:$inC, outC:$outC) pad:${pt}x$pl stride:${sh}x$sw');
break;
case OP_CONV2D_MULTI_BACKWARD_INPUT:
String gradOutMulti = _readString();
String nameWeightMulti = _readString();
String nameGradInMulti = _readString();
int inCBw = _readInt();
int outCBw = _readInt();
int khBw = _readInt();
int kwBw = _readInt();
int ptBw = _readInt();
int plBw = _readInt();
int shBw = _readInt();
int swBw = _readInt();
Logger.blue('$opName: $nameGradInMulti = conv2d_multi_bw_in($gradOutMulti) inC:$inCBw, outC:$outCBw');
break;
case OP_CONV2D_MULTI_BACKWARD_WEIGHT:
String nameInMultiBw = _readString();
String gradOutMultiBw = _readString();
String nameGradWeight = _readString();
String nameGradBias = _readString();
int inCW = _readInt();
int outCW = _readInt();
int khW = _readInt();
int kwW = _readInt();
int ptW = _readInt();
int plW = _readInt();
int shW = _readInt();
int swW = _readInt();
Logger.blue('$opName: $nameGradWeight, $nameGradBias = conv2d_multi_bw_w($nameInMultiBw, $gradOutMultiBw)');
break;
case OP_IM2COL:
String nameInIm = _readString();
String nameOutIm = _readString();
int khIm = _readInt();
int kwIm = _readInt();
Logger.cyan('$opName: $nameOutIm = im2col($nameInIm) kernel: ${khIm}x$kwIm');
break;
case OP_COL2IM:
String nameColGrad = _readString();
String nameInGrad = _readString();
int khCol = _readInt();
int kwCol = _readInt();
Logger.blue('$opName: $nameInGrad = col2im($nameColGrad) kernel: ${khCol}x$kwCol');
break;
case OP_MAX_POOL_1D_FORWARD:
case OP_MAX_POOL_2D_FORWARD:
String nameInPool = _readString();
String nameOutPool = _readString();
String nameIndicesPool = _readString();
int poolSize = _readInt();
int stridePool = _readInt();
Logger.cyan('$opName: $nameOutPool = max_pool($nameInPool, size: $poolSize, stride: $stridePool)');
break;
case OP_MAX_POOL_1D_BACKWARD:
case OP_MAX_POOL_2D_BACKWARD:
String gradOutPool = _readString();
String nameIndicesPoolBw = _readString();
String gradInPool = _readString();
Logger.blue('$opName: $gradInPool = max_pool_bw($gradOutPool)');
break;
case OP_AVG_POOL_2D_FORWARD: // Avg pool does NOT read indices in C++!
String nameInAvg = _readString();
String nameOutAvg = _readString();
int avgPoolSize = _readInt();
int avgStridePool = _readInt();
Logger.cyan('$opName: $nameOutAvg = avg_pool($nameInAvg, size: $avgPoolSize, stride: $avgStridePool)');
break;
case OP_AVG_POOL_2D_BACKWARD: // Avg pool does NOT read indices in C++!
String gradOutAvg = _readString();
String gradInAvg = _readString();
int avgPoolSizeBw = _readInt();
int avgStridePoolBw = _readInt();
Logger.blue('$opName: $gradInAvg = avg_pool_bw($gradOutAvg)');
break;
case OP_GLOBAL_AVG_POOL_FORWARD:
String nameInGap = _readString();
String nameOutGap = _readString();
Logger.cyan('$opName: $nameOutGap = global_avg_pool($nameInGap)');
break;
case OP_GLOBAL_AVG_POOL_BACKWARD:
String gradOutGap = _readString();
String gradInGap = _readString();
Logger.blue('$opName: $gradInGap = global_avg_pool_bw($gradOutGap)');
break;
case OP_BATCH_NORM_1D_FORWARD:
case OP_BATCH_NORM_2D_FORWARD:
String nameInBn = _readString();
String nameGammaBn = _readString();
String nameBetaBn = _readString();
String nameRm = _readString();
String nameRv = _readString();
String nameOutBn = _readString();
String nameSm = _readString();
String nameSiv = _readString();
double momentum = _readFloat();
double epsilonBn = _readFloat();
bool isTraining = _readBool();
Logger.cyan('$opName: $nameOutBn = batch_norm($nameInBn) train:$isTraining');
break;
case OP_BATCH_NORM_1D_BACKWARD:
case OP_BATCH_NORM_2D_BACKWARD:
String gradOutBn = _readString();
String nameInBnBw = _readString();
String nameGammaBnBw = _readString();
String nameSmBw = _readString();
String nameSivBw = _readString();
String gradInBn = _readString();
String gradGammaBn = _readString();
String gradBetaBn = _readString();
Logger.blue('$opName: bw($gradOutBn, in:$nameInBnBw)');
break;
case OP_DROPOUT_FORWARD:
String nameInDrop = _readString();
String nameOutDrop = _readString();
String nameMaskDrop = _readString();
double dropRate = _readFloat();
int seed = _readInt();
Logger.cyan('$opName: $nameOutDrop = dropout($nameInDrop, rate: $dropRate, seed: $seed)');
break;
case OP_DROPOUT_BACKWARD:
String gradOutDrop = _readString();
String nameMaskDropBw = _readString();
String gradInDrop = _readString();
Logger.blue('$opName: $gradInDrop = dropout_bw($gradOutDrop, mask: $nameMaskDropBw)');
break;
default:
Logger.red('Unknown OpCode ($opCode) encountered!', prefix: '⚠️');
return;
}
}
Logger.yellow('--- End of Tape ---', prefix: '📜');
}