package org.elasticsearch.xpack.ml.action;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.admin.cluster.node.stats.NodeStats;
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsAction;
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsRequest;
import org.elasticsearch.action.admin.cluster.node.stats.NodesStatsResponse;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.CheckedConsumer;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.metrics.CounterMetric;
import org.elasticsearch.ingest.IngestMetadata;
import org.elasticsearch.ingest.IngestService;
import org.elasticsearch.ingest.IngestStats;
import org.elasticsearch.ingest.Pipeline;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.inference.ingest.InferenceProcessor;
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;

/* loaded from: input_file:org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction.class */
public class TransportGetTrainedModelsStatsAction extends HandledTransportAction<GetTrainedModelsStatsAction.Request, GetTrainedModelsStatsAction.Response> {
    private final Client client;
    private final ClusterService clusterService;
    private final IngestService ingestService;
    private final TrainedModelProvider trainedModelProvider;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/elasticsearch/xpack/ml/action/TransportGetTrainedModelsStatsAction$IngestStatsAccumulator.class */
    public static class IngestStatsAccumulator {
        CounterMetric ingestCount = new CounterMetric();
        CounterMetric ingestTimeInMillis = new CounterMetric();
        CounterMetric ingestCurrent = new CounterMetric();
        CounterMetric ingestFailedCount = new CounterMetric();
        String type;

        IngestStatsAccumulator() {
        }

        IngestStatsAccumulator(String str) {
            this.type = str;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public IngestStatsAccumulator inc(IngestStats.Stats stats) {
            this.ingestCount.inc(stats.getIngestCount());
            this.ingestTimeInMillis.inc(stats.getIngestTimeInMillis());
            this.ingestCurrent.inc(stats.getIngestCurrent());
            this.ingestFailedCount.inc(stats.getIngestFailedCount());
            return this;
        }

        IngestStats.Stats build() {
            return new IngestStats.Stats(this.ingestCount.count(), this.ingestTimeInMillis.count(), this.ingestCurrent.count(), this.ingestFailedCount.count());
        }
    }

    @Inject
    public TransportGetTrainedModelsStatsAction(TransportService transportService, ActionFilters actionFilters, ClusterService clusterService, IngestService ingestService, TrainedModelProvider trainedModelProvider, Client client) {
        super("cluster:monitor/xpack/ml/inference/stats/get", transportService, actionFilters, GetTrainedModelsStatsAction.Request::new);
        this.client = client;
        this.clusterService = clusterService;
        this.ingestService = ingestService;
        this.trainedModelProvider = trainedModelProvider;
    }

    protected void doExecute(Task task, GetTrainedModelsStatsAction.Request request, ActionListener<GetTrainedModelsStatsAction.Response> actionListener) {
        GetTrainedModelsStatsAction.Response.Builder builder = new GetTrainedModelsStatsAction.Response.Builder();
        CheckedConsumer checkedConsumer = nodesStatsResponse -> {
            actionListener.onResponse(builder.setIngestStatsByModelId(inferenceIngestStatsByPipelineId(nodesStatsResponse, pipelineIdsByModelIds(this.clusterService.state(), this.ingestService, builder.getExpandedIds()))).build());
        };
        Objects.requireNonNull(actionListener);
        ActionListener wrap = ActionListener.wrap(checkedConsumer, actionListener::onFailure);
        CheckedConsumer checkedConsumer2 = tuple -> {
            builder.setExpandedIds((Set) tuple.v2()).setTotalModelCount(((Long) tuple.v1()).longValue());
            ClientHelper.executeAsyncWithOrigin(this.client, MachineLearning.NAME, NodesStatsAction.INSTANCE, new NodesStatsRequest(ingestNodes(this.clusterService.state())).clear().ingest(true), wrap);
        };
        Objects.requireNonNull(actionListener);
        this.trainedModelProvider.expandIds(request.getResourceId(), request.isAllowNoResources(), request.getPageParams(), ActionListener.wrap(checkedConsumer2, actionListener::onFailure));
    }

    static Map<String, IngestStats> inferenceIngestStatsByPipelineId(NodesStatsResponse nodesStatsResponse, Map<String, Set<String>> map) {
        HashMap hashMap = new HashMap();
        map.forEach((str, set) -> {
            hashMap.put(str, mergeStats((List) nodesStatsResponse.getNodes().stream().map(nodeStats -> {
                return ingestStatsForPipelineIds(nodeStats, set);
            }).collect(Collectors.toList())));
        });
        return hashMap;
    }

    static String[] ingestNodes(ClusterState clusterState) {
        String[] strArr = new String[clusterState.nodes().getIngestNodes().size()];
        Iterator keysIt = clusterState.nodes().getIngestNodes().keysIt();
        int i = 0;
        while (keysIt.hasNext()) {
            int i2 = i;
            i++;
            strArr[i2] = (String) keysIt.next();
        }
        return strArr;
    }

    static Map<String, Set<String>> pipelineIdsByModelIds(ClusterState clusterState, IngestService ingestService, Set<String> set) {
        IngestMetadata custom = clusterState.metaData().custom("ingest");
        HashMap hashMap = new HashMap();
        if (custom == null) {
            return hashMap;
        }
        custom.getPipelines().forEach((str, pipelineConfiguration) -> {
            try {
                Pipeline.create(str, pipelineConfiguration.getConfigAsMap(), ingestService.getProcessorFactories(), ingestService.getScriptService()).getProcessors().forEach(processor -> {
                    if (processor instanceof InferenceProcessor) {
                        InferenceProcessor inferenceProcessor = (InferenceProcessor) processor;
                        if (set.contains(inferenceProcessor.getModelId())) {
                            ((Set) hashMap.computeIfAbsent(inferenceProcessor.getModelId(), str -> {
                                return new LinkedHashSet();
                            })).add(str);
                        }
                    }
                });
            } catch (Exception e) {
                throw new ElasticsearchException("unexpected failure gathering pipeline information", e, new Object[0]);
            }
        });
        return hashMap;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static IngestStats ingestStatsForPipelineIds(NodeStats nodeStats, Set<String> set) {
        IngestStats ingestStats = nodeStats.getIngestStats();
        HashMap hashMap = new HashMap(ingestStats.getProcessorStats());
        hashMap.keySet().retainAll(set);
        List list = (List) ingestStats.getPipelineStats().stream().filter(pipelineStat -> {
            return set.contains(pipelineStat.getPipelineId());
        }).collect(Collectors.toList());
        CounterMetric counterMetric = new CounterMetric();
        CounterMetric counterMetric2 = new CounterMetric();
        CounterMetric counterMetric3 = new CounterMetric();
        CounterMetric counterMetric4 = new CounterMetric();
        list.forEach(pipelineStat2 -> {
            IngestStats.Stats stats = pipelineStat2.getStats();
            counterMetric.inc(stats.getIngestCount());
            counterMetric2.inc(stats.getIngestTimeInMillis());
            counterMetric3.inc(stats.getIngestCurrent());
            counterMetric4.inc(stats.getIngestFailedCount());
        });
        return new IngestStats(new IngestStats.Stats(counterMetric.count(), counterMetric2.count(), counterMetric3.count(), counterMetric4.count()), list, hashMap);
    }

    private static IngestStats mergeStats(List<IngestStats> list) {
        LinkedHashMap linkedHashMap = new LinkedHashMap(list.size());
        LinkedHashMap linkedHashMap2 = new LinkedHashMap(list.size());
        IngestStatsAccumulator ingestStatsAccumulator = new IngestStatsAccumulator();
        list.forEach(ingestStats -> {
            ingestStats.getPipelineStats().forEach(pipelineStat -> {
                ((IngestStatsAccumulator) linkedHashMap.computeIfAbsent(pipelineStat.getPipelineId(), str -> {
                    return new IngestStatsAccumulator();
                })).inc(pipelineStat.getStats());
            });
            ingestStats.getProcessorStats().forEach((str, list2) -> {
                Map map = (Map) linkedHashMap2.computeIfAbsent(str, str -> {
                    return new LinkedHashMap();
                });
                list2.forEach(processorStat -> {
                    ((IngestStatsAccumulator) map.computeIfAbsent(processorStat.getName(), str2 -> {
                        return new IngestStatsAccumulator(processorStat.getType());
                    })).inc(processorStat.getStats());
                });
            });
            ingestStatsAccumulator.inc(ingestStats.getTotalStats());
        });
        ArrayList arrayList = new ArrayList(linkedHashMap.size());
        linkedHashMap.forEach((str, ingestStatsAccumulator2) -> {
            arrayList.add(new IngestStats.PipelineStat(str, ingestStatsAccumulator2.build()));
        });
        LinkedHashMap linkedHashMap3 = new LinkedHashMap(linkedHashMap2.size());
        linkedHashMap2.forEach((str2, map) -> {
            ArrayList arrayList2 = new ArrayList(map.size());
            map.forEach((str2, ingestStatsAccumulator3) -> {
                arrayList2.add(new IngestStats.ProcessorStat(str2, ingestStatsAccumulator3.type, ingestStatsAccumulator3.build()));
            });
            linkedHashMap3.put(str2, arrayList2);
        });
        return new IngestStats(ingestStatsAccumulator.build(), arrayList, linkedHashMap3);
    }

    protected /* bridge */ /* synthetic */ void doExecute(Task task, ActionRequest actionRequest, ActionListener actionListener) {
        doExecute(task, (GetTrainedModelsStatsAction.Request) actionRequest, (ActionListener<GetTrainedModelsStatsAction.Response>) actionListener);
    }
}
