package edu.gatech.mln.infer.querydriven;

import edu.gatech.mln.Atom;
import edu.gatech.mln.Clause;
import edu.gatech.mln.GClause;
import edu.gatech.mln.MarkovLogicNetwork;
import edu.gatech.mln.Predicate;
import edu.gatech.mln.db.RDB;
import edu.gatech.mln.parser.CommandOptions;
import edu.gatech.mln.util.Config;
import edu.gatech.mln.util.FileMan;
import edu.gatech.mln.util.UIMan;
import gnu.trove.iterator.TIntIterator;
import gnu.trove.set.TIntSet;
import gnu.trove.set.hash.TIntHashSet;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.PrintWriter;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Scanner;
import java.util.Set;

/* loaded from: input_file:edu/gatech/mln/infer/querydriven/QInferer.class */
public class QInferer {
    private MarkovLogicNetwork mln;
    protected PrintWriter logOut;
    protected TIntSet queries;
    protected QMaxSAT solver;
    protected boolean isEager;
    protected boolean isCompo;
    protected boolean isHorn;
    private Set<Integer> lfpVars;
    protected Set<GClause> groundedClauses = new HashSet();
    protected double hardWeight = Config.hard_weight;
    private Set<Integer> trueHardEvidence = new HashSet();
    private Set<Integer> falseHardEvidence = new HashSet();
    private Map<Integer, Double> trueSoftEvidence = new HashMap();
    private Map<Integer, Double> falseSoftEvidence = new HashMap();

    public void infer(CommandOptions commandOptions) {
        Clause.mappingFromID2Const = new HashMap<>();
        Clause.mappingFromID2Desc = new HashMap<>();
        UIMan.println(">>> Connecting to RDBMS at " + Config.db_url);
        RDB rDBbyConfig = RDB.getRDBbyConfig();
        rDBbyConfig.resetSchema(Config.db_schema);
        this.logOut = FileMan.getPrintWriter(commandOptions.fout);
        this.mln = new MarkovLogicNetwork();
        this.mln.setDB(rDBbyConfig);
        this.mln.loadPrograms(commandOptions.fprog.split(","));
        this.mln.loadEvidences(commandOptions.fevid.split(","));
        this.mln.materializeTables();
        this.mln.prepareDB(rDBbyConfig);
        this.isEager = commandOptions.isQueryDrivenEager;
        this.isCompo = commandOptions.isQueryCompo;
        this.isHorn = commandOptions.isQueryHorn;
        loadQueries(commandOptions.fquery);
        fillUniverse();
        if (commandOptions.lfpFile != null) {
            enforceLfp(commandOptions.lfpFile);
        }
        if (solve()) {
            this.logOut.println("// Queries that are set to true");
            TIntIterator it = this.queries.iterator();
            while (it.hasNext()) {
                int next = it.next();
                if (this.solver.getFinalTa().contains(next)) {
                    this.logOut.println(this.mln.getAtom(next).toGroundString(this.mln));
                } else if (!this.solver.getFinalFa().contains(next)) {
                    throw new RuntimeException("Query unanswered: " + this.mln.getAtom(next));
                }
            }
            if (commandOptions.printViolation) {
                this.logOut.println("The following grounded constraints are violated: ");
                Iterator<GClause> it2 = this.solver.getViolatedConstraints().iterator();
                while (it2.hasNext()) {
                    this.logOut.println(it2.next().toVerboseString(this.mln));
                }
                this.logOut.println("\nThe following grounded constraints could be possibility violated: ");
                Iterator<GClause> it3 = this.solver.getPossiViolatedonstraints().iterator();
                while (it3.hasNext()) {
                    this.logOut.println(it3.next().toVerboseString(this.mln));
                }
            }
        } else {
            this.logOut.println("// 0 UNSAT");
        }
        this.logOut.flush();
        this.logOut.close();
    }

    private void enforceLfp(String str) {
        this.lfpVars = new HashSet();
        try {
            Scanner scanner = new Scanner(new File(str));
            while (scanner.hasNextLine()) {
                String nextLine = scanner.nextLine();
                if (!nextLine.startsWith("//")) {
                    this.lfpVars.add(this.mln.getAtomID(this.mln.parseAtom(nextLine)));
                }
            }
            scanner.close();
            double size = this.lfpVars.size() + 1;
            HornQGrounder.cutOffValue = size;
            Iterator<GClause> it = this.groundedClauses.iterator();
            while (it.hasNext()) {
                it.next().weight *= size;
            }
            Iterator<Integer> it2 = this.lfpVars.iterator();
            while (it2.hasNext()) {
                this.groundedClauses.add(new GClause(1.0d, -it2.next().intValue()));
            }
        } catch (FileNotFoundException e) {
            throw new RuntimeException(e);
        }
    }

    protected void fillUniverse() {
        this.groundedClauses = new HashSet();
        if (Config.gcLoadFile == null) {
            throw new RuntimeException("Need to implement constraints grounding without warm start!");
        }
        UIMan.println("Loading grounded clauses from file " + Config.gcLoadFile + ".");
        try {
            FileInputStream fileInputStream = FileMan.getFileInputStream(Config.gcLoadFile);
            loadGroundedConstraints(fileInputStream);
            fileInputStream.close();
            if (Config.revLoadFile != null) {
                UIMan.println("Loading grounded reverted constrains from file " + Config.revLoadFile + ".");
                try {
                    FileInputStream fileInputStream2 = FileMan.getFileInputStream(Config.revLoadFile);
                    loadRevertedConstraints(fileInputStream2);
                    fileInputStream2.close();
                } catch (IOException e) {
                    throw new RuntimeException(e);
                }
            }
            loadEvidenceConstraints(MaxSATUtils.getAllAtoms(this.groundedClauses));
        } catch (IOException e2) {
            throw new RuntimeException(e2);
        }
    }

    public void calculateHardWeight() {
        double d = 0.0d;
        for (GClause gClause : this.groundedClauses) {
            if (!gClause.isHardClause()) {
                d += gClause.weight;
            }
        }
        this.hardWeight = d + 1.0d;
    }

    public double getWeight(GClause gClause) {
        return gClause.isHardClause() ? this.hardWeight : gClause.weight;
    }

    public void loadQueries(String str) {
        this.queries = new TIntHashSet();
        try {
            Scanner scanner = new Scanner(new File(str));
            while (scanner.hasNextLine()) {
                String nextLine = scanner.nextLine();
                if (!nextLine.startsWith("//")) {
                    this.queries.add(this.mln.getAtomID(this.mln.parseAtom(nextLine)).intValue());
                }
            }
            scanner.close();
        } catch (FileNotFoundException e) {
            throw new RuntimeException(e);
        }
    }

    private void loadEvidenceConstraints(TIntSet tIntSet) {
        Iterator<Predicate> it = this.mln.getAllPred().iterator();
        while (it.hasNext()) {
            Predicate next = it.next();
            for (Atom atom : next.getHardEvidences()) {
                int intValue = this.mln.getAtomID(atom.base()).intValue();
                if (atom.truth.booleanValue()) {
                    this.trueHardEvidence.add(Integer.valueOf(intValue));
                    this.groundedClauses.add(new GClause(Config.hard_weight, intValue));
                } else {
                    this.falseHardEvidence.add(Integer.valueOf(intValue));
                    this.groundedClauses.add(new GClause(Config.hard_weight, -intValue));
                }
            }
            for (Atom atom2 : next.getSoftEvidences()) {
                int intValue2 = this.mln.getAtomID(atom2.base()).intValue();
                double doubleValue = atom2.prior.doubleValue();
                if (doubleValue > 0.0d) {
                    this.trueSoftEvidence.put(Integer.valueOf(intValue2), Double.valueOf(doubleValue));
                    this.groundedClauses.add(new GClause(doubleValue, intValue2));
                } else if (atom2.prior.doubleValue() < 0.0d) {
                    this.falseSoftEvidence.put(Integer.valueOf(intValue2), Double.valueOf(0.0d - doubleValue));
                    this.groundedClauses.add(new GClause(0.0d - doubleValue, -intValue2));
                }
            }
        }
        TIntIterator it2 = tIntSet.iterator();
        while (it2.hasNext()) {
            int next2 = it2.next();
            Atom atom3 = this.mln.getAtom(next2);
            if (atom3.pred.isClosedWorld() && !this.trueHardEvidence.contains(Integer.valueOf(next2)) && !this.falseHardEvidence.contains(atom3) && !this.trueSoftEvidence.containsKey(Integer.valueOf(next2)) && !this.falseSoftEvidence.containsKey(Integer.valueOf(next2))) {
                this.groundedClauses.add(new GClause(Config.hard_weight, -next2));
            }
        }
    }

    private void loadGroundedConstraints(InputStream inputStream) {
        try {
            BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(inputStream));
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    return;
                }
                if (!readLine.startsWith("//")) {
                    String[] split = readLine.split(": ");
                    double parseDouble = split[0].equals("infi") ? Config.hard_weight : Double.parseDouble(split[0]);
                    String[] split2 = split[1].split(", ");
                    int[] iArr = new int[split2.length];
                    for (int i = 0; i < split2.length; i++) {
                        String[] split3 = split2[i].split(" ");
                        boolean z = split3.length > 1;
                        int intValue = this.mln.getAtomID(this.mln.parseAtom(split3[split3.length - 1])).intValue();
                        if (z) {
                            intValue = 0 - intValue;
                        }
                        iArr[i] = intValue;
                    }
                    GClause matchGroundedClause = this.mln.matchGroundedClause(parseDouble, iArr);
                    if (matchGroundedClause != null) {
                        this.groundedClauses.add(matchGroundedClause);
                    }
                }
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public void loadRevertedConstraints(InputStream inputStream) {
        try {
            BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(inputStream));
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    return;
                }
                String[] split = readLine.split(": ");
                double parseDouble = split[0].equals("infi") ? Config.hard_weight : Double.parseDouble(split[0]);
                String[] split2 = split[1].split(", ");
                int[] iArr = new int[split2.length];
                for (int i = 0; i < split2.length; i++) {
                    String[] split3 = split2[i].split(" ");
                    boolean z = split3.length > 1;
                    int intValue = this.mln.getAtomID(this.mln.parseAtom(split3[split3.length - 1])).intValue();
                    if (z) {
                        intValue = 0 - intValue;
                    }
                    iArr[i] = intValue;
                }
                this.groundedClauses.add(new GClause(parseDouble, iArr));
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public boolean solve() {
        if (this.isEager) {
            this.solver = new QMaxSATEager(this.groundedClauses, this.mln);
        } else if (this.isCompo) {
            if (this.isHorn) {
                this.solver = new CompoHornQMaxSAT(this.groundedClauses, this.mln);
            } else {
                this.solver = new CompoQMaxSAT(this.groundedClauses, this.mln);
            }
        } else if (this.isHorn) {
            this.solver = new HornQMaxSAT(this.groundedClauses, this.mln);
        } else {
            this.solver = new QMaxSAT(this.groundedClauses, this.mln);
        }
        return this.solver.solve(this.queries);
    }
}
