dot<T extends SizedNativeType> function
Implementation
VARP dot<T extends ffi.SizedNativeType>(VARP a, VARP b) {
final ad = a.ndim!;
final bd = b.ndim!;
if (ad == 0 || bd == 0) {
return F.multiply(a, b);
}
if (ad == 1 && bd == 1) {
MnnAssert(a.shape![0] == b.shape![0], 'shapes not aligned');
return F.reduceSum(F.multiply(a, b));
}
if (ad > 1 && bd == 1) {
MnnAssert(a.shape!.last == b.shape![0], 'shapes not aligned');
return F.reduceSum(F.multiply(a, b), axis: [-1]);
}
if (ad == 2 && bd == 2) {
MnnAssert(a.shape![1] == b.shape![0], 'shapes not aligned');
return F.matMul(a, b);
}
if (ad > 2 && bd > 1) {
final reduceDim = a.shape!.last;
MnnAssert(reduceDim == b.shape![bd - 2], 'shapes not aligned');
final aShape = List.of(a.shape!);
final bShape = List.of(b.shape!);
aShape.removeLast();
bShape.removeAt(bd - 2);
final dstShape = [...aShape, ...bShape];
final newA = reshape(a, [-1, reduceDim]);
var newB = moveaxis(b, [bd - 2], [0]);
newB = reshape(newB, [reduceDim, -1]);
final res = F.matMul(newA, newB);
return reshape(res, dstShape);
}
throw ArgumentError('dot not implemented for dimensions $ad and $bd');
}