package org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.utils.Statistics;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

/* loaded from: input_file:org/elasticsearch/xpack/core/ml/inference/trainedmodel/langident/LangIdentNeuralNetwork.class */
public class LangIdentNeuralNetwork implements StrictlyParsedTrainedModel, LenientlyParsedTrainedModel {
    public static final ParseField NAME;
    public static final ParseField EMBEDDED_VECTOR_FEATURE_NAME;
    public static final ParseField HIDDEN_LAYER;
    public static final ParseField SOFTMAX_LAYER;
    public static final ConstructingObjectParser<LangIdentNeuralNetwork, Void> STRICT_PARSER;
    public static final ConstructingObjectParser<LangIdentNeuralNetwork, Void> LENIENT_PARSER;
    private static final List<String> LANGUAGE_NAMES;
    private static final long SHALLOW_SIZE;
    private static final int EMBEDDING_VECTOR_LENGTH = 80;
    private final LangNetLayer hiddenLayer;
    private final LangNetLayer softmaxLayer;
    private final String embeddedVectorFeatureName;
    static final /* synthetic */ boolean $assertionsDisabled;

    private static ConstructingObjectParser<LangIdentNeuralNetwork, Void> createParser(boolean z) {
        ConstructingObjectParser<LangIdentNeuralNetwork, Void> constructingObjectParser = new ConstructingObjectParser<>(NAME.getPreferredName(), z, objArr -> {
            return new LangIdentNeuralNetwork((String) objArr[0], (LangNetLayer) objArr[1], (LangNetLayer) objArr[2]);
        });
        constructingObjectParser.declareString(ConstructingObjectParser.constructorArg(), EMBEDDED_VECTOR_FEATURE_NAME);
        constructingObjectParser.declareObject(ConstructingObjectParser.constructorArg(), (xContentParser, r6) -> {
            return z ? (LangNetLayer) LangNetLayer.LENIENT_PARSER.apply(xContentParser, r6) : (LangNetLayer) LangNetLayer.STRICT_PARSER.apply(xContentParser, r6);
        }, HIDDEN_LAYER);
        constructingObjectParser.declareObject(ConstructingObjectParser.constructorArg(), (xContentParser2, r62) -> {
            return z ? (LangNetLayer) LangNetLayer.LENIENT_PARSER.apply(xContentParser2, r62) : (LangNetLayer) LangNetLayer.STRICT_PARSER.apply(xContentParser2, r62);
        }, SOFTMAX_LAYER);
        return constructingObjectParser;
    }

    public static LangIdentNeuralNetwork fromXContentStrict(XContentParser xContentParser) {
        return (LangIdentNeuralNetwork) STRICT_PARSER.apply(xContentParser, (Object) null);
    }

    public static LangIdentNeuralNetwork fromXContentLenient(XContentParser xContentParser) {
        return (LangIdentNeuralNetwork) LENIENT_PARSER.apply(xContentParser, (Object) null);
    }

    public LangIdentNeuralNetwork(String str, LangNetLayer langNetLayer, LangNetLayer langNetLayer2) {
        this.embeddedVectorFeatureName = (String) ExceptionsHelper.requireNonNull(str, EMBEDDED_VECTOR_FEATURE_NAME);
        this.hiddenLayer = (LangNetLayer) ExceptionsHelper.requireNonNull(langNetLayer, HIDDEN_LAYER);
        this.softmaxLayer = (LangNetLayer) ExceptionsHelper.requireNonNull(langNetLayer2, SOFTMAX_LAYER);
    }

    public LangIdentNeuralNetwork(StreamInput streamInput) throws IOException {
        this.embeddedVectorFeatureName = streamInput.readString();
        this.hiddenLayer = new LangNetLayer(streamInput);
        this.softmaxLayer = new LangNetLayer(streamInput);
    }

    @Override // org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel
    public InferenceResults infer(Map<String, Object> map, InferenceConfig inferenceConfig) {
        if (!(inferenceConfig instanceof ClassificationConfig)) {
            throw ExceptionsHelper.badRequestException("[{}] model only supports classification", NAME.getPreferredName());
        }
        Object obj = map.get(this.embeddedVectorFeatureName);
        if (!(obj instanceof double[])) {
            throw ExceptionsHelper.badRequestException("[{}] model could not find non-null numerical array named [{}]", NAME.getPreferredName(), this.embeddedVectorFeatureName);
        }
        double[] dArr = (double[]) obj;
        if (dArr.length != EMBEDDING_VECTOR_LENGTH) {
            throw ExceptionsHelper.badRequestException("[{}] model is expecting embedding vector of length [{}] but got [{}]", NAME.getPreferredName(), Integer.valueOf(EMBEDDING_VECTOR_LENGTH), Integer.valueOf(dArr.length));
        }
        List<Double> softMax = Statistics.softMax((List) Arrays.stream(this.softmaxLayer.productPlusBias(true, this.hiddenLayer.productPlusBias(false, dArr))).boxed().collect(Collectors.toList()));
        ClassificationConfig classificationConfig = (ClassificationConfig) inferenceConfig;
        Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>> tuple = InferenceHelpers.topClasses(softMax, LANGUAGE_NAMES, null, classificationConfig.getNumTopClasses());
        if ($assertionsDisabled || (((Integer) tuple.v1()).intValue() >= 0 && ((Integer) tuple.v1()).intValue() < LANGUAGE_NAMES.size())) {
            return new ClassificationInferenceResults(((Integer) tuple.v1()).intValue(), LANGUAGE_NAMES.get(((Integer) tuple.v1()).intValue()), (List) tuple.v2(), classificationConfig);
        }
        throw new AssertionError("Invalid language predicted. Predicted language index " + tuple.v1());
    }

    @Override // org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel
    public TargetType targetType() {
        return TargetType.CLASSIFICATION;
    }

    @Override // org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel
    public void validate() {
    }

    @Override // org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel
    public long estimatedNumOperations() {
        return this.hiddenLayer.getBias().length + this.hiddenLayer.getWeights().length + this.softmaxLayer.getBias().length + this.softmaxLayer.getWeights().length;
    }

    public long ramBytesUsed() {
        return SHALLOW_SIZE + RamUsageEstimator.sizeOf(this.hiddenLayer) + RamUsageEstimator.sizeOf(this.softmaxLayer);
    }

    public String getWriteableName() {
        return NAME.getPreferredName();
    }

    public void writeTo(StreamOutput streamOutput) throws IOException {
        streamOutput.writeString(this.embeddedVectorFeatureName);
        this.hiddenLayer.writeTo(streamOutput);
        this.softmaxLayer.writeTo(streamOutput);
    }

    @Override // org.elasticsearch.xpack.core.ml.utils.NamedXContentObject
    public String getName() {
        return NAME.getPreferredName();
    }

    public XContentBuilder toXContent(XContentBuilder xContentBuilder, ToXContent.Params params) throws IOException {
        xContentBuilder.startObject();
        xContentBuilder.field(EMBEDDED_VECTOR_FEATURE_NAME.getPreferredName(), this.embeddedVectorFeatureName);
        xContentBuilder.field(HIDDEN_LAYER.getPreferredName(), this.hiddenLayer);
        xContentBuilder.field(SOFTMAX_LAYER.getPreferredName(), this.softmaxLayer);
        xContentBuilder.endObject();
        return xContentBuilder;
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        LangIdentNeuralNetwork langIdentNeuralNetwork = (LangIdentNeuralNetwork) obj;
        return Objects.equals(this.embeddedVectorFeatureName, langIdentNeuralNetwork.embeddedVectorFeatureName) && Objects.equals(this.hiddenLayer, langIdentNeuralNetwork.hiddenLayer) && Objects.equals(this.softmaxLayer, langIdentNeuralNetwork.softmaxLayer);
    }

    public int hashCode() {
        return Objects.hash(this.embeddedVectorFeatureName, this.hiddenLayer, this.softmaxLayer);
    }

    static {
        $assertionsDisabled = !LangIdentNeuralNetwork.class.desiredAssertionStatus();
        NAME = new ParseField("lang_ident_neural_network", new String[0]);
        EMBEDDED_VECTOR_FEATURE_NAME = new ParseField("embedded_vector_feature_name", new String[0]);
        HIDDEN_LAYER = new ParseField("hidden_layer", new String[0]);
        SOFTMAX_LAYER = new ParseField("softmax_layer", new String[0]);
        STRICT_PARSER = createParser(false);
        LENIENT_PARSER = createParser(true);
        LANGUAGE_NAMES = Arrays.asList("eo", "co", "eu", "ta", "de", "mt", "ps", "te", "su", "uz", "zh-Latn", "ne", "nl", "sw", "sq", "hmn", "ja", "no", "mn", "so", "ko", "kk", "sl", "ig", "mr", "th", "zu", "ml", "hr", "bs", "lo", "sd", "cy", "hy", "uk", "pt", "lv", "iw", "cs", "vi", "jv", "be", "km", "mk", "tr", "fy", "am", "zh", "da", "sv", "fi", "ht", "af", "la", "id", "fil", "sm", "ca", "el", "ka", "sr", "it", "sk", "ru", "ru-Latn", "bg", "ny", "fa", "haw", "gl", "et", "ms", "gd", "bg-Latn", "ha", "is", "ur", "mi", "hi", "bn", "hi-Latn", "fr", "yi", "hu", "xh", "my", "tg", "ro", "ar", "lb", "el-Latn", "st", "ceb", "kn", "az", "si", "ky", "mg", "en", "gu", "es", "pl", "ja-Latn", "ga", "lt", "sn", "yo", "pa", "ku");
        SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(LangIdentNeuralNetwork.class);
    }
}
