/*
 * Decompiled with CFR 0.152.
 */
package org.apache.commons.math3.ml.neuralnet.sofm;

import java.util.Collection;
import java.util.HashSet;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.commons.math3.analysis.function.Gaussian;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.ml.distance.DistanceMeasure;
import org.apache.commons.math3.ml.neuralnet.MapUtils;
import org.apache.commons.math3.ml.neuralnet.Network;
import org.apache.commons.math3.ml.neuralnet.Neuron;
import org.apache.commons.math3.ml.neuralnet.UpdateAction;
import org.apache.commons.math3.ml.neuralnet.sofm.LearningFactorFunction;
import org.apache.commons.math3.ml.neuralnet.sofm.NeighbourhoodSizeFunction;

public class KohonenUpdateAction
implements UpdateAction {
    private final DistanceMeasure distance;
    private final LearningFactorFunction learningFactor;
    private final NeighbourhoodSizeFunction neighbourhoodSize;
    private final AtomicLong numberOfCalls = new AtomicLong(0L);

    public KohonenUpdateAction(DistanceMeasure distanceMeasure, LearningFactorFunction learningFactorFunction, NeighbourhoodSizeFunction neighbourhoodSizeFunction) {
        this.distance = distanceMeasure;
        this.learningFactor = learningFactorFunction;
        this.neighbourhoodSize = neighbourhoodSizeFunction;
    }

    public void update(Network network, double[] dArray) {
        long l = this.numberOfCalls.incrementAndGet() - 1L;
        double d = this.learningFactor.value(l);
        Neuron neuron = this.findAndUpdateBestNeuron(network, dArray, d);
        int n = this.neighbourhoodSize.value(l);
        Gaussian gaussian = new Gaussian(d, 0.0, n);
        if (n > 0) {
            Collection<Neuron> collection = new HashSet<Neuron>();
            collection.add(neuron);
            HashSet<Neuron> hashSet = new HashSet<Neuron>();
            hashSet.add(neuron);
            int n2 = 1;
            do {
                collection = network.getNeighbours(collection, hashSet);
                for (Neuron neuron2 : collection) {
                    this.updateNeighbouringNeuron(neuron2, dArray, gaussian.value(n2));
                }
                hashSet.addAll(collection);
            } while (++n2 <= n);
        }
    }

    public long getNumberOfCalls() {
        return this.numberOfCalls.get();
    }

    private boolean attemptNeuronUpdate(Neuron neuron, double[] dArray, double d) {
        double[] dArray2 = neuron.getFeatures();
        double[] dArray3 = this.computeFeatures(dArray2, dArray, d);
        return neuron.compareAndSetFeatures(dArray2, dArray3);
    }

    private void updateNeighbouringNeuron(Neuron neuron, double[] dArray, double d) {
        while (!this.attemptNeuronUpdate(neuron, dArray, d)) {
        }
    }

    private Neuron findAndUpdateBestNeuron(Network network, double[] dArray, double d) {
        Neuron neuron;
        while (!this.attemptNeuronUpdate(neuron = MapUtils.findBest(dArray, network, this.distance), dArray, d)) {
        }
        return neuron;
    }

    private double[] computeFeatures(double[] dArray, double[] dArray2, double d) {
        ArrayRealVector arrayRealVector = new ArrayRealVector(dArray, false);
        ArrayRealVector arrayRealVector2 = new ArrayRealVector(dArray2, false);
        return arrayRealVector2.subtract(arrayRealVector).mapMultiplyToSelf(d).add(arrayRealVector).toArray();
    }
}

