package edu.gatech.mln.infer.querydriven;

import edu.gatech.mln.GClause;
import edu.gatech.mln.util.Config;
import edu.gatech.mln.util.UIMan;
import gnu.trove.iterator.TIntObjectIterator;
import gnu.trove.map.TIntDoubleMap;
import gnu.trove.map.TIntObjectMap;
import gnu.trove.map.hash.TIntDoubleHashMap;
import gnu.trove.set.TIntSet;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;

/* loaded from: input_file:edu/gatech/mln/infer/querydriven/HornQGrounder.class */
public class HornQGrounder extends ContriFrontiersQGrounder {
    public static double cutOffValue = 0.1d;
    private TIntObjectMap<TIntSet> reverseGoalMap;
    private TIntObjectMap<Set<GClause>> goalWeight;
    private TIntObjectMap<TIntSet> factMap;
    private TIntObjectMap<Set<GClause>> factWeight;
    private TIntObjectMap<Set<GClause>> reverseMap;
    private Set<GClause> excludeSet;

    public HornQGrounder(Set<GClause> set, TIntObjectMap<TIntSet> tIntObjectMap, TIntObjectMap<Set<GClause>> tIntObjectMap2, TIntObjectMap<TIntSet> tIntObjectMap3, TIntObjectMap<Set<GClause>> tIntObjectMap4, TIntObjectMap<Set<GClause>> tIntObjectMap5) {
        super(set);
        this.reverseGoalMap = tIntObjectMap;
        this.goalWeight = tIntObjectMap2;
        this.factMap = tIntObjectMap3;
        this.factWeight = tIntObjectMap4;
        this.reverseMap = tIntObjectMap5;
    }

    private double getGoalWeight(int i, Set<GClause> set) {
        double d = 0.0d;
        Set<GClause> set2 = this.goalWeight.get(i);
        if (set2 == null) {
            return 0.0d;
        }
        for (GClause gClause : set2) {
            if (!set.contains(gClause)) {
                d += gClause.weight;
            }
        }
        return d;
    }

    private double getFactWeight(int i, Set<GClause> set, TIntDoubleMap tIntDoubleMap) {
        Double valueOf = Double.valueOf(tIntDoubleMap.get(i));
        double doubleValue = valueOf != null ? 0.0d + valueOf.doubleValue() : 0.0d;
        Set<GClause> set2 = this.factWeight.get(i);
        if (set2 == null) {
            return doubleValue;
        }
        for (GClause gClause : set2) {
            if (!set.contains(gClause)) {
                doubleValue += gClause.weight;
            }
        }
        return doubleValue;
    }

    @Override // edu.gatech.mln.infer.querydriven.QGrounder
    public Set<GClause> findStrenthenedFrontiners(Set<GClause> set, TIntSet tIntSet, TIntSet tIntSet2) {
        TIntSet allAtoms = MaxSATUtils.getAllAtoms(set);
        Set<GClause> findFrontiers = findFrontiers(set, allAtoms);
        HashSet hashSet = new HashSet();
        this.excludeSet = new HashSet();
        TIntDoubleHashMap tIntDoubleHashMap = new TIntDoubleHashMap();
        TIntDoubleHashMap tIntDoubleHashMap2 = new TIntDoubleHashMap();
        boolean z = false;
        boolean z2 = false;
        UIMan.verbose(1, "Number of connected frontiers: " + findFrontiers.size());
        for (GClause gClause : findFrontiers) {
            ArrayList arrayList = new ArrayList();
            int[] iArr = gClause.lits;
            int length = iArr.length;
            int i = 0;
            while (true) {
                if (i < length) {
                    int i2 = iArr[i];
                    int abs = Math.abs(i2);
                    if (allAtoms.contains(abs)) {
                        if (i2 > 0) {
                            if (tIntSet.contains(abs)) {
                                break;
                            }
                        }
                        if (i2 < 0 && tIntSet2.contains(abs)) {
                            break;
                        }
                    } else {
                        arrayList.add(Integer.valueOf(i2));
                    }
                    i++;
                } else if (arrayList.size() != gClause.lits.length && !arrayList.isEmpty()) {
                    int[] iArr2 = new int[arrayList.size()];
                    for (int i3 = 0; i3 < arrayList.size(); i3++) {
                        iArr2[i3] = ((Integer) arrayList.get(i3)).intValue();
                    }
                    GClause gClause2 = new GClause(gClause.weight, iArr2);
                    if (MaxSATUtils.isGoal(gClause2)) {
                        z2 = true;
                        Iterator<Integer> it = MaxSATUtils.getNegativeAts(gClause2).iterator();
                        while (it.hasNext()) {
                            int intValue = it.next().intValue();
                            Double valueOf = Double.valueOf(tIntDoubleHashMap2.get(intValue));
                            if (valueOf == null) {
                                valueOf = Double.valueOf(0.0d);
                            }
                            tIntDoubleHashMap2.put(intValue, Double.valueOf(valueOf.doubleValue() + gClause2.weight).doubleValue());
                        }
                    } else {
                        z = true;
                        if (MaxSATUtils.isFact(gClause2)) {
                            Iterator<Integer> it2 = MaxSATUtils.getPositiveAts(gClause2).iterator();
                            while (it2.hasNext()) {
                                int intValue2 = it2.next().intValue();
                                Double valueOf2 = Double.valueOf(tIntDoubleHashMap.get(intValue2));
                                if (valueOf2 == null) {
                                    valueOf2 = Double.valueOf(0.0d);
                                }
                                tIntDoubleHashMap.put(intValue2, Double.valueOf(valueOf2.doubleValue() + gClause2.weight).doubleValue());
                            }
                        }
                    }
                }
            }
        }
        TIntSet tIntSet3 = null;
        double d = 0.0d;
        if (z && !tIntDoubleHashMap2.isEmpty()) {
            for (double d2 : tIntDoubleHashMap2.values()) {
                d += d2;
            }
            tIntSet3 = MaxSATUtils.calReachableGoalsApproxReverse(tIntDoubleHashMap2.keySet(), set, this.reverseMap);
        }
        if (z2) {
            TIntObjectMap<TIntSet> tIntObjectMap = this.factMap;
            tIntDoubleHashMap.isEmpty();
        }
        for (GClause gClause3 : findFrontiers) {
            ArrayList arrayList2 = new ArrayList();
            int[] iArr3 = gClause3.lits;
            int length2 = iArr3.length;
            int i4 = 0;
            while (true) {
                if (i4 < length2) {
                    int i5 = iArr3[i4];
                    int abs2 = Math.abs(i5);
                    if (allAtoms.contains(abs2)) {
                        if ((i5 <= 0 || !tIntSet.contains(abs2)) && (i5 >= 0 || !tIntSet2.contains(abs2))) {
                            arrayList2.add(Integer.valueOf(i5));
                        }
                    }
                    i4++;
                } else if (arrayList2.isEmpty()) {
                    continue;
                } else {
                    int[] iArr4 = new int[arrayList2.size()];
                    for (int i6 = 0; i6 < arrayList2.size(); i6++) {
                        iArr4[i6] = ((Integer) arrayList2.get(i6)).intValue();
                    }
                    GClause gClause4 = new GClause(gClause3.weight, iArr4);
                    if (MaxSATUtils.isDefHorn(gClause3) && gClause3.lits.length - gClause4.lits.length > 0) {
                        if (MaxSATUtils.getPositiveAts(gClause4).size() == 0) {
                            int intValue3 = MaxSATUtils.getPositiveAts(gClause3).get(0).intValue();
                            double d3 = 0.0d;
                            if (d > 0.0d && tIntSet3.contains(intValue3)) {
                                d3 = 0.0d + d;
                            }
                            TIntObjectIterator<TIntSet> it3 = this.reverseGoalMap.iterator();
                            while (it3.hasNext()) {
                                it3.advance();
                                int key = it3.key();
                                if (it3.value().contains(intValue3)) {
                                    d3 += getGoalWeight(key, set);
                                }
                                if (d3 > gClause4.weight) {
                                    break;
                                }
                            }
                            if (d3 < gClause4.weight) {
                                gClause4.weight = d3;
                            }
                        }
                        if (gClause4.weight < cutOffValue) {
                            this.excludeSet.add(gClause3);
                        }
                    }
                    if (gClause4.isHardClause()) {
                        hashSet.add(gClause4);
                    } else if (hashSet.contains(gClause4)) {
                        hashSet.remove(gClause4);
                        gClause4.weight *= 2.0d;
                        if (gClause4.isHardClause()) {
                            throw new RuntimeException("Clause " + gClause4 + " becomes a hard constraint after weight merging.");
                        }
                        hashSet.add(gClause4);
                    } else {
                        hashSet.add(gClause4);
                    }
                }
            }
        }
        return hashSet;
    }

    @Override // edu.gatech.mln.infer.querydriven.ContriFrontiersQGrounder, edu.gatech.mln.infer.querydriven.QGrounder
    public Set<GClause> expand(Set<GClause> set, TIntSet tIntSet, TIntSet tIntSet2, TIntSet tIntSet3, TIntSet tIntSet4) {
        if (!Config.fowardBias) {
            return nonBiasedExpand(set, tIntSet, tIntSet2, tIntSet3, tIntSet4);
        }
        HashSet hashSet = new HashSet();
        Set<GClause> findFrontiers = super.findFrontiers(set, null);
        TIntSet allAtoms = MaxSATUtils.getAllAtoms(set);
        for (GClause gClause : findFrontiers) {
            if (MaxSATUtils.isGoal(gClause)) {
                hashSet.add(gClause);
            }
            if (MaxSATUtils.isDefHorn(gClause)) {
                int intValue = MaxSATUtils.getPositiveAts(gClause).get(0).intValue();
                TIntObjectIterator<TIntSet> it = this.reverseGoalMap.iterator();
                while (true) {
                    if (!it.hasNext()) {
                        break;
                    }
                    it.advance();
                    int key = it.key();
                    TIntSet value = it.value();
                    if (!allAtoms.contains(intValue) && value.contains(intValue) && getGoalWeight(key, set) > 0.0d) {
                        hashSet.add(gClause);
                        break;
                    }
                }
            }
        }
        return hashSet.isEmpty() ? nonBiasedExpand(set, tIntSet, tIntSet2, tIntSet3, tIntSet4) : hashSet;
    }

    public Set<GClause> nonBiasedExpand(Set<GClause> set, TIntSet tIntSet, TIntSet tIntSet2, TIntSet tIntSet3, TIntSet tIntSet4) {
        HashSet hashSet = new HashSet();
        for (GClause gClause : super.findFrontiers(set, null)) {
            if (!this.excludeSet.contains(gClause)) {
                if (gClause.isHardClause()) {
                    if (!MaxSATUtils.satisfy(gClause, tIntSet, tIntSet2)) {
                        hashSet.add(gClause);
                    }
                } else if (!MaxSATUtils.satisfy(gClause, tIntSet, tIntSet2) && MaxSATUtils.satisfy(gClause, tIntSet3, tIntSet4)) {
                    hashSet.add(gClause);
                }
            }
        }
        this.excludeSet = null;
        return hashSet;
    }
}
