fit method
Very small-step adversarial training stub: perform paired updates on discriminator and generator by creating labels and fitting the ANN heads.
Implementation
void fit(List<List<double>> realX, {int epochs = 1, int batchSize = 16}) {
final dataSize = realX.length;
for (var e = 0; e < epochs; e++) {
// create fake samples
final fakes = generate(min(batchSize, dataSize));
// discriminator train: real -> 1.0, fake -> 0.0
final xs = <List<double>>[];
final ys = <List<double>>[];
for (var i = 0; i < fakes.length; i++) {
xs.add(fakes[i]);
ys.add([0.0]);
}
for (var i = 0; i < min(batchSize, dataSize); i++) {
xs.add(realX[i]);
ys.add([1.0]);
}
discriminator.fit(xs, ys);
// generator train: try to fool discriminator -> label 1.0
final latents = List.generate(batchSize, (_) => _sampleLatent());
final genSamples = latents.map((z) => generator.predict([z])[0]).toList();
// create targets by asking discriminator for gradients via fit toward 1.0
final dgTargets = List.generate(genSamples.length, (_) => [1.0]);
// in practice generator update uses discriminator gradients; here we
// perform a proxy supervised fit where generator learns to produce
// samples that the discriminator classifies as real.
generator.fit(latents, dgTargets);
}
}