package edu.gatech.mln.feedbackSelection;

import edu.gatech.mln.GClause;
import edu.gatech.mln.MarkovLogicNetwork;
import edu.gatech.mln.infer.LazySolver;
import edu.gatech.mln.infer.LazySolverILP;
import edu.gatech.mln.infer.LazySolverLBX;
import edu.gatech.mln.infer.LazySolverLBXMCS;
import edu.gatech.mln.infer.LazySolverMCSls;
import edu.gatech.mln.infer.LazySolverMaxSAT;
import edu.gatech.mln.infer.LazySolverParallel;
import edu.gatech.mln.infer.LazySolverT;
import edu.gatech.mln.infer.LazySolverTwoStage;
import edu.gatech.mln.infer.LazySolverWalkSAT;
import edu.gatech.mln.parser.CommandOptions;
import edu.gatech.mln.util.Config;
import edu.gatech.mln.util.NamedThreadFactory;
import edu.gatech.mln.util.UIMan;
import gnu.trove.TIntCollection;
import gnu.trove.iterator.TIntIterator;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.set.hash.TIntHashSet;
import java.io.File;
import java.io.FileNotFoundException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Random;
import java.util.Scanner;
import java.util.Set;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import org.apache.commons.lang3.Pair;

/* loaded from: input_file:edu/gatech/mln/feedbackSelection/FeedbackPicker.class */
public class FeedbackPicker {
    private LazySolver solver;
    private MarkovLogicNetwork mln;
    private LazySolverParallel master = null;
    int NO_MODE = -1;
    int PESS_MODE = 0;
    int OPT_MODE = 1;
    double PESSI_RATE = 0.75d;
    private MaxSATConvertor conv = new MaxSATConvertor();

    public FeedbackPicker(CommandOptions commandOptions) {
        this.conv.convert(commandOptions);
        this.mln = this.conv.getMLN();
        if (Config.solver.equals(Config.ILP_SOLVER)) {
            this.solver = new LazySolverILP(this.mln);
            return;
        }
        if (Config.solver.equals(Config.LBX_SOLVER)) {
            this.solver = new LazySolverLBX(this.mln);
            return;
        }
        if (Config.solver.equals(Config.MCSLS_SOLVER)) {
            this.solver = new LazySolverMCSls(this.mln);
            return;
        }
        if (Config.solver.equals(Config.WALK_SOLVER)) {
            this.solver = new LazySolverWalkSAT(this.mln);
            return;
        }
        if (Config.solver.equals(Config.TUFFY_SOLVER)) {
            this.solver = new LazySolverT(this.mln);
            return;
        }
        if (Config.solver.equals(Config.TWO_STAGE_SOLVER)) {
            this.solver = new LazySolverTwoStage(this.mln);
        } else if (Config.solver.equals(Config.LBX_MCS_SOLVER)) {
            this.solver = new LazySolverLBXMCS(this.mln);
        } else {
            this.solver = new LazySolverMaxSAT(this.mln);
        }
    }

    public void runLoop(CommandOptions commandOptions) {
        int i;
        int i2 = 0;
        int i3 = this.PESS_MODE;
        int i4 = this.NO_MODE;
        TIntHashSet tIntHashSet = new TIntHashSet();
        TIntHashSet tIntHashSet2 = new TIntHashSet();
        TIntCollection tIntHashSet3 = new TIntHashSet(this.conv.queries);
        TIntHashSet tIntHashSet4 = new TIntHashSet();
        Random random = new Random(System.currentTimeMillis());
        TIntHashSet tIntHashSet5 = new TIntHashSet();
        try {
            Scanner scanner = new Scanner(new File(commandOptions.fOracle));
            while (scanner.hasNextLine()) {
                String nextLine = scanner.nextLine();
                if (!nextLine.startsWith("//")) {
                    tIntHashSet5.add(this.mln.getAtomID(this.mln.parseAtom(nextLine)).intValue());
                }
            }
            scanner.close();
            for (Integer num : this.solver.solve(this.conv.groundedClauses).get(0).right) {
                if (num.intValue() > 0 && this.conv.queries.contains(num.intValue())) {
                    tIntHashSet4.add(num.intValue());
                }
            }
            if (Config.isParallel) {
                Config.executor = Executors.newCachedThreadPool(new NamedThreadFactory("MLN thread pool", true));
                this.master = new LazySolverParallel();
                if (commandOptions.workerAddr == null) {
                    throw new RuntimeException("Specify worker addresses in parallel mode");
                }
                for (String str : commandOptions.workerAddr.split("##")) {
                    String[] split = str.split(":");
                    this.master.registerWorker(split[0], Integer.parseInt(split[1]));
                }
            }
            while (i2 < Config.pickBudget) {
                if (i4 != this.NO_MODE) {
                    i = i4;
                    i4 = this.NO_MODE;
                } else {
                    i = random.nextDouble() < Config.pessimisticRate ? this.PESS_MODE : this.OPT_MODE;
                }
                UIMan.verbose(0, "PICKER: Iteration " + i2);
                UIMan.verbose(0, "PICKER: Num positive feedback: " + tIntHashSet.size() + ", Num negative feedback: " + tIntHashSet2.size());
                i2++;
                tIntHashSet3.removeAll(tIntHashSet2);
                tIntHashSet3.removeAll(tIntHashSet);
                int i5 = -1;
                if (Config.pickStrategy.equals(Config.RANDOM_ALL_STRAT)) {
                    TIntArrayList tIntArrayList = new TIntArrayList(tIntHashSet3);
                    i5 = tIntArrayList.get(random.nextInt(tIntArrayList.size()));
                } else if (Config.pickStrategy.equals(Config.RANDOM_ALT_STRAT) || Config.pickStrategy.equals(Config.RANDOM_ORACLE_STRAT)) {
                    if (i == this.PESS_MODE) {
                        UIMan.verbose(0, "PICKER: Pessimistic Mode");
                        TIntHashSet tIntHashSet6 = new TIntHashSet(tIntHashSet4);
                        tIntHashSet6.retainAll(tIntHashSet3);
                        if (Config.pickStrategy.equals(Config.RANDOM_ORACLE_STRAT)) {
                            tIntHashSet6.removeAll(tIntHashSet5);
                        }
                        TIntArrayList tIntArrayList2 = new TIntArrayList(tIntHashSet6);
                        if (tIntHashSet6.size() != 0) {
                            i5 = tIntArrayList2.get(random.nextInt(tIntArrayList2.size()));
                        }
                    } else if (i == this.OPT_MODE) {
                        UIMan.verbose(0, "PICKER: Optimistic Mode");
                        TIntHashSet tIntHashSet7 = new TIntHashSet(tIntHashSet3);
                        tIntHashSet7.removeAll(tIntHashSet4);
                        if (Config.pickStrategy.equals(Config.RANDOM_ORACLE_STRAT)) {
                            tIntHashSet7.retainAll(tIntHashSet5);
                        }
                        TIntArrayList tIntArrayList3 = new TIntArrayList(tIntHashSet7);
                        if (tIntHashSet7.size() != 0) {
                            i5 = tIntArrayList3.get(random.nextInt(tIntArrayList3.size()));
                        }
                    }
                } else if (Config.pickStrategy.equals(Config.MAXCROSS_ALL_STRAT)) {
                    TIntHashSet tIntHashSet8 = new TIntHashSet(tIntHashSet3);
                    i5 = Config.isParallel ? parallelSimulateFeedback(tIntHashSet8, this.NO_MODE, tIntHashSet4, tIntHashSet5) : simulateFeedback(tIntHashSet8, this.NO_MODE, tIntHashSet4, tIntHashSet5);
                } else if (!Config.pickStrategy.equals(Config.MAXCROSS_ALT_STRAT) && !Config.pickStrategy.equals(Config.MAXCROSS_ORACLE_STRAT)) {
                    UIMan.verbose(0, "PICKER: Incorrect strategy");
                } else if (i == this.PESS_MODE) {
                    UIMan.verbose(0, "PICKER: Pessimistic Mode");
                    TIntHashSet tIntHashSet9 = new TIntHashSet(tIntHashSet4);
                    tIntHashSet9.retainAll(tIntHashSet3);
                    if (Config.pickStrategy.equals(Config.MAXCROSS_ORACLE_STRAT)) {
                        tIntHashSet9.removeAll(tIntHashSet5);
                    }
                    i5 = Config.isParallel ? parallelSimulateFeedback(tIntHashSet9, this.PESS_MODE, tIntHashSet4, tIntHashSet5) : simulateFeedback(tIntHashSet9, this.PESS_MODE, tIntHashSet4, tIntHashSet5);
                } else if (i == this.OPT_MODE) {
                    UIMan.verbose(0, "PICKER: Optimistic Mode");
                    TIntHashSet tIntHashSet10 = new TIntHashSet(tIntHashSet3);
                    tIntHashSet10.removeAll(tIntHashSet4);
                    if (Config.pickStrategy.equals(Config.MAXCROSS_ORACLE_STRAT)) {
                        tIntHashSet10.retainAll(tIntHashSet5);
                    }
                    i5 = Config.isParallel ? parallelSimulateFeedback(tIntHashSet10, this.OPT_MODE, tIntHashSet4, tIntHashSet5) : simulateFeedback(tIntHashSet10, this.OPT_MODE, tIntHashSet4, tIntHashSet5);
                }
                if (i5 == -1) {
                    i4 = i == this.PESS_MODE ? this.OPT_MODE : this.PESS_MODE;
                } else {
                    if (tIntHashSet5.contains(i5)) {
                        tIntHashSet.add(i5);
                        this.conv.groundedClauses.add(new GClause(Config.feedbackWeight, i5));
                        UIMan.verbose(0, "PICKER: " + this.mln.getAtom(i5).toGroundString(this.mln) + " added as positive feedback.");
                    } else {
                        tIntHashSet2.add(i5);
                        this.conv.groundedClauses.add(new GClause(Config.feedbackWeight, -i5));
                        UIMan.verbose(0, "PICKER: " + this.mln.getAtom(i5).toGroundString(this.mln) + " added as negative feedback.");
                    }
                    Set<Integer> set = this.solver.solve(this.conv.groundedClauses).get(0).right;
                    tIntHashSet4.clear();
                    for (Integer num2 : set) {
                        if (num2.intValue() > 0 && this.conv.queries.contains(num2.intValue())) {
                            tIntHashSet4.add(num2.intValue());
                        }
                    }
                    TIntHashSet tIntHashSet11 = new TIntHashSet(tIntHashSet4);
                    tIntHashSet11.retainAll(tIntHashSet5);
                    UIMan.verbose(0, "PICKER: " + tIntHashSet4.size() + " reports generated. " + tIntHashSet11.size() + " are true reports.");
                }
            }
        } catch (FileNotFoundException e) {
            throw new RuntimeException(e);
        }
    }

    private int simulateFeedback(TIntHashSet tIntHashSet, int i, TIntHashSet tIntHashSet2, TIntHashSet tIntHashSet3) {
        int i2 = -1;
        int i3 = 0;
        int i4 = i == this.NO_MODE ? this.PESS_MODE : i;
        do {
            TIntIterator it = tIntHashSet.iterator();
            while (it.hasNext()) {
                int next = it.next();
                HashSet hashSet = new HashSet(this.conv.groundedClauses);
                double d = Config.feedbackWeight;
                int[] iArr = new int[1];
                iArr[0] = i4 == this.OPT_MODE ? next : -next;
                hashSet.add(new GClause(d, iArr));
                Set<Integer> set = this.solver.solve(hashSet).get(0).right;
                TIntHashSet tIntHashSet4 = new TIntHashSet();
                for (Integer num : set) {
                    if (num.intValue() > 0 && this.conv.queries.contains(num.intValue())) {
                        tIntHashSet4.add(num.intValue());
                    }
                }
                TIntHashSet tIntHashSet5 = new TIntHashSet(tIntHashSet4);
                TIntHashSet tIntHashSet6 = new TIntHashSet(tIntHashSet4);
                tIntHashSet5.retainAll(tIntHashSet2);
                tIntHashSet6.removeAll(tIntHashSet2);
                int size = (tIntHashSet2.size() - tIntHashSet5.size()) + tIntHashSet6.size();
                TIntHashSet tIntHashSet7 = new TIntHashSet(tIntHashSet4);
                tIntHashSet7.retainAll(tIntHashSet3);
                UIMan.verbose(0, "PICKER: " + tIntHashSet4.size() + " reports generated. " + tIntHashSet7.size() + " are true reports.");
                UIMan.verbose(0, "PICKER: " + this.mln.getAtom(next).toGroundString(this.mln) + " as " + (i4 == this.OPT_MODE ? " positive" : " negative") + " feedback has distance " + size);
                if (size >= i3) {
                    i3 = size;
                    i2 = next;
                }
            }
            i4 = i4 == this.PESS_MODE ? this.OPT_MODE : this.PESS_MODE;
            if (i != this.NO_MODE) {
                break;
            }
        } while (i4 == this.OPT_MODE);
        return i2;
    }

    private int parallelSimulateFeedback(TIntHashSet tIntHashSet, int i, TIntHashSet tIntHashSet2, TIntHashSet tIntHashSet3) {
        Future<Pair<Double, Set<Integer>>> solve;
        int i2 = -1;
        int i3 = 0;
        int numWorkers = this.master.getNumWorkers() * 4;
        TIntArrayList tIntArrayList = new TIntArrayList(tIntHashSet);
        int i4 = 0;
        int i5 = this.PESS_MODE;
        while (i4 < tIntHashSet.size()) {
            ArrayList<Future> arrayList = new ArrayList();
            HashMap hashMap = new HashMap();
            for (int i6 = 0; i6 < numWorkers && i4 < tIntHashSet.size(); i6++) {
                Integer valueOf = Integer.valueOf(tIntArrayList.get(i4));
                HashSet hashSet = new HashSet(this.conv.groundedClauses);
                Integer valueOf2 = i != this.NO_MODE ? Integer.valueOf(i == this.OPT_MODE ? valueOf.intValue() : -valueOf.intValue()) : Integer.valueOf(i5 == this.OPT_MODE ? valueOf.intValue() : -valueOf.intValue());
                hashSet.add(new GClause(Config.feedbackWeight, valueOf2.intValue()));
                do {
                    solve = this.master.solve(hashSet);
                } while (solve == null);
                arrayList.add(solve);
                hashMap.put(solve, valueOf2);
                if (i != this.NO_MODE) {
                    i4++;
                } else if (i5 == this.OPT_MODE) {
                    i4++;
                    i5 = this.PESS_MODE;
                } else {
                    i5 = this.OPT_MODE;
                }
            }
            for (Future future : arrayList) {
                Integer num = (Integer) hashMap.get(future);
                String str = Integer.signum(num.intValue()) == 1 ? " positive" : " negative";
                try {
                    Pair pair = (Pair) future.get(1800L, TimeUnit.SECONDS);
                    if (pair != null) {
                        Set<Integer> set = (Set) pair.right;
                        TIntHashSet tIntHashSet4 = new TIntHashSet();
                        for (Integer num2 : set) {
                            if (num2.intValue() > 0 && this.conv.queries.contains(num2.intValue())) {
                                tIntHashSet4.add(num2.intValue());
                            }
                        }
                        TIntHashSet tIntHashSet5 = new TIntHashSet(tIntHashSet4);
                        TIntHashSet tIntHashSet6 = new TIntHashSet(tIntHashSet4);
                        tIntHashSet5.retainAll(tIntHashSet2);
                        tIntHashSet6.removeAll(tIntHashSet2);
                        int size = (tIntHashSet2.size() - tIntHashSet5.size()) + tIntHashSet6.size();
                        TIntHashSet tIntHashSet7 = new TIntHashSet(tIntHashSet4);
                        tIntHashSet7.retainAll(tIntHashSet3);
                        UIMan.verbose(0, "PICKER: " + tIntHashSet4.size() + " reports generated. " + tIntHashSet7.size() + " are true reports.");
                        UIMan.verbose(0, "PICKER: " + this.mln.getAtom(num.intValue() * Integer.signum(num.intValue())).toGroundString(this.mln) + " as " + str + " feedback has distance " + size);
                        if (size >= i3) {
                            i3 = size;
                            i2 = num.intValue() * Integer.signum(num.intValue());
                        }
                    }
                } catch (Exception e) {
                    UIMan.verbose(0, "PICKER: Error. Skipped simulating " + this.mln.getAtom(num.intValue() * Integer.signum(num.intValue())).toGroundString(this.mln) + " as " + str + " feedback");
                }
            }
        }
        return i2;
    }
}
