package edu.gatech.mln.feedbackSelection;

import edu.gatech.mln.Atom;
import edu.gatech.mln.Clause;
import edu.gatech.mln.GClause;
import edu.gatech.mln.MarkovLogicNetwork;
import edu.gatech.mln.Predicate;
import edu.gatech.mln.db.RDB;
import edu.gatech.mln.infer.querydriven.MaxSATUtils;
import edu.gatech.mln.parser.CommandOptions;
import edu.gatech.mln.util.Config;
import edu.gatech.mln.util.ContainerMan;
import edu.gatech.mln.util.FileMan;
import edu.gatech.mln.util.UIMan;
import gnu.trove.iterator.TIntDoubleIterator;
import gnu.trove.iterator.TIntIterator;
import gnu.trove.map.TIntDoubleMap;
import gnu.trove.map.TIntIntMap;
import gnu.trove.map.hash.TIntDoubleHashMap;
import gnu.trove.map.hash.TIntIntHashMap;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Scanner;
import java.util.Set;

/* loaded from: input_file:edu/gatech/mln/feedbackSelection/AbductionMaxSATConvertor.class */
public class AbductionMaxSATConvertor {
    private MarkovLogicNetwork mln;
    protected PrintWriter logOut;
    protected TIntSet queries;
    protected TIntDoubleMap spurTuples;
    protected Set<GClause> transformedClauses;
    protected Set<GClause> transformedClausesWithoutQ;
    protected TIntIntMap contrAuxiVars;
    protected TIntIntMap revContrAuxiVars;
    protected TIntIntMap negAuxiVars;
    protected TIntIntMap revNegAuxiVars;
    protected Set<GClause> groundedClauses = new HashSet();
    protected double hardWeight = Config.hard_weight;
    private Set<Integer> trueHardEvidence = new HashSet();
    private Set<Integer> falseHardEvidence = new HashSet();
    private Map<Integer, Double> trueSoftEvidence = new HashMap();
    private Map<Integer, Double> falseSoftEvidence = new HashMap();

    public void convert(CommandOptions commandOptions) {
        Clause.mappingFromID2Const = new HashMap<>();
        Clause.mappingFromID2Desc = new HashMap<>();
        UIMan.println(">>> Connecting to RDBMS at " + Config.db_url);
        RDB rDBbyConfig = RDB.getRDBbyConfig();
        rDBbyConfig.resetSchema(Config.db_schema);
        this.logOut = FileMan.getPrintWriter(commandOptions.fout);
        this.mln = new MarkovLogicNetwork();
        this.mln.setDB(rDBbyConfig);
        this.mln.loadPrograms(commandOptions.fprog.split(","));
        this.mln.loadEvidences(commandOptions.fevid.split(","));
        this.mln.materializeTables();
        this.mln.prepareDB(rDBbyConfig);
        loadQueries(commandOptions.fquery);
        loadSpuriousTuples(commandOptions.baseSpurTupleFile);
        fillUniverse();
        transform();
    }

    public MarkovLogicNetwork getMLN() {
        return this.mln;
    }

    private void transform() {
        int i = 0;
        TIntIterator it = MaxSATUtils.getAllAtoms(this.groundedClauses).iterator();
        while (it.hasNext()) {
            int next = it.next();
            if (next > i) {
                i = next;
            }
        }
        TIntDoubleIterator it2 = this.spurTuples.iterator();
        this.contrAuxiVars = new TIntIntHashMap();
        this.negAuxiVars = new TIntIntHashMap();
        this.revContrAuxiVars = new TIntIntHashMap();
        this.revNegAuxiVars = new TIntIntHashMap();
        while (it2.hasNext()) {
            it2.advance();
            int i2 = i + 1;
            this.contrAuxiVars.put(it2.key(), i2);
            this.revContrAuxiVars.put(i2, it2.key());
            i = i2 + 1;
            this.negAuxiVars.put(it2.key(), i);
            this.revNegAuxiVars.put(i, it2.key());
        }
        this.transformedClauses = new HashSet();
        this.transformedClausesWithoutQ = new HashSet();
        for (GClause gClause : this.groundedClauses) {
            int[] iArr = new int[gClause.lits.length];
            for (int i3 = 0; i3 < gClause.lits.length; i3++) {
                int i4 = gClause.lits[i3];
                if (!this.spurTuples.containsKey(Math.abs(gClause.lits[i3]))) {
                    iArr[i3] = i4;
                } else if (i4 > 0) {
                    iArr[i3] = i4;
                } else {
                    iArr[i3] = -this.negAuxiVars.get(-i4);
                }
            }
            this.transformedClauses.add(new GClause(Config.hard_weight, iArr));
            this.transformedClausesWithoutQ.add(new GClause(Config.hard_weight, iArr));
        }
        TIntDoubleIterator it3 = this.spurTuples.iterator();
        while (it3.hasNext()) {
            it3.advance();
            int key = it3.key();
            int[] iArr2 = {-key, -this.contrAuxiVars.get(key), this.negAuxiVars.get(key)};
            this.transformedClauses.add(new GClause(Config.hard_weight, iArr2));
            this.transformedClausesWithoutQ.add(new GClause(Config.hard_weight, iArr2));
        }
        double d = 0.0d;
        TIntDoubleIterator it4 = this.spurTuples.iterator();
        while (it4.hasNext()) {
            it4.advance();
            int key2 = it4.key();
            double value = it4.value();
            this.transformedClauses.add(new GClause(value, this.contrAuxiVars.get(key2)));
            this.transformedClausesWithoutQ.add(new GClause(value, this.contrAuxiVars.get(key2)));
            d += it4.value();
        }
        double d2 = d + 1.0d;
        TIntIterator it5 = this.queries.iterator();
        while (it5.hasNext()) {
            int next2 = it5.next();
            this.transformedClauses.add(new GClause(d2, -next2));
            this.transformedClausesWithoutQ.add(new GClause(1.0d, -next2));
        }
    }

    public void flushToFile(CommandOptions commandOptions) {
        this.logOut.println("c");
        this.logOut.println("c Weighted Partial Max-SAT generated from user-guided program analysis.");
        this.logOut.println("c");
        int size = this.transformedClauses.size();
        long j = 0;
        for (GClause gClause : this.transformedClauses) {
            if (!gClause.isHardClause()) {
                j = (long) (j + Math.abs(gClause.weight));
            }
        }
        TIntSet allAtoms = MaxSATUtils.getAllAtoms(this.transformedClauses);
        int size2 = allAtoms.size() - 1;
        long j2 = j + 1;
        if (j2 < 0) {
            UIMan.verbose(1, "The sum of weights of soft constraints overflows, set it to " + Config.hard_weight);
            j2 = (long) Config.hard_weight;
        }
        this.logOut.println("p wcnf " + size2 + " " + size + " " + j2);
        for (GClause gClause2 : this.transformedClauses) {
            if (gClause2.isHardClause()) {
                this.logOut.print(j2);
            } else {
                this.logOut.print((int) gClause2.weight);
            }
            for (int i : gClause2.lits) {
                this.logOut.print(" " + i);
            }
            this.logOut.println(" 0");
        }
        this.logOut.flush();
        this.logOut.close();
        PrintWriter printWriter = FileMan.getPrintWriter(String.valueOf(commandOptions.fout) + "_map");
        TIntIterator it = allAtoms.iterator();
        while (it.hasNext()) {
            int next = it.next();
            if (this.revContrAuxiVars.containsKey(next)) {
                printWriter.println(String.valueOf(next) + " control var for " + this.mln.getAtom(this.revContrAuxiVars.get(next)).toGroundString(this.mln));
            } else if (this.revNegAuxiVars.containsKey(next)) {
                printWriter.println(String.valueOf(next) + " negative var for " + this.mln.getAtom(this.revNegAuxiVars.get(next)).toGroundString(this.mln));
            } else {
                printWriter.println(String.valueOf(next) + " " + this.mln.getAtom(next).toGroundString(this.mln));
            }
        }
        printWriter.flush();
        printWriter.close();
    }

    protected void fillUniverse() {
        this.groundedClauses = new HashSet();
        if (Config.gcLoadFile == null) {
            throw new RuntimeException("Need to implement constraints grounding without warm start!");
        }
        UIMan.println("Loading grounded clauses from file " + Config.gcLoadFile + ".");
        try {
            FileInputStream fileInputStream = FileMan.getFileInputStream(Config.gcLoadFile);
            loadGroundedConstraints(fileInputStream);
            fileInputStream.close();
            if (Config.revLoadFile != null) {
                UIMan.println("Loading grounded reverted constrains from file " + Config.revLoadFile + ".");
                try {
                    FileInputStream fileInputStream2 = FileMan.getFileInputStream(Config.revLoadFile);
                    loadRevertedConstraints(fileInputStream2);
                    fileInputStream2.close();
                } catch (IOException e) {
                    throw new RuntimeException(e);
                }
            }
            loadEvidenceConstraints(MaxSATUtils.getAllAtoms(this.groundedClauses));
        } catch (IOException e2) {
            throw new RuntimeException(e2);
        }
    }

    public void calculateHardWeight() {
        double d = 0.0d;
        for (GClause gClause : this.groundedClauses) {
            if (!gClause.isHardClause()) {
                d += gClause.weight;
            }
        }
        this.hardWeight = d + 1.0d;
    }

    public double getWeight(GClause gClause) {
        return gClause.isHardClause() ? this.hardWeight : gClause.weight;
    }

    public void loadQueries(String str) {
        this.queries = new TIntHashSet();
        try {
            Scanner scanner = new Scanner(new File(str));
            while (scanner.hasNextLine()) {
                String nextLine = scanner.nextLine();
                if (!nextLine.startsWith("//")) {
                    this.queries.add(this.mln.getAtomID(this.mln.parseAtom(nextLine)).intValue());
                }
            }
            scanner.close();
        } catch (FileNotFoundException e) {
            throw new RuntimeException(e);
        }
    }

    private void loadSpuriousTuples(String str) {
        this.spurTuples = new TIntDoubleHashMap();
        try {
            Scanner scanner = new Scanner(new File(str));
            while (scanner.hasNextLine()) {
                String nextLine = scanner.nextLine();
                if (!nextLine.startsWith("//")) {
                    String[] split = nextLine.split("\\s+");
                    Atom parseAtom = this.mln.parseAtom(split[0]);
                    this.spurTuples.put(this.mln.getAtomID(parseAtom).intValue(), Double.parseDouble(split[1]));
                }
            }
            scanner.close();
        } catch (FileNotFoundException e) {
            throw new RuntimeException(e);
        }
    }

    private void loadEvidenceConstraints(TIntSet tIntSet) {
        Iterator<Predicate> it = this.mln.getAllPred().iterator();
        while (it.hasNext()) {
            Predicate next = it.next();
            for (Atom atom : next.getHardEvidences()) {
                int intValue = this.mln.getAtomID(atom.base()).intValue();
                if (atom.truth.booleanValue()) {
                    this.trueHardEvidence.add(Integer.valueOf(intValue));
                    this.groundedClauses.add(new GClause(Config.hard_weight, intValue));
                } else {
                    this.falseHardEvidence.add(Integer.valueOf(intValue));
                    this.groundedClauses.add(new GClause(Config.hard_weight, -intValue));
                }
            }
            for (Atom atom2 : next.getSoftEvidences()) {
                int intValue2 = this.mln.getAtomID(atom2.base()).intValue();
                double doubleValue = atom2.prior.doubleValue();
                if (doubleValue > 0.0d) {
                    this.trueSoftEvidence.put(Integer.valueOf(intValue2), Double.valueOf(doubleValue));
                    this.groundedClauses.add(new GClause(doubleValue, intValue2));
                } else if (atom2.prior.doubleValue() < 0.0d) {
                    this.falseSoftEvidence.put(Integer.valueOf(intValue2), Double.valueOf(0.0d - doubleValue));
                    this.groundedClauses.add(new GClause(0.0d - doubleValue, -intValue2));
                }
            }
        }
        TIntIterator it2 = tIntSet.iterator();
        while (it2.hasNext()) {
            int next2 = it2.next();
            Atom atom3 = this.mln.getAtom(next2);
            if (atom3.pred.isClosedWorld() && !this.trueHardEvidence.contains(Integer.valueOf(next2)) && !this.falseHardEvidence.contains(atom3) && !this.trueSoftEvidence.containsKey(Integer.valueOf(next2)) && !this.falseSoftEvidence.containsKey(Integer.valueOf(next2))) {
                this.groundedClauses.add(new GClause(Config.hard_weight, -next2));
            }
        }
    }

    private void loadGroundedConstraints(InputStream inputStream) {
        try {
            BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(inputStream));
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    return;
                }
                if (!readLine.startsWith("//")) {
                    String[] split = readLine.split(": ");
                    double parseDouble = split[0].equals("infi") ? Config.hard_weight : Double.parseDouble(split[0]);
                    String[] split2 = split[1].split(", ");
                    int[] iArr = new int[split2.length];
                    for (int i = 0; i < split2.length; i++) {
                        String[] split3 = split2[i].split(" ");
                        boolean z = split3.length > 1;
                        int intValue = this.mln.getAtomID(this.mln.parseAtom(split3[split3.length - 1])).intValue();
                        if (z) {
                            intValue = 0 - intValue;
                        }
                        iArr[i] = intValue;
                    }
                    GClause matchGroundedClause = this.mln.matchGroundedClause(parseDouble, iArr);
                    if (matchGroundedClause != null) {
                        this.groundedClauses.add(matchGroundedClause);
                    }
                }
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public void loadRevertedConstraints(InputStream inputStream) {
        try {
            BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(inputStream));
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    return;
                }
                String[] split = readLine.split(": ");
                double parseDouble = split[0].equals("infi") ? Config.hard_weight : Double.parseDouble(split[0]);
                String[] split2 = split[1].split(", ");
                int[] iArr = new int[split2.length];
                for (int i = 0; i < split2.length; i++) {
                    String[] split3 = split2[i].split(" ");
                    boolean z = split3.length > 1;
                    int intValue = this.mln.getAtomID(this.mln.parseAtom(split3[split3.length - 1])).intValue();
                    if (z) {
                        intValue = 0 - intValue;
                    }
                    iArr[i] = intValue;
                }
                this.groundedClauses.add(new GClause(parseDouble, iArr));
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    protected Set<GClause> removeEDB(Set<GClause> set) {
        int i;
        HashSet hashSet = new HashSet();
        HashSet hashSet2 = new HashSet();
        HashSet hashSet3 = new HashSet();
        for (GClause gClause : set) {
            if (gClause.isHardClause() && gClause.lits.length == 1) {
                if (gClause.lits[0] > 0) {
                    if (hashSet2.contains(Integer.valueOf(gClause.lits[0]))) {
                        return null;
                    }
                    hashSet.add(Integer.valueOf(gClause.lits[0]));
                } else {
                    if (hashSet.contains(Integer.valueOf(-gClause.lits[0]))) {
                        return null;
                    }
                    hashSet2.add(Integer.valueOf(-gClause.lits[0]));
                }
            }
        }
        for (GClause gClause2 : set) {
            ArrayList arrayList = new ArrayList();
            int[] iArr = gClause2.lits;
            int length = iArr.length;
            while (true) {
                if (i < length) {
                    int i2 = iArr[i];
                    if (i2 > 0) {
                        if (hashSet.contains(Integer.valueOf(i2))) {
                            break;
                        }
                        i = hashSet2.contains(Integer.valueOf(i2)) ? i + 1 : 0;
                        arrayList.add(Integer.valueOf(i2));
                    } else {
                        if (hashSet2.contains(Integer.valueOf(-i2))) {
                            break;
                        }
                        if (hashSet.contains(Integer.valueOf(-i2))) {
                        }
                        arrayList.add(Integer.valueOf(i2));
                    }
                } else if (arrayList.size() != 0) {
                    hashSet3.add(new GClause(gClause2.weight, ContainerMan.convertIntegers(arrayList)));
                }
            }
        }
        return hashSet3;
    }

    public Set<GClause> getTransformedClauses() {
        return this.transformedClauses;
    }

    public TIntIntMap getContrAuxiVars() {
        return this.contrAuxiVars;
    }

    public TIntIntMap getRevContrAuxiVars() {
        return this.revContrAuxiVars;
    }

    public TIntIntMap getNegAuxiVars() {
        return this.negAuxiVars;
    }

    public TIntIntMap getRevNegAuxiVars() {
        return this.revNegAuxiVars;
    }
}
