package edu.gatech.mln.feedbackSelection;

import edu.gatech.mln.CardinalityConstr;
import edu.gatech.mln.GClause;
import edu.gatech.mln.MarkovLogicNetwork;
import edu.gatech.mln.infer.LazySolver;
import edu.gatech.mln.infer.LazySolverILP;
import edu.gatech.mln.infer.LazySolverLBX;
import edu.gatech.mln.infer.LazySolverLBXMCS;
import edu.gatech.mln.infer.LazySolverMCSls;
import edu.gatech.mln.infer.LazySolverMaxSAT;
import edu.gatech.mln.infer.LazySolverT;
import edu.gatech.mln.infer.LazySolverTwoStage;
import edu.gatech.mln.infer.LazySolverWalkSAT;
import edu.gatech.mln.infer.querydriven.MaxSATUtils;
import edu.gatech.mln.parser.CommandOptions;
import edu.gatech.mln.util.Config;
import edu.gatech.mln.util.UIMan;
import gnu.trove.iterator.TIntIterator;
import gnu.trove.list.array.TIntArrayList;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
import java.io.File;
import java.io.FileNotFoundException;
import java.util.HashSet;
import java.util.List;
import java.util.Random;
import java.util.Scanner;
import java.util.Set;
import org.apache.commons.lang3.Pair;

/* loaded from: input_file:edu/gatech/mln/feedbackSelection/AbductionPicker.class */
public class AbductionPicker {
    private LazySolver solver;
    private MarkovLogicNetwork mln;
    private AbductionMaxSATConvertor conv = new AbductionMaxSATConvertor();

    public AbductionPicker(CommandOptions commandOptions) {
        this.conv.convert(commandOptions);
        this.mln = this.conv.getMLN();
        if (Config.solver.equals(Config.ILP_SOLVER)) {
            this.solver = new LazySolverILP(this.mln);
            return;
        }
        if (Config.solver.equals(Config.LBX_SOLVER)) {
            this.solver = new LazySolverLBX(this.mln);
            return;
        }
        if (Config.solver.equals(Config.MCSLS_SOLVER)) {
            this.solver = new LazySolverMCSls(this.mln);
            return;
        }
        if (Config.solver.equals(Config.WALK_SOLVER)) {
            this.solver = new LazySolverWalkSAT(this.mln);
            return;
        }
        if (Config.solver.equals(Config.TUFFY_SOLVER)) {
            this.solver = new LazySolverT(this.mln);
            return;
        }
        if (Config.solver.equals(Config.TWO_STAGE_SOLVER)) {
            this.solver = new LazySolverTwoStage(this.mln);
        } else if (Config.solver.equals(Config.LBX_MCS_SOLVER)) {
            this.solver = new LazySolverLBXMCS(this.mln);
        } else {
            this.solver = new LazySolverMaxSAT(this.mln);
        }
    }

    public void runLoop(CommandOptions commandOptions) {
        List<Pair<Double, Set<Integer>>> solve;
        int i = 0;
        TIntHashSet tIntHashSet = new TIntHashSet();
        TIntHashSet tIntHashSet2 = new TIntHashSet();
        TIntHashSet tIntHashSet3 = new TIntHashSet(this.conv.spurTuples.keySet());
        TIntHashSet tIntHashSet4 = new TIntHashSet(this.conv.revContrAuxiVars.keySet());
        TIntHashSet tIntHashSet5 = new TIntHashSet();
        TIntHashSet tIntHashSet6 = new TIntHashSet();
        Random random = new Random(System.currentTimeMillis());
        TIntHashSet tIntHashSet7 = new TIntHashSet();
        try {
            Scanner scanner = new Scanner(new File(commandOptions.fOracle));
            while (scanner.hasNextLine()) {
                String nextLine = scanner.nextLine();
                if (!nextLine.startsWith("//")) {
                    tIntHashSet7.add(this.mln.getAtomID(this.mln.parseAtom(nextLine)).intValue());
                }
            }
            scanner.close();
            TIntHashSet tIntHashSet8 = new TIntHashSet();
            try {
                Scanner scanner2 = new Scanner(new File(commandOptions.oracleSpurTupleFile));
                while (scanner2.hasNextLine()) {
                    String nextLine2 = scanner2.nextLine();
                    if (!nextLine2.startsWith("//")) {
                        tIntHashSet8.add(this.mln.getAtomID(this.mln.parseAtom(nextLine2)).intValue());
                    }
                }
                scanner2.close();
                HashSet hashSet = new HashSet(this.conv.transformedClausesWithoutQ);
                TIntIterator it = tIntHashSet4.iterator();
                while (it.hasNext()) {
                    hashSet.add(new GClause(Config.hard_weight, it.next()));
                }
                for (Integer num : this.solver.solve(hashSet).get(0).right) {
                    if (num.intValue() > 0 && this.conv.queries.contains(num.intValue())) {
                        tIntHashSet6.add(num.intValue());
                    }
                }
                TIntHashSet tIntHashSet9 = new TIntHashSet(tIntHashSet6);
                tIntHashSet9.retainAll(tIntHashSet7);
                UIMan.verbose(0, "PICKER: For original run " + tIntHashSet6.size() + " reports generated. " + tIntHashSet9.size() + " are true reports.");
                if (Config.numFeedback == -1 || tIntHashSet4.size() <= Config.numFeedback) {
                    solve = this.solver.solve(this.conv.transformedClauses);
                } else if (Config.solver.equals(Config.ILP_SOLVER)) {
                    HashSet hashSet2 = new HashSet();
                    hashSet2.add(new CardinalityConstr(CardinalityConstr.Kind.AT_LEAST, tIntHashSet4.toArray(), tIntHashSet4.size() - Config.numFeedback));
                    solve = ((LazySolverILP) this.solver).solve(this.conv.transformedClauses, hashSet2);
                } else {
                    HashSet hashSet3 = new HashSet(this.conv.transformedClauses);
                    hashSet3.addAll(MaxSATUtils.atLeast(tIntHashSet4, tIntHashSet4.size() - Config.numFeedback));
                    solve = this.solver.solve(hashSet3);
                }
                for (Integer num2 : solve.get(0).right) {
                    if (num2.intValue() > 0 && this.conv.queries.contains(num2.intValue())) {
                        tIntHashSet5.add(num2.intValue());
                    }
                }
                TIntHashSet tIntHashSet10 = new TIntHashSet(tIntHashSet5);
                tIntHashSet10.retainAll(tIntHashSet7);
                UIMan.verbose(0, "PICKER: For abduction run " + tIntHashSet5.size() + " reports generated. " + tIntHashSet10.size() + " are true reports.");
                while (i < Config.pickBudget && !tIntHashSet3.isEmpty()) {
                    UIMan.verbose(0, "PICKER: Iteration " + i);
                    i++;
                    TIntHashSet tIntHashSet11 = new TIntHashSet();
                    if (Config.pickStrategy.equals(Config.RANDOM_ALL_STRAT)) {
                        TIntArrayList tIntArrayList = new TIntArrayList(tIntHashSet4);
                        int size = Config.numFeedback == -1 ? tIntHashSet4.size() : Config.numFeedback;
                        for (int i2 = 0; i2 < size && i2 < tIntHashSet4.size(); i2++) {
                            tIntHashSet11.add(tIntArrayList.removeAt(random.nextInt(tIntArrayList.size())));
                        }
                    } else if (Config.pickStrategy.equals(Config.MAXCROSS_ALL_STRAT)) {
                        TIntSet keySet = this.conv.revContrAuxiVars.keySet();
                        Set<Integer> set = solve.get(0).right;
                        TIntIterator it2 = keySet.iterator();
                        while (it2.hasNext()) {
                            int next = it2.next();
                            if (!set.contains(Integer.valueOf(next))) {
                                tIntHashSet11.add(next);
                            }
                        }
                    } else {
                        UIMan.verbose(0, "PICKER: Incorrect strategy");
                    }
                    TIntIterator it3 = tIntHashSet11.iterator();
                    while (it3.hasNext()) {
                        int next2 = it3.next();
                        int i3 = this.conv.revContrAuxiVars.get(next2);
                        if (tIntHashSet3.contains(i3)) {
                            tIntHashSet4.remove(next2);
                            if (tIntHashSet8.contains(i3)) {
                                tIntHashSet.add(i3);
                                this.conv.transformedClauses.add(new GClause(Config.hard_weight, next2));
                                this.conv.transformedClausesWithoutQ.add(new GClause(Config.hard_weight, next2));
                                UIMan.verbose(0, "PICKER: " + this.mln.getAtom(i3).toGroundString(this.mln) + " added as positive feedback.");
                            } else {
                                tIntHashSet2.add(i3);
                                this.conv.transformedClauses.add(new GClause(Config.hard_weight, -next2));
                                this.conv.transformedClausesWithoutQ.add(new GClause(Config.hard_weight, -next2));
                                UIMan.verbose(0, "PICKER: " + this.mln.getAtom(i3).toGroundString(this.mln) + " added as negative feedback.");
                            }
                        }
                    }
                    tIntHashSet3.removeAll(tIntHashSet2);
                    tIntHashSet3.removeAll(tIntHashSet);
                    UIMan.verbose(0, "PICKER: Num positive feedback: " + tIntHashSet.size() + ", Num negative feedback: " + tIntHashSet2.size());
                    HashSet hashSet4 = new HashSet(this.conv.transformedClausesWithoutQ);
                    TIntIterator it4 = tIntHashSet4.iterator();
                    while (it4.hasNext()) {
                        hashSet4.add(new GClause(Config.hard_weight, it4.next()));
                    }
                    Set<Integer> set2 = this.solver.solve(hashSet4).get(0).right;
                    tIntHashSet6.clear();
                    for (Integer num3 : set2) {
                        if (num3.intValue() > 0 && this.conv.queries.contains(num3.intValue())) {
                            tIntHashSet6.add(num3.intValue());
                        }
                    }
                    TIntHashSet tIntHashSet12 = new TIntHashSet(tIntHashSet6);
                    tIntHashSet12.retainAll(tIntHashSet7);
                    UIMan.verbose(0, "PICKER: For original run " + tIntHashSet6.size() + " reports generated. " + tIntHashSet12.size() + " are true reports.");
                    if (Config.numFeedback == -1 || tIntHashSet4.size() <= Config.numFeedback) {
                        solve = this.solver.solve(this.conv.transformedClauses);
                    } else if (Config.solver.equals(Config.ILP_SOLVER)) {
                        HashSet hashSet5 = new HashSet();
                        hashSet5.add(new CardinalityConstr(CardinalityConstr.Kind.AT_LEAST, tIntHashSet4.toArray(), tIntHashSet4.size() - Config.numFeedback));
                        solve = ((LazySolverILP) this.solver).solve(this.conv.transformedClauses, hashSet5);
                    } else {
                        HashSet hashSet6 = new HashSet(this.conv.transformedClauses);
                        hashSet6.addAll(MaxSATUtils.atLeast(tIntHashSet4, tIntHashSet4.size() - Config.numFeedback));
                        solve = this.solver.solve(hashSet6);
                    }
                    Set<Integer> set3 = solve.get(0).right;
                    tIntHashSet5.clear();
                    for (Integer num4 : set3) {
                        if (num4.intValue() > 0 && this.conv.queries.contains(num4.intValue())) {
                            tIntHashSet5.add(num4.intValue());
                        }
                    }
                    TIntHashSet tIntHashSet13 = new TIntHashSet(tIntHashSet5);
                    tIntHashSet13.retainAll(tIntHashSet7);
                    UIMan.verbose(0, "PICKER: For abduction run " + tIntHashSet5.size() + " reports generated. " + tIntHashSet13.size() + " are true reports.");
                }
            } catch (FileNotFoundException e) {
                throw new RuntimeException(e);
            }
        } catch (FileNotFoundException e2) {
            throw new RuntimeException(e2);
        }
    }
}
