softmaxMatrix function
Implementation
Tensor<Matrix> softmaxMatrix(Tensor<Matrix> m) {
Matrix inputMatrix = m.value;
int numRows = inputMatrix.length;
int numCols = 0;
if (numRows > 0) {
numCols = inputMatrix[0].length;
}
Matrix outValue = [];
for (int r = 0; r < numRows; r = r + 1) {
Vector row = inputMatrix[r];
double maxVal = -double.infinity;
for (int c = 0; c < numCols; c = c + 1) {
if (row[c] > maxVal) {
maxVal = row[c];
}
}
double sumExps = 0.0;
Vector exps = [];
for (int c = 0; c < numCols; c = c + 1) {
double expVal = exp(row[c] - maxVal);
exps.add(expVal);
sumExps = sumExps + expVal;
}
Vector softmaxRow = [];
for (int c = 0; c < numCols; c = c + 1) {
softmaxRow.add(exps[c] / sumExps);
}
outValue.add(softmaxRow);
}
Tensor<Matrix> out = Tensor<Matrix>(outValue);
out.creator = Node(
[m],
() {
for (int r = 0; r < numRows; r = r + 1) {
double dotProduct = 0.0;
for (int c = 0; c < numCols; c = c + 1) {
int flatIndex = r * numCols + c;
double yC = out.data[flatIndex];
double dyC = out.grad[flatIndex];
dotProduct = dotProduct + (dyC * yC);
}
for (int c = 0; c < numCols; c = c + 1) {
int flatIndex = r * numCols + c;
double yC = out.data[flatIndex];
double dyC = out.grad[flatIndex];
m.grad[flatIndex] = m.grad[flatIndex] + (yC * (dyC - dotProduct));
}
}
},
opName: 'softmax_matrix',
cost: numRows * numCols * 2
);
return out;
}