dot<T extends SizedNativeType> function

VARP dot<T extends SizedNativeType>(
  1. VARP a,
  2. VARP b
)

Implementation

VARP dot<T extends ffi.SizedNativeType>(VARP a, VARP b) {
  final ad = a.ndim!;
  final bd = b.ndim!;

  if (ad == 0 || bd == 0) {
    return F.multiply(a, b);
  }

  if (ad == 1 && bd == 1) {
    MnnAssert(a.shape![0] == b.shape![0], 'shapes not aligned');
    return F.reduceSum(F.multiply(a, b));
  }

  if (ad > 1 && bd == 1) {
    MnnAssert(a.shape!.last == b.shape![0], 'shapes not aligned');
    return F.reduceSum(F.multiply(a, b), axis: [-1]);
  }

  if (ad == 2 && bd == 2) {
    MnnAssert(a.shape![1] == b.shape![0], 'shapes not aligned');
    return F.matMul(a, b);
  }

  if (ad > 2 && bd > 1) {
    final reduceDim = a.shape!.last;
    MnnAssert(reduceDim == b.shape![bd - 2], 'shapes not aligned');

    final aShape = List.of(a.shape!);
    final bShape = List.of(b.shape!);
    aShape.removeLast();
    bShape.removeAt(bd - 2);
    final dstShape = [...aShape, ...bShape];

    final newA = reshape(a, [-1, reduceDim]);
    var newB = moveaxis(b, [bd - 2], [0]);
    newB = reshape(newB, [reduceDim, -1]);

    final res = F.matMul(newA, newB);
    return reshape(res, dstShape);
  }

  throw ArgumentError('dot not implemented for dimensions $ad and $bd');
}