package org.wikipedia.miner.annotation.weighting;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import org.apache.log4j.Logger;
import org.wikipedia.miner.annotation.Topic;
import org.wikipedia.miner.annotation.TopicDetector;
import org.wikipedia.miner.comparison.ArticleComparer;
import org.wikipedia.miner.model.Wikipedia;
import org.wikipedia.miner.util.ProgressTracker;
import org.wikipedia.miner.util.RelatednessCache;
import org.wikipedia.miner.util.Result;
import org.wikipedia.miner.util.TopicIndexingSet;
import weka.classifiers.Classifier;
import weka.classifiers.meta.Bagging;
import weka.core.Instance;
import weka.core.Utils;
import weka.core.WekaException;
import weka.wrapper.Dataset;
import weka.wrapper.Decider;
import weka.wrapper.DeciderBuilder;
import weka.wrapper.InstanceBuilder;

/* loaded from: input_file:org/wikipedia/miner/annotation/weighting/TopicIndexer.class */
public class TopicIndexer extends TopicWeighter {
    private Wikipedia wikipedia;
    private Dataset<Attributes, Boolean> dataset;
    int candidatesConsidered = 0;
    private Decider<Attributes, Boolean> decider = new DeciderBuilder("LinkDisambiguator", Attributes.class).setDefaultAttributeTypeNumeric().setClassAttributeTypeBoolean("isKeyTopic").build();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/wikipedia/miner/annotation/weighting/TopicIndexer$Attributes.class */
    public enum Attributes {
        occurances,
        maxDisambigConfidence,
        avgDisambigConfidence,
        relatednessToContext,
        relatednessToOtherTopics,
        maxLinkProbability,
        avgLinkProbability,
        generality,
        firstOccurance,
        lastOccurance,
        spread
    }

    public TopicIndexer(Wikipedia wikipedia) throws Exception {
        this.wikipedia = wikipedia;
    }

    public int getCandidatesConsidered() {
        return this.candidatesConsidered;
    }

    @Override // org.wikipedia.miner.annotation.weighting.TopicWeighter
    public HashMap<Integer, Double> getTopicWeights(Collection<Topic> collection) throws Exception {
        if (!this.decider.isReady()) {
            throw new WekaException("You must build (or load) classifier first.");
        }
        HashMap<Integer, Double> hashMap = new HashMap<>();
        for (Topic topic : collection) {
            hashMap.put(Integer.valueOf(topic.getId()), Double.valueOf(((Double) this.decider.getDecisionDistribution(getInstance(topic, null)).get(true)).doubleValue()));
            this.candidatesConsidered++;
        }
        return hashMap;
    }

    public void train(TopicIndexingSet topicIndexingSet, String str, TopicDetector topicDetector) throws Exception {
        this.dataset = this.decider.createNewDataset();
        ProgressTracker progressTracker = new ProgressTracker(topicIndexingSet.size(), "training", TopicIndexer.class);
        Iterator<TopicIndexingSet.Item> it = topicIndexingSet.iterator();
        while (it.hasNext()) {
            train(it.next(), topicDetector);
            progressTracker.update();
        }
        weightTrainingInstances();
    }

    public Result<Integer> test(TopicIndexingSet topicIndexingSet, TopicDetector topicDetector) throws Exception {
        if (!this.decider.isReady()) {
            throw new Exception("You must build (or load) classifier first.");
        }
        double d = 1.0d;
        double d2 = 1.0d;
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        Result<Integer> result = new Result<>();
        ProgressTracker progressTracker = new ProgressTracker(topicIndexingSet.size(), "Testing", TopicIndexer.class);
        Iterator<TopicIndexingSet.Item> it = topicIndexingSet.iterator();
        while (it.hasNext()) {
            i++;
            Result<Integer> test = test(it.next(), topicDetector);
            if (test.getRecall() == 1.0d) {
                i2++;
            }
            if (test.getPrecision() == 1.0d) {
                i3++;
            }
            d = Math.min(d, test.getRecall());
            d2 = Math.min(d2, test.getPrecision());
            result.addIntermediateResult(test);
            progressTracker.update();
        }
        System.out.println("worstR:" + d + ", worstP:" + d2);
        System.out.println("tested:" + i + ", perfectR:" + i2 + ", perfectP:" + i3);
        return result;
    }

    public void saveTrainingData(File file) throws Exception {
        Logger.getLogger(TopicIndexer.class).info("saving training data");
        this.dataset.save(file);
    }

    public void loadTrainingData(File file) throws Exception {
        Logger.getLogger(TopicIndexer.class).info("loading training data");
        this.dataset.load(file);
        weightTrainingInstances();
    }

    public void clearTrainingData() {
        this.dataset = null;
    }

    public void saveClassifier(File file) throws IOException {
        Logger.getLogger(TopicIndexer.class).info("saving classifier");
        this.decider.save(file);
    }

    public void loadClassifier(File file) throws Exception {
        Logger.getLogger(TopicIndexer.class).info("loading classifier");
        this.decider.load(file);
    }

    public void buildClassifier(Classifier classifier) throws Exception {
        Logger.getLogger(TopicIndexer.class).info("building classifier");
        this.decider.train(classifier, this.dataset);
    }

    public void buildDefaultClassifier() throws Exception {
        Logger.getLogger(TopicIndexer.class).info("building classifier");
        Bagging bagging = new Bagging();
        bagging.setOptions(Utils.splitOptions("-P 10 -S 1 -I 10 -W weka.classifiers.trees.J48 -- -U -M 2"));
        this.decider.train(bagging, this.dataset);
    }

    private void train(TopicIndexingSet.Item item, TopicDetector topicDetector) throws Exception {
        for (Topic topic : topicDetector.getTopics(item.getDocument(), new RelatednessCache(new ArticleComparer(this.wikipedia)))) {
            this.dataset.add(getInstance(topic, Boolean.valueOf(item.isTopic(topic))));
        }
    }

    private Result<Integer> test(TopicIndexingSet.Item item, TopicDetector topicDetector) throws Exception {
        ArrayList<Topic> weightedTopics = getWeightedTopics(topicDetector.getTopics(item.getDocument(), new RelatednessCache(new ArticleComparer(this.wikipedia))));
        HashSet hashSet = new HashSet();
        Iterator<Topic> it = weightedTopics.iterator();
        while (it.hasNext()) {
            Topic next = it.next();
            if (next.getWeight().doubleValue() > 0.5d) {
                hashSet.add(Integer.valueOf(next.getId()));
            }
        }
        Result<Integer> result = new Result<>(hashSet, item.getTopicIds());
        System.out.println(" - " + result);
        return result;
    }

    private Instance getInstance(Topic topic, Boolean bool) throws Exception {
        InstanceBuilder attribute = this.decider.getInstanceBuilder().setAttribute(Attributes.occurances, Integer.valueOf(topic.getOccurances())).setAttribute(Attributes.maxDisambigConfidence, Double.valueOf(topic.getMaxDisambigConfidence())).setAttribute(Attributes.avgDisambigConfidence, Double.valueOf(topic.getAverageDisambigConfidence())).setAttribute(Attributes.relatednessToContext, Double.valueOf(topic.getRelatednessToContext())).setAttribute(Attributes.relatednessToOtherTopics, Double.valueOf(topic.getRelatednessToOtherTopics())).setAttribute(Attributes.maxLinkProbability, Double.valueOf(topic.getMaxLinkProbability())).setAttribute(Attributes.avgLinkProbability, Double.valueOf(topic.getAverageLinkProbability())).setAttribute(Attributes.generality, topic.getGenerality()).setAttribute(Attributes.firstOccurance, Double.valueOf(topic.getFirstOccurance())).setAttribute(Attributes.lastOccurance, Double.valueOf(topic.getLastOccurance())).setAttribute(Attributes.spread, Double.valueOf(topic.getSpread()));
        if (bool != null) {
            attribute = attribute.setClassAttribute(bool);
        }
        return attribute.build();
    }

    private void weightTrainingInstances() {
        double d = 0.0d;
        double d2 = 0.0d;
        Enumeration enumerateInstances = this.dataset.enumerateInstances();
        while (enumerateInstances.hasMoreElements()) {
            if (((Instance) enumerateInstances.nextElement()).value(3) == 0.0d) {
                d += 1.0d;
            } else {
                d2 += 1.0d;
            }
        }
        double d3 = d / (d + d2);
        Enumeration enumerateInstances2 = this.dataset.enumerateInstances();
        while (enumerateInstances2.hasMoreElements()) {
            Instance instance = (Instance) enumerateInstances2.nextElement();
            if (instance.value(3) == 0.0d) {
                instance.setWeight(0.5d * (1.0d / d3));
            } else {
                instance.setWeight(0.5d * (1.0d / (1.0d - d3)));
            }
        }
    }
}
