compileNetwork function

String compileNetwork(
  1. TfannNetwork network, {
  2. String functionName = 'tfann_evaluate',
})

Returns a pure dart code that represents the function of this network.

Implementation

String compileNetwork(TfannNetwork network,
    {String functionName = 'tfann_evaluate'}) {
  StringBuffer stringBuffer = StringBuffer();
  int inputSize = network.layers[0].weights.nColumns;
  var activationsSet = network.layers.map((e) => e.activationFunc.type).toSet();

  stringBuffer.write("import 'dart:typed_data';\n");
  stringBuffer.write("import 'dart:math';\n\n");
  stringBuffer.write('''final Float32x4 _SIMD0 = Float32x4.zero();
final Float32x4 _SIMD0_75 = Float32x4.splat(0.75);
final Float32x4 _SIMD0_5 = Float32x4.splat(0.5);
final Float32x4 _SIMD0_25 = Float32x4.splat(0.25);
final Float32x4 _SIMD0_125 = Float32x4.splat(0.125);
final Float32x4 _SIMD0_375 = Float32x4.splat(0.375);
final Float32x4 _SIMD0_625 = Float32x4.splat(0.625);
final Float32x4 _SIMD0_0625 = Float32x4.splat(0.0625);
final Float32x4 _SIMD0_03 = Float32x4.splat(0.03);
final Float32x4 _SIMD0_033 = Float32x4.splat(0.033);
final Float32x4 _SIMD0_65625 = Float32x4.splat(0.65625);
final Float32x4 _SIMD0_065 = Float32x4.splat(0.065);
final Float32x4 _SIMD0_185 = Float32x4.splat(0.185);
final Float32x4 _SIMD0_104 = Float32x4.splat(0.104);
final Float32x4 _SIMD0_208 = Float32x4.splat(0.208);
final Float32x4 _SIMD0_704 = Float32x4.splat(0.704);
final Float32x4 _SIMD0_7424 = Float32x4.splat(0.7424);
final Float32x4 _SIMDm0_8 = Float32x4.splat(-0.8);
final Float32x4 _SIMDm1_5 = Float32x4.splat(-1.5);
final Float32x4 _SIMD0_28125 = Float32x4.splat(0.28125);
final Float32x4 _SIMD1 = Float32x4.splat(1);
final Float32x4 _SIMD1_47 = Float32x4.splat(1.47);
final Float32x4 _SIMD1_6 = Float32x4.splat(1.6);
final Float32x4 _SIMD4 = Float32x4.splat(4);
final Float32x4 _SIMD8 = Float32x4.splat(8);
final Float32x4 _SIMDm2 = Float32x4.splat(-2);
final Float32x4 _SIMDm3_3 = Float32x4.splat(-3.3);
final Float32x4 _SIMD0_875 = Float32x4.splat(0.875);
final Float32x4 _SIMD0_4 = Float32x4.splat(0.4);
final Float32x4 _SIMDm0_16 = Float32x4.splat(-0.16);
final Float32x4List _SimdSignMaskVector = Float32x4List.fromList(List.generate(
    16,
    (index) => Float32x4(
        (index & 1) != 0 ? -1.0 : 1.0,
        (index & 2) != 0 ? -1.0 : 1.0,
        (index & 4) != 0 ? -1.0 : 1.0,
        (index & 8) != 0 ? -1.0 : 1.0)));
\n''');
  if (activationsSet.contains(ActivationFunctionType.logistic)) {
    stringBuffer.write(
        "double logisticSigmoid(double x) { return 1 / (1 + exp(-x));}\n");
  }
  if (activationsSet.contains(ActivationFunctionType.abs)) {
    stringBuffer
        .write("double absSigmoid(double x) { return x / (1 + x.abs());}\n");
    stringBuffer.write(
        "Float32x4 absSigmoidX4(Float32x4 x) =>   x / (_SIMD1 + x.abs());\n");
  }
  if (activationsSet.contains(ActivationFunctionType.tanh)) {
    stringBuffer.write(
        "double tanh(double x) {  var e2x = exp(2 * x);    return (e2x - 1) / (e2x + 1); }\n");
  }
  if (activationsSet.contains(ActivationFunctionType.bell)) {
    stringBuffer.write("double bell(double x) {      return exp(-2*x*x);}\n");
  }
  if (activationsSet.contains(ActivationFunctionType.uscls)) {
    stringBuffer.write('''double uscls(double x) {
  if (x > 1) return 0.375 + 0.125 * x;
  if (x > -1.5) return 0.5 * x;
  return 0.0625 * x - 0.65625;
}\n''');
    stringBuffer.write('''Float32x4 usclsX4(Float32x4 x) {
  Int32x4 greater1 = x.greaterThan(_SIMD1);
  Float32x4 branch1Result = x.scale(0.125) + _SIMD0_375;
  Int32x4 lessThanMinus1_5 = x.lessThanOrEqual(_SIMDm1_5);
  Float32x4 branch3Result = x.scale(0.0625) - _SIMD0_65625;

  return greater1.select(
      branch1Result, lessThanMinus1_5.select(branch3Result, x.scale(0.5)));
}\n\n''');
  }
  if (activationsSet.contains(ActivationFunctionType.uscsls)) {
    stringBuffer.write('''double uscsls(double x) {
  if (x >= 1.6) return 0.065 * x + 0.704;
  if (x > -0.8) {
    var x2 = x * x;
    var x3 = x2 * x;
    return 0.125 * (x2 - x3) + 0.625 * x;
  }
  return 0.185 * x - 0.208;
}


Float32x4 uscslsX4(Float32x4 x) {
  Int32x4 greater1_6 = x.greaterThan(_SIMD1_6);
  Float32x4 x2 = x * x;

  Float32x4 branch1Result = x.scale(0.065) + _SIMD0_704;
  Float32x4 x3 = x2 * x;

  Int32x4 lessThanMinus0_8 = x.lessThanOrEqual(_SIMDm0_8);
  Float32x4 branch3Result = x.scale(0.185) - _SIMD0_208;

  return greater1_6.select(
      branch1Result,
      lessThanMinus0_8.select(
          branch3Result,
           (x2 - x3).scale(0.125) +  x.scale(0.625)));
}\n\n''');
  }
  if (activationsSet.contains(ActivationFunctionType.uacsls)) {
    stringBuffer.write('''double uacsls(double x) {
  var qx = x * 0.5;
  if (x >= 0) return qx;
  if (x > -1.5) return 0.5*qx * qx + qx;
  return 0.125 * x - 0.28125;
}

Float32x4 uacslsX4(Float32x4 x) {
  Float32x4 qx = x.scale(0.5);
  Int32x4 greaterZero = x.greaterThanOrEqual(Float32x4.zero());
  Int32x4 greaterM3div2 = x.greaterThan(_SIMDm1_5);
  return greaterZero.select(
      qx, greaterM3div2.select(qx * qx.scale(0.5) + qx, x.scale(0.125) - _SIMD0_28125));
}\n\n''');
  }
  if (activationsSet.contains(ActivationFunctionType.fastBell)) {
    stringBuffer.write('''double fastBell(double x) {
  var x2 = x * x;
  if (x2 <= 0.25) return 1 - 2 * x2;
  return (1 - x2) / (8 * x2) + 1 / 8.0;
}\n\n''');
    stringBuffer.write('''
Float32x4 fastBellX4(Float32x4 x) {
  var x2 = x * x;
  return x2
      .greaterThan(_SIMD0_25)
      .select((_SIMD1 - x2) / (x2.scale(8)) + _SIMD0_125, _SIMD1 - x2.scale(2));
}
\n\n''');
  }
  if (activationsSet.contains(ActivationFunctionType.divlineSigmoid)) {
    stringBuffer.write('''double divlineSigmoid(double x)
{
  var absX = x.abs();
  if(absX<=0.75)
  {
    return x;
  }
  var ftxmo=absX*4-1;
  return x.sign *(1-1/(ftxmo*ftxmo));
}\n''');
    stringBuffer.write('''Float32x4 divlineSigmoidX4(Float32x4 x) {
  var absX = x.abs();
  var ftxmo = absX.scale(4) - _SIMD1;
  return absX
      .greaterThan(_SIMD0_75)
      .select(_SimdSignMaskVector[x.signMask] * (_SIMD1-(ftxmo*ftxmo).reciprocal()), x);
}\n\n''');
  }
  if (activationsSet.contains(ActivationFunctionType.cubicSigmoid)) {
    stringBuffer.write(
'''double cubicSigmoid(double x)
{
  if (x.abs()>=1)
    return 0.03*x+x.sign*0.96;
  return -0.48*x*x*x+1.47*x;
}

Float32x4 cubicSigmoidX4(Float32x4 x)
{
  return x.abs().greaterThanOrEqual(_SIMD1).select(x.scale(0.03)+_SimdSignMaskVector[x.signMask].scale(0.96), x.scale(1.47)-x*x*x.scale(0.48));
}\n\n''');
  }
  if(activationsSet.contains(ActivationFunctionType.line))
  {
    stringBuffer.write('double simpleLine(double x) => x;\n');
    stringBuffer.write('Float32x4 simpleLineX4(Float32x4 x) => x;\n\n');
  }
  if(activationsSet.contains(ActivationFunctionType.squartered))
  {
    stringBuffer.write('double squartered(double x) =>x*x/4;\n');
    stringBuffer.write('Float32x4 squarteredX4(Float32x4 x) =>x*x.scale(0.25);\n\n');
  }
  if(activationsSet.contains(ActivationFunctionType.funnyHat))
  {
    stringBuffer.write('''double funnyHat(double x) {
  double x2=x*x;
  if (x>=0)
  {
    if(x>=1.6)
    {
      return -0.16*x+0.104;
    }
    return 0.5*x*x2-1.25*x2+1;
  }
  if(x<=-3.3)
  {
    return 0.033*x-0.7424;
  }
  return 1-0.1*x*x2-0.5*x2;
}
Float32x4 funnyHatX4(Float32x4 x) {
  var x2 = x*x;
  var g0 = x.greaterThan(_SIMD0);
  var x3 = x2*x;
  var g1_6 =x.greaterThanOrEqual(_SIMD1_6);
  var gm3_3 =x.greaterThan(_SIMDm3_3);
  return g0.select(g1_6.select(x.scale(-0.16)+_SIMD0_104, x3.scale(0.5)-x2.scale(1.25)+_SIMD1), gm3_3.select(_SIMD1-x3.scale(0.1)-x2.scale(0.5), x.scale(0.033)-_SIMD0_7424));
}\n\n''');
  }

  network.layers.asMap().forEach((i, layer) {
    int weightsWidth = layer.weights.nColumns;
    weightsWidth = roundUp4(weightsWidth) ~/ 2;

    stringBuffer
        .write("final List<Float32x4List> Lweight_${functionName}_$i = [");

    stringBuffer.write(layer.weights.rowsData
        .map((row) =>
            "Uint32List.fromList(${row.buffer.asUint32List().toList()}).buffer.asFloat32x4List()")
        .join(", "));
    stringBuffer.write("];\n");

    stringBuffer.write(
        "final Float32x4List Lbias_${functionName}_$i = Uint32List.fromList(${layer.bias.columnData.buffer.asUint32List().toList()}).buffer.asFloat32x4List();\n");
  });
  stringBuffer
      .write("\n\nList<double> ${functionName}(List<double> inData) \n{\n");
  stringBuffer.write("  assert(inData.length == $inputSize);\n");
  stringBuffer
      .write("  Float32List input = Float32List(${roundUp4(inputSize)});\n");
  stringBuffer
      .write("  for (int i = 0; i< $inputSize; ++i) input[i] = inData[i];\n");

  stringBuffer.write(
      "  Float32x4List currentTensor = input.buffer.asFloat32x4List();\n");
  stringBuffer.write("  Float32List outputTensor;\n");
  network.layers.asMap().forEach((i, layer) {
    stringBuffer.write(
        "  outputTensor = Float32List(${roundUp4(layer.weights.nRows)});\n");
    stringBuffer
        .write("  for (int r = 0; r < ${layer.weights.nRows}; ++r)\n  {\n");
    stringBuffer.write(
        "    Float32x4List weightRow = Lweight_${functionName}_$i[r];\n");
    int columns4 = (layer.weights.nColumns + 3) ~/ 4;
    if (columns4 == 1) {
      stringBuffer
          .write("    Float32x4 sum = currentTensor[0]*weightRow[0];\n");
    } else {
      stringBuffer.write(
          "    Float32x4 sum = Float32x4.zero();\n    for (int i = 0; i < $columns4; ++i)\n    {     sum+=currentTensor[i]*weightRow[i];   }\n");
    }
    stringBuffer.write(
        "    outputTensor[r] = sum.x ${layer.weights.nColumns >= 2 ? '+ sum.y' : ''} ${layer.weights.nColumns >= 3 ? '+ sum.z' : ''} ${layer.weights.nColumns >= 4 ? '+ sum.w' : ''};\n  }\n");
    stringBuffer
        .write("  currentTensor = outputTensor.buffer.asFloat32x4List();\n");
    int biasDiv4 = (layer.bias.length + 3) ~/ 4;
    if (biasDiv4 > 1) {
      stringBuffer.write("  for (int i = 0; i < ${biasDiv4}; ++i)\n");
    }
    stringBuffer.write(
        "    currentTensor[${biasDiv4 == 1 ? "0" : "i"}]+=Lbias_${functionName}_$i[${biasDiv4 == 1 ? "0" : "i"}];\n");
    var currentX4Func = [
      '',
      '',
      'absSigmoidX4',
      '',
      'usclsX4',
      'uscslsX4',
      'uacslsX4',
      'fastBellX4',
      'divlineSigmoidX4',
          'simpleLineX4',
          'funnyHatX4',
          'cubicSigmoidX4',
          'squarteredX4'
    ][layer.activationFunc.type.index];
    var currentFunc = [
      'logisticSigmoid',
      'tanh',
      'absSigmoid',
      'bell',
      'uscls',
      'uscsls',
      'uacsls',
      'fastBell',
      'divlineSigmoid',
          'simpleLine',
          'funnyHat',
          'cubicSigmoid',
          'squartered'
    ][layer.activationFunc.type.index];
    if (currentX4Func.isNotEmpty) {
      var actFull4 = layer.bias.length ~/ 4;
      var actRemain = layer.bias.length % 4;
      if (actFull4 > 0) {
        if (actFull4 > 1) {
          stringBuffer.write("  for (int i = 0; i < ${actFull4}; ++i)\n");
          stringBuffer.write(
              "    currentTensor[i]=$currentX4Func(currentTensor[i]);\n");
        } else {
          stringBuffer
              .write("  currentTensor[0]=$currentX4Func(currentTensor[0]);\n");
        }
      }
      for (int i = 0; i < actRemain; ++i) {
        stringBuffer.write(
            "  outputTensor[${actFull4 * 4 + i}]=$currentFunc(outputTensor[${actFull4 * 4 + i}]);\n");
      }
    } else {
      stringBuffer.write("  for (int i = 0; i < ${layer.bias.length}; ++i)\n");
      stringBuffer
          .write("    outputTensor[i]=$currentFunc(outputTensor[i]);\n");
    }
  });
  stringBuffer.write(
      "  return currentTensor.buffer.asFloat32List(0,${network.layers.last.bias.length}).toList();\n}\n\n");

  return stringBuffer.toString();
}