package org.jpmml.evaluator;

import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.BiMap;
import com.google.common.collect.ImmutableBiMap;
import java.util.ArrayDeque;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.dmg.pmml.CompoundPredicate;
import org.dmg.pmml.DataType;
import org.dmg.pmml.EmbeddedModel;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningFunctionType;
import org.dmg.pmml.MissingValueStrategyType;
import org.dmg.pmml.NoTrueChildStrategyType;
import org.dmg.pmml.Node;
import org.dmg.pmml.PMML;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.ScoreDistribution;
import org.dmg.pmml.TreeModel;
import org.jpmml.evaluator.PredicateUtil;
import org.jpmml.manager.InvalidFeatureException;
import org.jpmml.manager.UnsupportedFeatureException;

/* loaded from: input_file:BOOT-INF/lib/pmml-evaluator-1.1.14.jar:org/jpmml/evaluator/TreeModelEvaluator.class */
public class TreeModelEvaluator extends ModelEvaluator<TreeModel> implements HasEntityRegistry<Node> {
    private static final LoadingCache<TreeModel, BiMap<String, Node>> entityCache = CacheBuilder.newBuilder().weakKeys().build(new CacheLoader<TreeModel, BiMap<String, Node>>() { // from class: org.jpmml.evaluator.TreeModelEvaluator.1
        @Override // com.google.common.cache.CacheLoader
        public BiMap<String, Node> load(TreeModel treeModel) {
            return collectNodes(treeModel.getNode(), new ImmutableBiMap.Builder<>()).build();
        }

        private ImmutableBiMap.Builder<String, Node> collectNodes(Node node, ImmutableBiMap.Builder<String, Node> builder) {
            ImmutableBiMap.Builder<String, Node> put = EntityUtil.put(node, builder);
            Iterator<Node> it = node.getNodes().iterator();
            while (it.hasNext()) {
                put = collectNodes(it.next(), put);
            }
            return put;
        }
    });

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:BOOT-INF/lib/pmml-evaluator-1.1.14.jar:org/jpmml/evaluator/TreeModelEvaluator$NodeResult.class */
    public static class NodeResult {
        private Node node = null;

        public NodeResult(Node node) {
            setNode(node);
        }

        public Node getNode() {
            return this.node;
        }

        private void setNode(Node node) {
            this.node = node;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:BOOT-INF/lib/pmml-evaluator-1.1.14.jar:org/jpmml/evaluator/TreeModelEvaluator$Trail.class */
    public static class Trail extends ArrayDeque<Node> {
        private int missingLevels = 0;

        public Node getLastPrediction() {
            return getFirst();
        }

        public void addMissingLevel() {
            setMissingLevels(getMissingLevels() + 1);
        }

        public int getMissingLevels() {
            return this.missingLevels;
        }

        private void setMissingLevels(int i) {
            this.missingLevels = i;
        }
    }

    public TreeModelEvaluator(PMML pmml) {
        this(pmml, (TreeModel) find(pmml.getModels(), TreeModel.class));
    }

    public TreeModelEvaluator(PMML pmml, TreeModel treeModel) {
        super(pmml, treeModel);
    }

    @Override // org.jpmml.manager.ModelManager, org.jpmml.manager.Consumer
    public String getSummary() {
        return "Tree model";
    }

    @Override // org.jpmml.evaluator.HasEntityRegistry
    public BiMap<String, Node> getEntityRegistry() {
        return (BiMap) getValue(entityCache);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.jpmml.evaluator.ModelEvaluator
    public Map<FieldName, ?> evaluate(ModelEvaluationContext modelEvaluationContext) {
        Map<FieldName, ? extends Number> evaluateClassification;
        TreeModel treeModel = (TreeModel) getModel();
        if (!treeModel.isScorable()) {
            throw new InvalidResultException(treeModel);
        }
        MiningFunctionType functionName = treeModel.getFunctionName();
        switch (functionName) {
            case REGRESSION:
                evaluateClassification = evaluateRegression(modelEvaluationContext);
                break;
            case CLASSIFICATION:
                evaluateClassification = evaluateClassification(modelEvaluationContext);
                break;
            default:
                throw new UnsupportedFeatureException(treeModel, functionName);
        }
        return OutputUtil.evaluate(evaluateClassification, modelEvaluationContext);
    }

    private Map<FieldName, ? extends Number> evaluateRegression(ModelEvaluationContext modelEvaluationContext) {
        Double d = null;
        Node evaluateTree = evaluateTree(new Trail(), modelEvaluationContext);
        if (evaluateTree != null) {
            d = (Double) TypeUtil.parseOrCast(DataType.DOUBLE, ensureScore(evaluateTree));
        }
        return TargetUtil.evaluateRegression(d, modelEvaluationContext);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Map<FieldName, ? extends ClassificationMap<?>> evaluateClassification(ModelEvaluationContext modelEvaluationContext) {
        TreeModel treeModel = (TreeModel) getModel();
        NodeClassificationMap nodeClassificationMap = null;
        Trail trail = new Trail();
        Node evaluateTree = evaluateTree(trail, modelEvaluationContext);
        if (evaluateTree != null) {
            ensureScore(evaluateTree);
            double d = 1.0d;
            int missingLevels = trail.getMissingLevels();
            for (int i = 0; i < missingLevels; i++) {
                d *= treeModel.getMissingValuePenalty();
            }
            nodeClassificationMap = createNodeClassificationMap(evaluateTree, d);
        }
        return TargetUtil.evaluateClassification((ClassificationMap<?>) nodeClassificationMap, modelEvaluationContext);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private Node evaluateTree(Trail trail, ModelEvaluationContext modelEvaluationContext) {
        TreeModel treeModel = (TreeModel) getModel();
        Node node = treeModel.getNode();
        if (node == null) {
            throw new InvalidFeatureException(treeModel);
        }
        Boolean evaluateNode = evaluateNode(node, trail, modelEvaluationContext);
        if (evaluateNode == null || !evaluateNode.booleanValue()) {
            return null;
        }
        return handleTrue(node, trail, modelEvaluationContext).getNode();
    }

    private Boolean evaluateNode(Node node, Trail trail, EvaluationContext evaluationContext) {
        EmbeddedModel embeddedModel = node.getEmbeddedModel();
        if (embeddedModel != null) {
            throw new UnsupportedFeatureException(embeddedModel);
        }
        Predicate predicate = node.getPredicate();
        if (predicate == null) {
            throw new InvalidFeatureException(node);
        }
        if (!(predicate instanceof CompoundPredicate)) {
            return PredicateUtil.evaluate(predicate, evaluationContext);
        }
        PredicateUtil.CompoundPredicateResult evaluateCompoundPredicateInternal = PredicateUtil.evaluateCompoundPredicateInternal((CompoundPredicate) predicate, evaluationContext);
        if (evaluateCompoundPredicateInternal.isAlternative()) {
            trail.addMissingLevel();
        }
        return evaluateCompoundPredicateInternal.getResult();
    }

    private NodeResult handleTrue(Node node, Trail trail, EvaluationContext evaluationContext) {
        List<Node> nodes = node.getNodes();
        if (nodes.isEmpty()) {
            return new NodeResult(node);
        }
        trail.push(node);
        for (Node node2 : nodes) {
            Boolean evaluateNode = evaluateNode(node2, trail, evaluationContext);
            if (evaluateNode == null) {
                NodeResult handleMissingValue = handleMissingValue(node, node2, trail, evaluationContext);
                if (handleMissingValue != null) {
                    return handleMissingValue;
                }
            } else if (evaluateNode.booleanValue()) {
                return handleTrue(node2, trail, evaluationContext);
            }
        }
        return handleNoTrueChild(node, trail, evaluationContext);
    }

    private NodeResult handleDefaultChild(Node node, Trail trail, EvaluationContext evaluationContext) {
        List<Node> nodes = node.getNodes();
        String defaultChild = node.getDefaultChild();
        if (defaultChild == null) {
            throw new InvalidFeatureException(node);
        }
        trail.addMissingLevel();
        for (Node node2 : nodes) {
            String id = node2.getId();
            if (id != null && id.equals(defaultChild)) {
                return handleTrue(node2, trail, evaluationContext);
            }
        }
        throw new InvalidFeatureException(node);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private NodeResult handleNoTrueChild(Node node, Trail trail, EvaluationContext evaluationContext) {
        TreeModel treeModel = (TreeModel) getModel();
        NoTrueChildStrategyType noTrueChildStrategy = treeModel.getNoTrueChildStrategy();
        switch (noTrueChildStrategy) {
            case RETURN_NULL_PREDICTION:
                return new NodeResult(null);
            case RETURN_LAST_PREDICTION:
                if (trail.size() > 0) {
                    Node lastPrediction = trail.getLastPrediction();
                    if (lastPrediction.getScore() != null) {
                        return new NodeResult(lastPrediction);
                    }
                }
                return new NodeResult(null);
            default:
                throw new UnsupportedFeatureException(treeModel, noTrueChildStrategy);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    private NodeResult handleMissingValue(Node node, Node node2, Trail trail, EvaluationContext evaluationContext) {
        TreeModel treeModel = (TreeModel) getModel();
        MissingValueStrategyType missingValueStrategy = treeModel.getMissingValueStrategy();
        switch (missingValueStrategy) {
            case NULL_PREDICTION:
                return new NodeResult(null);
            case LAST_PREDICTION:
                return new NodeResult(trail.getLastPrediction());
            case DEFAULT_CHILD:
                if (node == null) {
                    throw new EvaluationException();
                }
                return handleDefaultChild(node, trail, evaluationContext);
            case NONE:
                return null;
            default:
                throw new UnsupportedFeatureException(treeModel, missingValueStrategy);
        }
    }

    private static String ensureScore(Node node) {
        String score = node.getScore();
        if (score == null) {
            throw new InvalidFeatureException(node);
        }
        return score;
    }

    private static NodeClassificationMap createNodeClassificationMap(Node node, double d) {
        NodeClassificationMap nodeClassificationMap = new NodeClassificationMap(node);
        List<ScoreDistribution> scoreDistributions = node.getScoreDistributions();
        double d2 = 0.0d;
        Iterator<ScoreDistribution> it = scoreDistributions.iterator();
        while (it.hasNext()) {
            d2 += it.next().getRecordCount();
        }
        for (ScoreDistribution scoreDistribution : scoreDistributions) {
            Double probability = scoreDistribution.getProbability();
            if (probability == null) {
                probability = Double.valueOf(scoreDistribution.getRecordCount() / d2);
            }
            nodeClassificationMap.put(scoreDistribution.getValue(), probability);
            Double confidence = scoreDistribution.getConfidence();
            if (confidence != null) {
                nodeClassificationMap.putConfidence(scoreDistribution.getValue(), Double.valueOf(confidence.doubleValue() * d));
            }
        }
        return nodeClassificationMap;
    }
}
