package org.elasticsearch.xpack.core.ml.dataframe.evaluation.classification;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.Version;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.BucketOrder;
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
import org.elasticsearch.search.aggregations.bucket.filter.Filters;
import org.elasticsearch.search.aggregations.bucket.filter.FiltersAggregator;
import org.elasticsearch.search.aggregations.metrics.Cardinality;
import org.elasticsearch.xpack.core.ml.MachineLearningFeatureSetUsage;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.MlEvaluationNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

/* loaded from: input_file:org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.class */
public class MulticlassConfusionMatrix implements EvaluationMetric {
    public static final ParseField NAME = new ParseField("multiclass_confusion_matrix", new String[0]);
    public static final ParseField SIZE = new ParseField("size", new String[0]);
    public static final ParseField AGG_NAME_PREFIX = new ParseField("agg_name_prefix", new String[0]);
    private static final ConstructingObjectParser<MulticlassConfusionMatrix, Void> PARSER = createParser();
    static final String STEP_1_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_1_by_actual_class";
    static final String STEP_2_AGGREGATE_BY_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_by_actual_class";
    static final String STEP_2_AGGREGATE_BY_PREDICTED_CLASS = NAME.getPreferredName() + "_step_2_by_predicted_class";
    static final String STEP_2_CARDINALITY_OF_ACTUAL_CLASS = NAME.getPreferredName() + "_step_2_cardinality_of_actual_class";
    private static final String OTHER_BUCKET_KEY = "_other_";
    private static final String DEFAULT_AGG_NAME_PREFIX = "";
    private static final int DEFAULT_SIZE = 10;
    private static final int MAX_SIZE = 1000;
    private final int size;
    private final String aggNamePrefix;
    private final SetOnce<List<String>> topActualClassNames;
    private final SetOnce<Result> result;

    /* loaded from: input_file:org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix$ActualClass.class */
    public static class ActualClass implements ToXContentObject, Writeable {
        private static final ParseField ACTUAL_CLASS = new ParseField("actual_class", new String[0]);
        private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField("actual_class_doc_count", new String[0]);
        private static final ParseField PREDICTED_CLASSES = new ParseField("predicted_classes", new String[0]);
        private static final ParseField OTHER_PREDICTED_CLASS_DOC_COUNT = new ParseField("other_predicted_class_doc_count", new String[0]);
        private static final ConstructingObjectParser<ActualClass, Void> PARSER = new ConstructingObjectParser<>("multiclass_confusion_matrix_actual_class", true, objArr -> {
            return new ActualClass((String) objArr[0], ((Long) objArr[1]).longValue(), (List) objArr[2], ((Long) objArr[3]).longValue());
        });
        private final String actualClass;
        private final long actualClassDocCount;
        private final List<PredictedClass> predictedClasses;
        private final long otherPredictedClassDocCount;

        public ActualClass(String str, long j, List<PredictedClass> list, long j2) {
            this.actualClass = (String) ExceptionsHelper.requireNonNull(str, ACTUAL_CLASS);
            this.actualClassDocCount = MulticlassConfusionMatrix.requireNonNegative(j, ACTUAL_CLASS_DOC_COUNT);
            this.predictedClasses = Collections.unmodifiableList((List) ExceptionsHelper.requireNonNull(list, PREDICTED_CLASSES));
            this.otherPredictedClassDocCount = MulticlassConfusionMatrix.requireNonNegative(j2, OTHER_PREDICTED_CLASS_DOC_COUNT);
        }

        public ActualClass(StreamInput streamInput) throws IOException {
            this.actualClass = streamInput.readString();
            this.actualClassDocCount = streamInput.readVLong();
            this.predictedClasses = Collections.unmodifiableList(streamInput.readList(PredictedClass::new));
            this.otherPredictedClassDocCount = streamInput.readVLong();
        }

        public String getActualClass() {
            return this.actualClass;
        }

        public long getActualClassDocCount() {
            return this.actualClassDocCount;
        }

        public List<PredictedClass> getPredictedClasses() {
            return this.predictedClasses;
        }

        public long getOtherPredictedClassDocCount() {
            return this.otherPredictedClassDocCount;
        }

        public void writeTo(StreamOutput streamOutput) throws IOException {
            streamOutput.writeString(this.actualClass);
            streamOutput.writeVLong(this.actualClassDocCount);
            streamOutput.writeList(this.predictedClasses);
            streamOutput.writeVLong(this.otherPredictedClassDocCount);
        }

        public XContentBuilder toXContent(XContentBuilder xContentBuilder, ToXContent.Params params) throws IOException {
            xContentBuilder.startObject();
            xContentBuilder.field(ACTUAL_CLASS.getPreferredName(), this.actualClass);
            xContentBuilder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), this.actualClassDocCount);
            xContentBuilder.field(PREDICTED_CLASSES.getPreferredName(), this.predictedClasses);
            xContentBuilder.field(OTHER_PREDICTED_CLASS_DOC_COUNT.getPreferredName(), this.otherPredictedClassDocCount);
            xContentBuilder.endObject();
            return xContentBuilder;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || getClass() != obj.getClass()) {
                return false;
            }
            ActualClass actualClass = (ActualClass) obj;
            return Objects.equals(this.actualClass, actualClass.actualClass) && this.actualClassDocCount == actualClass.actualClassDocCount && Objects.equals(this.predictedClasses, actualClass.predictedClasses) && this.otherPredictedClassDocCount == actualClass.otherPredictedClassDocCount;
        }

        public int hashCode() {
            return Objects.hash(this.actualClass, Long.valueOf(this.actualClassDocCount), this.predictedClasses, Long.valueOf(this.otherPredictedClassDocCount));
        }

        static {
            PARSER.declareString(ConstructingObjectParser.constructorArg(), ACTUAL_CLASS);
            PARSER.declareLong(ConstructingObjectParser.constructorArg(), ACTUAL_CLASS_DOC_COUNT);
            PARSER.declareObjectArray(ConstructingObjectParser.constructorArg(), PredictedClass.PARSER, PREDICTED_CLASSES);
            PARSER.declareLong(ConstructingObjectParser.constructorArg(), OTHER_PREDICTED_CLASS_DOC_COUNT);
        }
    }

    /* loaded from: input_file:org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix$PredictedClass.class */
    public static class PredictedClass implements ToXContentObject, Writeable {
        private static final ParseField PREDICTED_CLASS = new ParseField("predicted_class", new String[0]);
        private static final ParseField COUNT = new ParseField(MachineLearningFeatureSetUsage.COUNT, new String[0]);
        private static final ConstructingObjectParser<PredictedClass, Void> PARSER = new ConstructingObjectParser<>("multiclass_confusion_matrix_predicted_class", true, objArr -> {
            return new PredictedClass((String) objArr[0], ((Long) objArr[1]).longValue());
        });
        private final String predictedClass;
        private final long count;

        public PredictedClass(String str, long j) {
            this.predictedClass = (String) ExceptionsHelper.requireNonNull(str, PREDICTED_CLASS);
            this.count = MulticlassConfusionMatrix.requireNonNegative(j, COUNT);
        }

        public PredictedClass(StreamInput streamInput) throws IOException {
            this.predictedClass = streamInput.readString();
            this.count = streamInput.readVLong();
        }

        public String getPredictedClass() {
            return this.predictedClass;
        }

        public long getCount() {
            return this.count;
        }

        public void writeTo(StreamOutput streamOutput) throws IOException {
            streamOutput.writeString(this.predictedClass);
            streamOutput.writeVLong(this.count);
        }

        public XContentBuilder toXContent(XContentBuilder xContentBuilder, ToXContent.Params params) throws IOException {
            xContentBuilder.startObject();
            xContentBuilder.field(PREDICTED_CLASS.getPreferredName(), this.predictedClass);
            xContentBuilder.field(COUNT.getPreferredName(), this.count);
            xContentBuilder.endObject();
            return xContentBuilder;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || getClass() != obj.getClass()) {
                return false;
            }
            PredictedClass predictedClass = (PredictedClass) obj;
            return Objects.equals(this.predictedClass, predictedClass.predictedClass) && this.count == predictedClass.count;
        }

        public int hashCode() {
            return Objects.hash(this.predictedClass, Long.valueOf(this.count));
        }

        static {
            PARSER.declareString(ConstructingObjectParser.constructorArg(), PREDICTED_CLASS);
            PARSER.declareLong(ConstructingObjectParser.constructorArg(), COUNT);
        }
    }

    /* loaded from: input_file:org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix$Result.class */
    public static class Result implements EvaluationMetricResult {
        private static final ParseField CONFUSION_MATRIX = new ParseField("confusion_matrix", new String[0]);
        private static final ParseField OTHER_ACTUAL_CLASS_COUNT = new ParseField("other_actual_class_count", new String[0]);
        private static final ConstructingObjectParser<Result, Void> PARSER = new ConstructingObjectParser<>("multiclass_confusion_matrix_result", true, objArr -> {
            return new Result((List) objArr[0], ((Long) objArr[1]).longValue());
        });
        private final List<ActualClass> actualClasses;
        private final long otherActualClassCount;

        public static Result fromXContent(XContentParser xContentParser) {
            return (Result) PARSER.apply(xContentParser, (Object) null);
        }

        public Result(List<ActualClass> list, long j) {
            this.actualClasses = Collections.unmodifiableList((List) ExceptionsHelper.requireNonNull(list, CONFUSION_MATRIX));
            this.otherActualClassCount = MulticlassConfusionMatrix.requireNonNegative(j, OTHER_ACTUAL_CLASS_COUNT);
        }

        public Result(StreamInput streamInput) throws IOException {
            this.actualClasses = Collections.unmodifiableList(streamInput.readList(ActualClass::new));
            this.otherActualClassCount = streamInput.readVLong();
        }

        public String getWriteableName() {
            return MlEvaluationNamedXContentProvider.registeredMetricName(Classification.NAME, MulticlassConfusionMatrix.NAME);
        }

        @Override // org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult
        public String getMetricName() {
            return MulticlassConfusionMatrix.NAME.getPreferredName();
        }

        public List<ActualClass> getConfusionMatrix() {
            return this.actualClasses;
        }

        public long getOtherActualClassCount() {
            return this.otherActualClassCount;
        }

        public void writeTo(StreamOutput streamOutput) throws IOException {
            streamOutput.writeList(this.actualClasses);
            streamOutput.writeVLong(this.otherActualClassCount);
        }

        public XContentBuilder toXContent(XContentBuilder xContentBuilder, ToXContent.Params params) throws IOException {
            xContentBuilder.startObject();
            xContentBuilder.field(CONFUSION_MATRIX.getPreferredName(), this.actualClasses);
            xContentBuilder.field(OTHER_ACTUAL_CLASS_COUNT.getPreferredName(), this.otherActualClassCount);
            xContentBuilder.endObject();
            return xContentBuilder;
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || getClass() != obj.getClass()) {
                return false;
            }
            Result result = (Result) obj;
            return Objects.equals(this.actualClasses, result.actualClasses) && this.otherActualClassCount == result.otherActualClassCount;
        }

        public int hashCode() {
            return Objects.hash(this.actualClasses, Long.valueOf(this.otherActualClassCount));
        }

        static {
            PARSER.declareObjectArray(ConstructingObjectParser.constructorArg(), ActualClass.PARSER, CONFUSION_MATRIX);
            PARSER.declareLong(ConstructingObjectParser.constructorArg(), OTHER_ACTUAL_CLASS_COUNT);
        }
    }

    private static ConstructingObjectParser<MulticlassConfusionMatrix, Void> createParser() {
        ConstructingObjectParser<MulticlassConfusionMatrix, Void> constructingObjectParser = new ConstructingObjectParser<>(NAME.getPreferredName(), true, objArr -> {
            return new MulticlassConfusionMatrix((Integer) objArr[0], (String) objArr[1]);
        });
        constructingObjectParser.declareInt(ConstructingObjectParser.optionalConstructorArg(), SIZE);
        constructingObjectParser.declareString(ConstructingObjectParser.optionalConstructorArg(), AGG_NAME_PREFIX);
        return constructingObjectParser;
    }

    public static MulticlassConfusionMatrix fromXContent(XContentParser xContentParser) {
        return (MulticlassConfusionMatrix) PARSER.apply(xContentParser, (Object) null);
    }

    public MulticlassConfusionMatrix() {
        this(null, null);
    }

    public MulticlassConfusionMatrix(@Nullable Integer num, @Nullable String str) {
        this.topActualClassNames = new SetOnce<>();
        this.result = new SetOnce<>();
        if (num != null && (num.intValue() <= 0 || num.intValue() > 1000)) {
            throw ExceptionsHelper.badRequestException("[{}] must be an integer in [1, {}]", SIZE.getPreferredName(), 1000);
        }
        this.size = num != null ? num.intValue() : 10;
        this.aggNamePrefix = str != null ? str : DEFAULT_AGG_NAME_PREFIX;
    }

    public MulticlassConfusionMatrix(StreamInput streamInput) throws IOException {
        this.topActualClassNames = new SetOnce<>();
        this.result = new SetOnce<>();
        this.size = streamInput.readVInt();
        if (streamInput.getVersion().onOrAfter(Version.V_7_6_0)) {
            this.aggNamePrefix = streamInput.readString();
        } else {
            this.aggNamePrefix = DEFAULT_AGG_NAME_PREFIX;
        }
    }

    public String getWriteableName() {
        return MlEvaluationNamedXContentProvider.registeredMetricName(Classification.NAME, NAME);
    }

    @Override // org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric
    public String getName() {
        return NAME.getPreferredName();
    }

    public int getSize() {
        return this.size;
    }

    @Override // org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric
    public final Tuple<List<AggregationBuilder>, List<PipelineAggregationBuilder>> aggs(String str, String str2) {
        return this.topActualClassNames.get() == null ? Tuple.tuple(Arrays.asList(AggregationBuilders.terms(aggName(STEP_1_AGGREGATE_BY_ACTUAL_CLASS)).field(str).order(Arrays.asList(BucketOrder.count(false), BucketOrder.key(true))).size(this.size)), Collections.emptyList()) : this.result.get() == null ? Tuple.tuple(Arrays.asList(AggregationBuilders.cardinality(aggName(STEP_2_CARDINALITY_OF_ACTUAL_CLASS)).field(str), AggregationBuilders.filters(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS), (FiltersAggregator.KeyedFilter[]) ((List) this.topActualClassNames.get()).stream().map(str3 -> {
            return new FiltersAggregator.KeyedFilter(str3, QueryBuilders.termQuery(str, str3));
        }).toArray(i -> {
            return new FiltersAggregator.KeyedFilter[i];
        })).subAggregation(AggregationBuilders.filters(aggName(STEP_2_AGGREGATE_BY_PREDICTED_CLASS), (FiltersAggregator.KeyedFilter[]) ((List) this.topActualClassNames.get()).stream().map(str4 -> {
            return new FiltersAggregator.KeyedFilter(str4, QueryBuilders.termQuery(str2, str4));
        }).toArray(i2 -> {
            return new FiltersAggregator.KeyedFilter[i2];
        })).otherBucket(true).otherBucketKey(OTHER_BUCKET_KEY))), Collections.emptyList()) : Tuple.tuple(Collections.emptyList(), Collections.emptyList());
    }

    @Override // org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric
    public void process(Aggregations aggregations) {
        if (this.topActualClassNames.get() == null && aggregations.get(aggName(STEP_1_AGGREGATE_BY_ACTUAL_CLASS)) != null) {
            this.topActualClassNames.set((List) aggregations.get(aggName(STEP_1_AGGREGATE_BY_ACTUAL_CLASS)).getBuckets().stream().map((v0) -> {
                return v0.getKeyAsString();
            }).sorted().collect(Collectors.toList()));
        }
        if (this.result.get() != null || aggregations.get(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS)) == null) {
            return;
        }
        Cardinality cardinality = aggregations.get(aggName(STEP_2_CARDINALITY_OF_ACTUAL_CLASS));
        Filters filters = aggregations.get(aggName(STEP_2_AGGREGATE_BY_ACTUAL_CLASS));
        ArrayList arrayList = new ArrayList(filters.getBuckets().size());
        for (Filters.Bucket bucket : filters.getBuckets()) {
            String keyAsString = bucket.getKeyAsString();
            long docCount = bucket.getDocCount();
            Filters filters2 = bucket.getAggregations().get(aggName(STEP_2_AGGREGATE_BY_PREDICTED_CLASS));
            ArrayList arrayList2 = new ArrayList();
            long j = 0;
            for (Filters.Bucket bucket2 : filters2.getBuckets()) {
                String keyAsString2 = bucket2.getKeyAsString();
                long docCount2 = bucket2.getDocCount();
                if (OTHER_BUCKET_KEY.equals(keyAsString2)) {
                    j = docCount2;
                } else {
                    arrayList2.add(new PredictedClass(keyAsString2, docCount2));
                }
            }
            arrayList2.sort(Comparator.comparing((v0) -> {
                return v0.getPredictedClass();
            }));
            arrayList.add(new ActualClass(keyAsString, docCount, arrayList2, j));
        }
        this.result.set(new Result(arrayList, Math.max(cardinality.getValue() - this.size, 0L)));
    }

    private String aggName(String str) {
        return this.aggNamePrefix + str;
    }

    @Override // org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetric
    public Optional<Result> getResult() {
        return Optional.ofNullable((Result) this.result.get());
    }

    public void writeTo(StreamOutput streamOutput) throws IOException {
        streamOutput.writeVInt(this.size);
        if (streamOutput.getVersion().onOrAfter(Version.V_7_6_0)) {
            streamOutput.writeString(this.aggNamePrefix);
        }
    }

    public XContentBuilder toXContent(XContentBuilder xContentBuilder, ToXContent.Params params) throws IOException {
        xContentBuilder.startObject();
        xContentBuilder.field(SIZE.getPreferredName(), this.size);
        xContentBuilder.endObject();
        return xContentBuilder;
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        MulticlassConfusionMatrix multiclassConfusionMatrix = (MulticlassConfusionMatrix) obj;
        return this.size == multiclassConfusionMatrix.size && Objects.equals(this.aggNamePrefix, multiclassConfusionMatrix.aggNamePrefix);
    }

    public int hashCode() {
        return Objects.hash(Integer.valueOf(this.size), this.aggNamePrefix);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static long requireNonNegative(long j, ParseField parseField) {
        if (j < 0) {
            throw ExceptionsHelper.serverError("[" + parseField.getPreferredName() + "] must be >= 0, was: " + j);
        }
        return j;
    }
}
