package edu.gatech.mln.feedbackSelection;

import edu.gatech.mln.Atom;
import edu.gatech.mln.Clause;
import edu.gatech.mln.GClause;
import edu.gatech.mln.MarkovLogicNetwork;
import edu.gatech.mln.Predicate;
import edu.gatech.mln.db.RDB;
import edu.gatech.mln.infer.querydriven.MaxSATUtils;
import edu.gatech.mln.parser.CommandOptions;
import edu.gatech.mln.util.Config;
import edu.gatech.mln.util.FileMan;
import edu.gatech.mln.util.UIMan;
import gnu.trove.iterator.TIntIterator;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Scanner;
import java.util.Set;

/* loaded from: input_file:edu/gatech/mln/feedbackSelection/ProvenanceWeights.class */
public class ProvenanceWeights {
    private MarkovLogicNetwork mln;
    protected PrintWriter logOut;
    CommandOptions options;
    static final /* synthetic */ boolean $assertionsDisabled;
    protected Set<GClause> groundedClauses = new HashSet();
    protected double hardWeight = Config.hard_weight;
    private Set<Integer> trueHardEvidence = new HashSet();
    private Set<Integer> falseHardEvidence = new HashSet();
    private Map<Integer, Double> trueSoftEvidence = new HashMap();
    private Map<Integer, Double> falseSoftEvidence = new HashMap();
    private Set<Integer> evidenceTuples = new HashSet();
    private Map<Integer, Set<GClause>> parents = new HashMap();

    /* loaded from: input_file:edu/gatech/mln/feedbackSelection/ProvenanceWeights$AtomWrapper.class */
    class AtomWrapper implements Comparable<AtomWrapper> {
        private Atom a;
        private double weight;

        public AtomWrapper(Atom atom, double d) {
            this.a = atom;
            this.weight = d;
        }

        public Atom getAtom() {
            return this.a;
        }

        public double getWeight() {
            return this.weight;
        }

        @Override // java.lang.Comparable
        public int compareTo(AtomWrapper atomWrapper) {
            return (int) (this.weight - atomWrapper.weight);
        }
    }

    static {
        $assertionsDisabled = !ProvenanceWeights.class.desiredAssertionStatus();
    }

    public ProvenanceWeights(CommandOptions commandOptions) {
        this.options = null;
        this.options = commandOptions;
        Clause.mappingFromID2Const = new HashMap<>();
        Clause.mappingFromID2Desc = new HashMap<>();
        UIMan.println(">>> Connecting to RDBMS at " + Config.db_url);
        RDB rDBbyConfig = RDB.getRDBbyConfig();
        rDBbyConfig.resetSchema(Config.db_schema);
        this.logOut = FileMan.getPrintWriter(commandOptions.fout);
        this.mln = new MarkovLogicNetwork();
        this.mln.setDB(rDBbyConfig);
        this.mln.loadPrograms(commandOptions.fprog.split(","));
        this.mln.loadEvidences(commandOptions.fevid.split(","));
        this.mln.materializeTables();
        this.mln.prepareDB(rDBbyConfig);
        fillUniverse();
        generateParents();
        generateViolatedRules();
        generateTupleDistance();
    }

    public MarkovLogicNetwork getMLN() {
        return this.mln;
    }

    protected void generateViolatedRules() {
        String[] split = this.options.fOracleTuples.split(",");
        TIntHashSet tIntHashSet = new TIntHashSet();
        for (String str : split) {
            try {
                Scanner scanner = new Scanner(new File(str));
                while (scanner.hasNextLine()) {
                    String nextLine = scanner.nextLine();
                    if (!nextLine.startsWith("//")) {
                        tIntHashSet.add(this.mln.getAtomID(this.mln.parseAtom(nextLine)).intValue());
                    }
                }
                scanner.close();
            } catch (FileNotFoundException e) {
                throw new RuntimeException(e);
            }
        }
        HashSet hashSet = new HashSet();
        HashSet hashSet2 = new HashSet();
        HashSet hashSet3 = new HashSet();
        HashSet hashSet4 = new HashSet();
        HashSet hashSet5 = new HashSet();
        for (GClause gClause : this.groundedClauses) {
            boolean z = true;
            boolean z2 = true;
            boolean z3 = false;
            int i = -1;
            HashSet hashSet6 = new HashSet();
            for (int i2 : gClause.lits) {
                if (i2 < 0) {
                    if (!tIntHashSet.contains(-i2)) {
                        z = false;
                        if (this.evidenceTuples.contains(Integer.valueOf(-i2))) {
                            hashSet6.add(Integer.valueOf(-i2));
                        } else {
                            z2 = false;
                        }
                    }
                } else if (this.evidenceTuples.contains(Integer.valueOf(i2))) {
                    if (tIntHashSet.contains(i2)) {
                        z = false;
                        hashSet6.add(Integer.valueOf(i2));
                    }
                } else if (!tIntHashSet.contains(i2)) {
                    z3 = true;
                    i = i2;
                }
            }
            if (z3) {
                if (!z && z2) {
                    hashSet2.add(gClause);
                    hashSet4.add(Integer.valueOf(i));
                    hashSet5.addAll(hashSet6);
                }
                if (z) {
                    hashSet.add(gClause);
                    hashSet3.add(Integer.valueOf(i));
                }
            }
        }
        try {
            PrintWriter printWriter = new PrintWriter(new File(this.options.fout));
            printWriter.println("// Clauses with all body tuples present but head absent in oracle");
            Iterator it = hashSet.iterator();
            while (it.hasNext()) {
                printWriter.println(((GClause) it.next()).toVerboseString(this.mln));
            }
            printWriter.println();
            printWriter.println("// Clauses with all body tuples except EDB tuples present and head absent in oracle");
            Iterator it2 = hashSet2.iterator();
            while (it2.hasNext()) {
                printWriter.println(((GClause) it2.next()).toVerboseString(this.mln));
            }
            printWriter.println();
            printWriter.flush();
            printWriter.close();
            PrintWriter printWriter2 = new PrintWriter(new File(String.valueOf(this.options.fout) + "_tuples"));
            printWriter2.println("// Head tuples in clauses with all body tuples present but head absent.");
            Iterator it3 = hashSet3.iterator();
            while (it3.hasNext()) {
                printWriter2.println(this.mln.getAtom(Math.abs(((Integer) it3.next()).intValue())).toGroundString(this.mln));
            }
            printWriter2.println();
            printWriter2.println("// Head tuples in clauses with all body tuples except EDB tuples present and head absent.");
            Iterator it4 = hashSet4.iterator();
            while (it4.hasNext()) {
                printWriter2.println(this.mln.getAtom(Math.abs(((Integer) it4.next()).intValue())).toGroundString(this.mln));
            }
            printWriter2.println();
            printWriter2.println("// Missing EDB tuples in clauses with all body tuples except EDB tuples present and head absent.");
            Iterator it5 = hashSet5.iterator();
            while (it5.hasNext()) {
                printWriter2.println(this.mln.getAtom(Math.abs(((Integer) it5.next()).intValue())).toGroundString(this.mln));
            }
            printWriter2.flush();
            printWriter2.close();
        } catch (IOException e2) {
            throw new RuntimeException(e2);
        }
    }

    protected void generateParents() {
        System.out.println("Generating Parents Map");
        for (GClause gClause : this.groundedClauses) {
            boolean z = false;
            for (int i : gClause.lits) {
                if (i >= 0 && !this.evidenceTuples.contains(Integer.valueOf(Math.abs(i)))) {
                    int abs = Math.abs(i);
                    z = true;
                    if (this.parents.containsKey(Integer.valueOf(abs))) {
                        this.parents.get(Integer.valueOf(abs)).add(gClause);
                    } else {
                        HashSet hashSet = new HashSet();
                        hashSet.add(gClause);
                        this.parents.put(Integer.valueOf(abs), hashSet);
                    }
                }
            }
            if (!$assertionsDisabled && !z) {
                throw new AssertionError();
            }
        }
        System.out.println("Leave generating parents map");
    }

    protected void generateTupleDistance() {
        UIMan.verbose(0, "Generate tuple distance file.");
        HashMap hashMap = new HashMap();
        TIntIterator it = MaxSATUtils.getAllAtoms(this.groundedClauses).iterator();
        while (it.hasNext()) {
            int next = it.next();
            if (this.mln.getAtom(next).pred.getName().equals("racePairs_cs")) {
                hashMap.put(Integer.valueOf(next), 0);
            }
        }
        boolean z = true;
        while (z) {
            z = false;
            for (GClause gClause : this.groundedClauses) {
                if (gClause.lits.length != 1) {
                    List<Integer> positiveAts = MaxSATUtils.getPositiveAts(gClause);
                    Iterator<Integer> it2 = MaxSATUtils.getNegativeAts(gClause).iterator();
                    while (it2.hasNext()) {
                        int intValue = it2.next().intValue();
                        Iterator<Integer> it3 = positiveAts.iterator();
                        while (it3.hasNext()) {
                            int intValue2 = it3.next().intValue();
                            if (hashMap.containsKey(Integer.valueOf(intValue2))) {
                                z |= updateMapIfLess(hashMap, intValue, hashMap.get(Integer.valueOf(intValue2)).intValue() + 1);
                            }
                        }
                    }
                }
            }
        }
        try {
            PrintWriter printWriter = new PrintWriter(new File(String.valueOf(new File(this.options.fout).getParent()) + File.separator + "distance_tuple.txt"));
            for (Map.Entry<Integer, Integer> entry : hashMap.entrySet()) {
                printWriter.println(String.valueOf(this.mln.getAtom(entry.getKey().intValue()).toGroundString(this.mln)) + " " + entry.getValue());
            }
            printWriter.flush();
            printWriter.close();
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private boolean updateMapIfLess(Map<Integer, Integer> map, int i, int i2) {
        Integer num = map.get(Integer.valueOf(i));
        if (num == null) {
            map.put(Integer.valueOf(i), Integer.valueOf(i2));
            return true;
        }
        if (num.intValue() <= i2) {
            return false;
        }
        map.put(Integer.valueOf(i), Integer.valueOf(i2));
        return true;
    }

    protected double getAtomProvenanceWeight(int i) {
        double d = 0.0d;
        Iterator<GClause> it = this.parents.get(Integer.valueOf(i)).iterator();
        while (it.hasNext()) {
            d += it.next().weight;
        }
        return d;
    }

    public void flushToFile() {
        try {
            PrintWriter printWriter = new PrintWriter(new File(this.options.fProvWeights));
            Set<Integer> keySet = this.parents.keySet();
            ArrayList arrayList = new ArrayList();
            Iterator<Integer> it = keySet.iterator();
            while (it.hasNext()) {
                int intValue = it.next().intValue();
                arrayList.add(new AtomWrapper(this.mln.getAtom(intValue), getAtomProvenanceWeight(intValue)));
            }
            Collections.sort(arrayList);
            for (int i = 0; i < arrayList.size(); i++) {
                AtomWrapper atomWrapper = (AtomWrapper) arrayList.get(i);
                printWriter.print(String.valueOf(atomWrapper.getAtom().toGroundString(this.mln)) + " " + atomWrapper.getWeight() + "\n");
            }
            printWriter.println();
            printWriter.flush();
            printWriter.close();
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    protected void fillUniverse() {
        this.groundedClauses = new HashSet();
        if (Config.gcLoadFile == null) {
            throw new RuntimeException("Need to implement constraints grounding without warm start!");
        }
        UIMan.println("Loading grounded clauses from file " + Config.gcLoadFile + ".");
        try {
            FileInputStream fileInputStream = FileMan.getFileInputStream(Config.gcLoadFile);
            loadGroundedConstraints(fileInputStream);
            fileInputStream.close();
            loadEvidenceConstraints(MaxSATUtils.getAllAtoms(this.groundedClauses));
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public double getWeight(GClause gClause) {
        return gClause.isHardClause() ? this.hardWeight : gClause.weight;
    }

    private void loadEvidenceConstraints(TIntSet tIntSet) {
        Iterator<Predicate> it = this.mln.getAllPred().iterator();
        while (it.hasNext()) {
            Predicate next = it.next();
            for (Atom atom : next.getHardEvidences()) {
                int intValue = this.mln.getAtomID(atom.base()).intValue();
                if (atom.truth.booleanValue()) {
                    this.trueHardEvidence.add(Integer.valueOf(intValue));
                    this.evidenceTuples.add(Integer.valueOf(intValue));
                } else {
                    this.falseHardEvidence.add(Integer.valueOf(intValue));
                    this.evidenceTuples.add(Integer.valueOf(intValue));
                }
            }
            for (Atom atom2 : next.getSoftEvidences()) {
                int intValue2 = this.mln.getAtomID(atom2.base()).intValue();
                double doubleValue = atom2.prior.doubleValue();
                if (doubleValue > 0.0d) {
                    this.trueSoftEvidence.put(Integer.valueOf(intValue2), Double.valueOf(doubleValue));
                    this.evidenceTuples.add(Integer.valueOf(intValue2));
                } else if (atom2.prior.doubleValue() < 0.0d) {
                    this.falseSoftEvidence.put(Integer.valueOf(intValue2), Double.valueOf(0.0d - doubleValue));
                    this.evidenceTuples.add(Integer.valueOf(intValue2));
                }
            }
        }
        TIntIterator it2 = tIntSet.iterator();
        while (it2.hasNext()) {
            int next2 = it2.next();
            Atom atom3 = this.mln.getAtom(next2);
            if (atom3.pred.isClosedWorld() && !this.trueHardEvidence.contains(Integer.valueOf(next2)) && !this.falseHardEvidence.contains(atom3) && !this.trueSoftEvidence.containsKey(Integer.valueOf(next2)) && !this.falseSoftEvidence.containsKey(Integer.valueOf(next2))) {
                this.evidenceTuples.add(Integer.valueOf(next2));
            }
        }
    }

    private void loadGroundedConstraints(InputStream inputStream) {
        try {
            BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(inputStream));
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    return;
                }
                if (!readLine.startsWith("//")) {
                    String[] split = readLine.split(": ");
                    double parseDouble = split[0].equals("infi") ? Config.hard_weight : Double.parseDouble(split[0]);
                    String[] split2 = split[1].split(", ");
                    int[] iArr = new int[split2.length];
                    for (int i = 0; i < split2.length; i++) {
                        String[] split3 = split2[i].split(" ");
                        boolean z = split3.length > 1;
                        int intValue = this.mln.getAtomID(this.mln.parseAtom(split3[split3.length - 1])).intValue();
                        if (z) {
                            intValue = 0 - intValue;
                        }
                        iArr[i] = intValue;
                    }
                    GClause matchGroundedClause = this.mln.matchGroundedClause(parseDouble, iArr);
                    if (matchGroundedClause != null) {
                        this.groundedClauses.add(matchGroundedClause);
                    }
                }
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }
}
