softmaxMatrix function

Tensor<Matrix> softmaxMatrix(
  1. Tensor<Matrix> m
)

Implementation

Tensor<Matrix> softmaxMatrix(Tensor<Matrix> m) {
  Matrix inputMatrix = m.value;
  int numRows = inputMatrix.length;
  int numCols = 0;
  if (numRows > 0) {
    numCols = inputMatrix[0].length;
  }

  Matrix outValue = [];

  for (int r = 0; r < numRows; r = r + 1) {
    Vector row = inputMatrix[r];

    double maxVal = -double.infinity;
    for (int c = 0; c < numCols; c = c + 1) {
      if (row[c] > maxVal) {
        maxVal = row[c];
      }
    }

    double sumExps = 0.0;
    Vector exps = [];
    for (int c = 0; c < numCols; c = c + 1) {
      double expVal = exp(row[c] - maxVal);
      exps.add(expVal);
      sumExps = sumExps + expVal;
    }

    Vector softmaxRow = [];
    for (int c = 0; c < numCols; c = c + 1) {
      softmaxRow.add(exps[c] / sumExps);
    }
    outValue.add(softmaxRow);
  }

  Tensor<Matrix> out = Tensor<Matrix>(outValue);

  out.creator = Node(
      [m],
          () {
        for (int r = 0; r < numRows; r = r + 1) {
          double dotProduct = 0.0;
          for (int c = 0; c < numCols; c = c + 1) {
            int flatIndex = r * numCols + c;
            double yC = out.data[flatIndex];
            double dyC = out.grad[flatIndex];
            dotProduct = dotProduct + (dyC * yC);
          }

          for (int c = 0; c < numCols; c = c + 1) {
            int flatIndex = r * numCols + c;
            double yC = out.data[flatIndex];
            double dyC = out.grad[flatIndex];
            m.grad[flatIndex] = m.grad[flatIndex] + (yC * (dyC - dotProduct));
          }
        }
      },
      opName: 'softmax_matrix',
      cost: numRows * numCols * 2
  );

  return out;
}