package org.elasticsearch.xpack.ml.action;

import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionListenerResponseHandler;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.CheckedConsumer;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.ml.action.ExplainDataFrameAnalyticsAction;
import org.elasticsearch.xpack.core.ml.action.PutDataFrameAnalyticsAction;
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
import org.elasticsearch.xpack.core.ml.dataframe.explain.FieldSelection;
import org.elasticsearch.xpack.core.ml.dataframe.explain.MemoryEstimation;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.dataframe.extractor.DataFrameDataExtractorFactory;
import org.elasticsearch.xpack.ml.dataframe.extractor.ExtractedFieldsDetector;
import org.elasticsearch.xpack.ml.dataframe.extractor.ExtractedFieldsDetectorFactory;
import org.elasticsearch.xpack.ml.dataframe.process.MemoryUsageEstimationProcessManager;
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;

/* loaded from: input_file:org/elasticsearch/xpack/ml/action/TransportExplainDataFrameAnalyticsAction.class */
public class TransportExplainDataFrameAnalyticsAction extends HandledTransportAction<PutDataFrameAnalyticsAction.Request, ExplainDataFrameAnalyticsAction.Response> {
    private final XPackLicenseState licenseState;
    private final TransportService transportService;
    private final ClusterService clusterService;
    private final NodeClient client;
    private final MemoryUsageEstimationProcessManager processManager;

    @Inject
    public TransportExplainDataFrameAnalyticsAction(TransportService transportService, ActionFilters actionFilters, ClusterService clusterService, NodeClient nodeClient, XPackLicenseState xPackLicenseState, MemoryUsageEstimationProcessManager memoryUsageEstimationProcessManager) {
        super("cluster:admin/xpack/ml/data_frame/analytics/explain", transportService, actionFilters, PutDataFrameAnalyticsAction.Request::new);
        this.transportService = transportService;
        this.clusterService = (ClusterService) Objects.requireNonNull(clusterService);
        this.client = (NodeClient) Objects.requireNonNull(nodeClient);
        this.licenseState = xPackLicenseState;
        this.processManager = (MemoryUsageEstimationProcessManager) Objects.requireNonNull(memoryUsageEstimationProcessManager);
    }

    protected void doExecute(Task task, PutDataFrameAnalyticsAction.Request request, ActionListener<ExplainDataFrameAnalyticsAction.Response> actionListener) {
        if (!this.licenseState.isMachineLearningAllowed()) {
            actionListener.onFailure(LicenseUtils.newComplianceException(MachineLearning.NAME));
        } else if (MachineLearning.isMlNode(this.clusterService.localNode())) {
            explain(task, request, actionListener);
        } else {
            redirectToMlNode(request, actionListener);
        }
    }

    private void explain(Task task, PutDataFrameAnalyticsAction.Request request, ActionListener<ExplainDataFrameAnalyticsAction.Response> actionListener) {
        ExtractedFieldsDetectorFactory extractedFieldsDetectorFactory = new ExtractedFieldsDetectorFactory(this.client);
        DataFrameAnalyticsConfig config = request.getConfig();
        CheckedConsumer checkedConsumer = extractedFieldsDetector -> {
            explain(task, request, extractedFieldsDetector, actionListener);
        };
        Objects.requireNonNull(actionListener);
        extractedFieldsDetectorFactory.createFromSource(config, ActionListener.wrap(checkedConsumer, actionListener::onFailure));
    }

    private void explain(Task task, PutDataFrameAnalyticsAction.Request request, ExtractedFieldsDetector extractedFieldsDetector, ActionListener<ExplainDataFrameAnalyticsAction.Response> actionListener) {
        Tuple<ExtractedFields, List<FieldSelection>> detect = extractedFieldsDetector.detect();
        CheckedConsumer checkedConsumer = memoryEstimation -> {
            actionListener.onResponse(new ExplainDataFrameAnalyticsAction.Response((List) detect.v2(), memoryEstimation));
        };
        Objects.requireNonNull(actionListener);
        estimateMemoryUsage(task, request, (ExtractedFields) detect.v1(), ActionListener.wrap(checkedConsumer, actionListener::onFailure));
    }

    private void estimateMemoryUsage(Task task, PutDataFrameAnalyticsAction.Request request, ExtractedFields extractedFields, ActionListener<MemoryEstimation> actionListener) {
        String str = "memory_usage_estimation_" + task.getId();
        DataFrameDataExtractorFactory createForSourceIndices = DataFrameDataExtractorFactory.createForSourceIndices(this.client, str, request.getConfig(), extractedFields);
        MemoryUsageEstimationProcessManager memoryUsageEstimationProcessManager = this.processManager;
        DataFrameAnalyticsConfig config = request.getConfig();
        CheckedConsumer checkedConsumer = memoryUsageEstimationResult -> {
            actionListener.onResponse(new MemoryEstimation(memoryUsageEstimationResult.getExpectedMemoryWithoutDisk(), memoryUsageEstimationResult.getExpectedMemoryWithDisk()));
        };
        Objects.requireNonNull(actionListener);
        memoryUsageEstimationProcessManager.runJobAsync(str, config, createForSourceIndices, ActionListener.wrap(checkedConsumer, actionListener::onFailure));
    }

    private void redirectToMlNode(PutDataFrameAnalyticsAction.Request request, ActionListener<ExplainDataFrameAnalyticsAction.Response> actionListener) {
        Optional<DiscoveryNode> findMlNode = findMlNode(this.clusterService.state());
        if (findMlNode.isPresent()) {
            this.transportService.sendRequest(findMlNode.get(), this.actionName, request, new ActionListenerResponseHandler(actionListener, ExplainDataFrameAnalyticsAction.Response::new));
        } else {
            actionListener.onFailure(ExceptionsHelper.badRequestException("No ML node to run on", new Object[0]));
        }
    }

    private static Optional<DiscoveryNode> findMlNode(ClusterState clusterState) {
        Iterator it = clusterState.getNodes().iterator();
        while (it.hasNext()) {
            DiscoveryNode discoveryNode = (DiscoveryNode) it.next();
            if (MachineLearning.isMlNode(discoveryNode)) {
                return Optional.of(discoveryNode);
            }
        }
        return Optional.empty();
    }

    protected /* bridge */ /* synthetic */ void doExecute(Task task, ActionRequest actionRequest, ActionListener actionListener) {
        doExecute(task, (PutDataFrameAnalyticsAction.Request) actionRequest, (ActionListener<ExplainDataFrameAnalyticsAction.Response>) actionListener);
    }
}
