package pama1234.util.neat.raimannma.architecture;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.IntConsumer;
import java.util.function.Predicate;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import pama1234.math.UtilMath;
import pama1234.util.neat.raimannma.architecture.Connection;
import pama1234.util.neat.raimannma.architecture.Node;
import pama1234.util.neat.raimannma.methods.Utils;

/* loaded from: classes3.dex */
public abstract class NetworkCore<N extends Node, C extends Connection> {
    public float[] floatData;
    public final int inputSize;
    public final int outputSize;
    public float score = Float.NaN;
    public List<N> nodes = new ArrayList();
    public Set<C> connections = new HashSet();
    public Set<C> gates = new HashSet();
    public Set<C> selfConnections = new HashSet();
    protected float dropout = 0.0f;

    @FunctionalInterface
    /* loaded from: classes3.dex */
    public interface NewNetwork<T extends NetworkCore<?, ?>> {
        T get(int i, int i2);
    }

    @FunctionalInterface
    /* loaded from: classes3.dex */
    public interface NewNode<N extends Node> {
        default N get() {
            return get(Node.NodeType.HIDDEN);
        }

        N get(Node.NodeType nodeType);
    }

    public NetworkCore(int i, int i2) {
        this.inputSize = i;
        this.outputSize = i2;
        for (int i3 = 0; i3 < this.inputSize; i3++) {
            this.nodes.add(newNode(Node.NodeType.INPUT));
        }
        for (int i4 = 0; i4 < this.outputSize; i4++) {
            this.nodes.add(newNode(Node.NodeType.OUTPUT));
        }
        createConnection();
    }

    public static <N extends Node, C extends Connection, T extends NetworkCore<N, C>> T crossover(NetworkCore<N, C> networkCore, NetworkCore<N, C> networkCore2, boolean z, NewNetwork<T> newNetwork, NewNode<N> newNode) {
        int i;
        int i2 = networkCore.inputSize;
        if (i2 != networkCore2.inputSize || (i = networkCore.outputSize) != networkCore2.outputSize) {
            throw new IllegalStateException("Networks don't have the same input/output size!");
        }
        final T t = newNetwork.get(i2, i);
        t.connections.clear();
        t.nodes.clear();
        float f = Float.isNaN(networkCore.score) ? -3.4028235E38f : networkCore.score;
        float f2 = Float.isNaN(networkCore2.score) ? -3.4028235E38f : networkCore2.score;
        int size = networkCore.nodes.size();
        int size2 = networkCore2.nodes.size();
        final int randInt = (z || f == f2) ? Utils.randInt(Math.min(size, size2), Math.max(size, size2)) : f > f2 ? size : size2;
        networkCore.setNodeIndices();
        networkCore2.setNodeIndices();
        int i3 = 0;
        while (i3 < randInt) {
            N n = i3 < randInt - networkCore.outputSize ? Utils.randBoolean() ? (i3 >= size || networkCore.nodes.get(i3).type == Node.NodeType.OUTPUT) ? networkCore2.nodes.get(i3) : networkCore.nodes.get(i3) : (i3 >= size2 || networkCore2.nodes.get(i3).type == Node.NodeType.OUTPUT) ? networkCore.nodes.get(i3) : networkCore2.nodes.get(i3) : Utils.randBoolean() ? networkCore.nodes.get((i3 + size) - randInt) : networkCore2.nodes.get((i3 + size2) - randInt);
            N n2 = newNode.get();
            n2.bias = n.bias;
            n2.activationType = n.activationType;
            n2.type = n.type;
            t.nodes.add(n2);
            i3++;
        }
        Map<Integer, int[]> makeConnections = makeConnections(networkCore);
        final Map<Integer, int[]> makeConnections2 = makeConnections(networkCore2);
        final ArrayList arrayList = new ArrayList();
        for (Integer num : new ArrayList(makeConnections.keySet())) {
            if (makeConnections2.get(num) != null) {
                arrayList.add(Utils.randBoolean() ? makeConnections.get(num) : makeConnections2.get(num));
                makeConnections2.put(num, null);
            } else if (f >= f2 || z) {
                arrayList.add(makeConnections.get(num));
            }
        }
        ArrayList arrayList2 = new ArrayList(makeConnections2.keySet());
        if (f2 >= f || z) {
            Stream stream = arrayList2.stream();
            Objects.requireNonNull(makeConnections2);
            stream.map(new Function() { // from class: pama1234.util.neat.raimannma.architecture.NetworkCore$$ExternalSyntheticLambda2
                @Override // java.util.function.Function
                public final Object apply(Object obj) {
                    return (int[]) makeConnections2.get((Integer) obj);
                }
            }).filter(new Predicate() { // from class: pama1234.util.neat.raimannma.architecture.NetworkCore$$ExternalSyntheticLambda3
                @Override // java.util.function.Predicate
                public final boolean test(Object obj) {
                    return Objects.nonNull((int[]) obj);
                }
            }).forEach(new Consumer() { // from class: pama1234.util.neat.raimannma.architecture.NetworkCore$$ExternalSyntheticLambda4
                @Override // java.util.function.Consumer
                public final void accept(Object obj) {
                    arrayList.add((int[]) obj);
                }
            });
        }
        arrayList.stream().forEach(new Consumer() { // from class: pama1234.util.neat.raimannma.architecture.NetworkCore$$ExternalSyntheticLambda5
            @Override // java.util.function.Consumer
            public final void accept(Object obj) {
                NetworkCore.lambda$crossover$1(randInt, t, (int[]) obj);
            }
        });
        t.setNodeIndices();
        return t;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* JADX WARN: Multi-variable type inference failed */
    public static /* synthetic */ void lambda$crossover$1(int i, NetworkCore networkCore, int[] iArr) {
        int index;
        int gain = ConnectionData.gain(iArr);
        if (gain >= i || gain < 0 || (index = ConnectionData.toIndex(iArr)) >= i) {
            return;
        }
        Connection connect = networkCore.connect(networkCore.nodes.get(gain), networkCore.nodes.get(index));
        connect.weight = ConnectionData.weight(iArr);
        if (ConnectionData.gateNodeIndex(iArr) < 0 || ConnectionData.gateNodeIndex(iArr) >= i) {
            return;
        }
        networkCore.gate(networkCore.nodes.get(ConnectionData.gateNodeIndex(iArr)), connect);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public /* synthetic */ void lambda$setNodeIndices$0(int i) {
        this.nodes.get(i).index = i;
    }

    private static <N extends Node, C extends Connection> Map<Integer, int[]> makeConnections(NetworkCore<N, C> networkCore) {
        final HashMap hashMap = new HashMap();
        Stream.concat(networkCore.connections.stream(), networkCore.selfConnections.stream()).forEach(new Consumer() { // from class: pama1234.util.neat.raimannma.architecture.NetworkCore$$ExternalSyntheticLambda0
            @Override // java.util.function.Consumer
            public final void accept(Object obj) {
                hashMap.put(Integer.valueOf(Connection.getInnovationID(r2.from.index, r2.to.index)), ((Connection) obj).getConnectionDataAsIntArray());
            }
        });
        return hashMap;
    }

    public C connect(N n, N n2) {
        return connect(n, n2, 0.0f);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public C connect(N n, N n2, float f) {
        C newConnection = newConnection(n, n2, f);
        if (n.equals(n2)) {
            this.selfConnections.add(newConnection);
        } else {
            this.connections.add(newConnection);
        }
        return newConnection;
    }

    public void createConnection() {
        int i = this.inputSize;
        float sqrt = i * UtilMath.sqrt(2.0f / i);
        for (int i2 = 0; i2 < this.inputSize; i2++) {
            N n = this.nodes.get(i2);
            for (int i3 = this.inputSize; i3 < this.outputSize + this.inputSize; i3++) {
                connect(n, this.nodes.get(i3), Utils.randFloat(sqrt));
            }
        }
    }

    public void gate(N n, C c) {
        if (!this.nodes.contains(n)) {
            throw new ArrayIndexOutOfBoundsException("This node is not part of the network!");
        }
        if (c.gateNode == null) {
            n.gate(c);
            this.gates.add(c);
        }
    }

    public abstract C newConnection(Node node, Node node2, float f);

    public abstract N newNode(Node.NodeType nodeType);

    public void setNodeIndices() {
        IntStream.range(0, this.nodes.size()).forEach(new IntConsumer() { // from class: pama1234.util.neat.raimannma.architecture.NetworkCore$$ExternalSyntheticLambda1
            @Override // java.util.function.IntConsumer
            public final void accept(int i) {
                NetworkCore.this.lambda$setNodeIndices$0(i);
            }
        });
    }
}
