package edu.gatech.mln.infer;

import edu.gatech.mln.GClause;
import edu.gatech.mln.util.Config;
import edu.gatech.mln.util.NamedThreadFactory;
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.io.PrintWriter;
import java.net.Socket;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import org.apache.commons.lang3.Pair;

/* loaded from: input_file:edu/gatech/mln/infer/LazySolverParallel.class */
public class LazySolverParallel {
    private Map<Worker, Integer> workLoadMap = new HashMap();
    public static final int MAX_WORK_LOAD = 4;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:edu/gatech/mln/infer/LazySolverParallel$MaxSATTask.class */
    public class MaxSATTask implements Callable<Pair<Double, Set<Integer>>> {
        private Worker worker;
        private Set<GClause> problem;

        public MaxSATTask(Worker worker, Set<GClause> set) {
            this.worker = worker;
            this.problem = set;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        /* JADX WARN: Multi-variable type inference failed */
        /* JADX WARN: Type inference failed for: r0v33, types: [java.util.Map] */
        /* JADX WARN: Type inference failed for: r0v34, types: [java.lang.Throwable] */
        /* JADX WARN: Type inference failed for: r0v45 */
        /* JADX WARN: Type inference failed for: r0v54, types: [java.util.Map] */
        /* JADX WARN: Type inference failed for: r0v55, types: [java.lang.Throwable] */
        /* JADX WARN: Type inference failed for: r0v66 */
        @Override // java.util.concurrent.Callable
        public Pair<Double, Set<Integer>> call() {
            try {
                Socket socket = new Socket(this.worker.addr, this.worker.port);
                PrintWriter printWriter = new PrintWriter(socket.getOutputStream());
                BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(socket.getInputStream()));
                printWriter.println(this.problem.size());
                for (GClause gClause : this.problem) {
                    if (gClause.isHardClause()) {
                        printWriter.print(-1);
                    } else {
                        printWriter.print(gClause.weight);
                    }
                    for (int i : gClause.lits) {
                        printWriter.print(" " + i);
                    }
                    printWriter.println(" 0");
                }
                printWriter.flush();
                double parseDouble = Double.parseDouble(bufferedReader.readLine().trim());
                if (parseDouble < 0.0d) {
                    printWriter.close();
                    bufferedReader.close();
                    socket.close();
                    ?? r0 = LazySolverParallel.this.workLoadMap;
                    synchronized (r0) {
                        LazySolverParallel.this.workLoadMap.put(this.worker, Integer.valueOf(((Integer) LazySolverParallel.this.workLoadMap.get(this.worker)).intValue() - 1));
                        r0 = r0;
                        return null;
                    }
                }
                int parseInt = Integer.parseInt(bufferedReader.readLine().trim());
                HashSet hashSet = new HashSet();
                for (int i2 = 0; i2 < parseInt; i2++) {
                    hashSet.add(Integer.valueOf(Integer.parseInt(bufferedReader.readLine().trim())));
                }
                printWriter.close();
                bufferedReader.close();
                socket.close();
                ?? r02 = LazySolverParallel.this.workLoadMap;
                synchronized (r02) {
                    LazySolverParallel.this.workLoadMap.put(this.worker, Integer.valueOf(((Integer) LazySolverParallel.this.workLoadMap.get(this.worker)).intValue() - 1));
                    r02 = r02;
                    return new Pair<>(Double.valueOf(parseDouble), hashSet);
                }
            } catch (Exception e) {
                e.printStackTrace();
                throw new RuntimeException(e);
            }
        }
    }

    public void registerWorker(String str, int i) {
        Worker worker = new Worker();
        worker.addr = str;
        worker.port = i;
        if (this.workLoadMap.containsKey(worker)) {
            return;
        }
        this.workLoadMap.put(worker, 0);
    }

    public int getNumWorkers() {
        return this.workLoadMap.size();
    }

    /* JADX WARN: Type inference failed for: r0v2, types: [java.lang.Throwable, java.util.Map<edu.gatech.mln.infer.Worker, java.lang.Integer>] */
    public Future<Pair<Double, Set<Integer>>> solve(Set<GClause> set) {
        Map.Entry<Worker, Integer> entry = null;
        synchronized (this.workLoadMap) {
            for (Map.Entry<Worker, Integer> entry2 : this.workLoadMap.entrySet()) {
                if (entry == null) {
                    entry = entry2;
                } else if (entry2.getValue().intValue() < entry.getValue().intValue()) {
                    entry = entry2;
                }
            }
            if (entry == null) {
                throw new RuntimeException("No MaxSAT workers!");
            }
            if (entry.getValue().intValue() >= 4) {
                return null;
            }
            this.workLoadMap.put(entry.getKey(), Integer.valueOf(entry.getValue().intValue() + 1));
            return Config.executor.submit(new MaxSATTask(entry.getKey(), set));
        }
    }

    public static void main(String[] strArr) throws Exception {
        Future<Pair<Double, Set<Integer>>> solve;
        Config.executor = Executors.newCachedThreadPool(new NamedThreadFactory("MLN thread pool", true));
        LazySolverParallel lazySolverParallel = new LazySolverParallel();
        lazySolverParallel.registerWorker("127.0.0.1", 8888);
        lazySolverParallel.registerWorker("127.0.0.1", 9999);
        HashSet hashSet = new HashSet();
        hashSet.add(new GClause(1.0d, 1, -2, 3));
        hashSet.add(new GClause(Config.hard_weight, 1));
        hashSet.add(new GClause(2.0d, 3));
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < 100; i++) {
            do {
                solve = lazySolverParallel.solve(hashSet);
            } while (solve == null);
            arrayList.add(solve);
        }
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            System.out.println(((Future) it.next()).get());
        }
    }
}
