fit method

void fit(
  1. List<List<double>> realX, {
  2. int epochs = 1,
  3. int batchSize = 16,
})

Very small-step adversarial training stub: perform paired updates on discriminator and generator by creating labels and fitting the ANN heads.

Implementation

void fit(List<List<double>> realX, {int epochs = 1, int batchSize = 16}) {
  final dataSize = realX.length;
  for (var e = 0; e < epochs; e++) {
    // create fake samples
    final fakes = generate(min(batchSize, dataSize));
    // discriminator train: real -> 1.0, fake -> 0.0
    final xs = <List<double>>[];
    final ys = <List<double>>[];
    for (var i = 0; i < fakes.length; i++) {
      xs.add(fakes[i]);
      ys.add([0.0]);
    }
    for (var i = 0; i < min(batchSize, dataSize); i++) {
      xs.add(realX[i]);
      ys.add([1.0]);
    }
    discriminator.fit(xs, ys);

    // generator train: try to fool discriminator -> label 1.0
    final latents = List.generate(batchSize, (_) => _sampleLatent());
    final genSamples = latents.map((z) => generator.predict([z])[0]).toList();
    // create targets by asking discriminator for gradients via fit toward 1.0
    final dgTargets = List.generate(genSamples.length, (_) => [1.0]);
    // in practice generator update uses discriminator gradients; here we
    // perform a proxy supervised fit where generator learns to produce
    // samples that the discriminator classifies as real.
    generator.fit(latents, dgTargets);
  }
}