/*
 * Decompiled with CFR 0.152.
 */
package com.o19s.es.ltr.query;

import com.o19s.es.ltr.LtrQueryContext;
import com.o19s.es.ltr.feature.Feature;
import com.o19s.es.ltr.feature.FeatureSet;
import com.o19s.es.ltr.feature.LtrModel;
import com.o19s.es.ltr.feature.PrebuiltLtrModel;
import com.o19s.es.ltr.query.LtrRewritableQuery;
import com.o19s.es.ltr.query.LtrRewriteContext;
import com.o19s.es.ltr.query.NoopScorer;
import com.o19s.es.ltr.ranker.LogLtrRanker;
import com.o19s.es.ltr.ranker.LtrRanker;
import com.o19s.es.ltr.ranker.NullRanker;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.RandomAccess;
import java.util.Set;
import java.util.stream.Stream;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.ConstantScoreScorer;
import org.apache.lucene.search.ConstantScoreWeight;
import org.apache.lucene.search.DisiPriorityQueue;
import org.apache.lucene.search.DisiWrapper;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.ScorerSupplier;
import org.apache.lucene.search.Weight;
import org.opensearch.ltr.settings.LTRSettings;
import org.opensearch.ltr.stats.LTRStats;
import org.opensearch.ltr.stats.StatName;

public class RankerQuery
extends Query {
    private static final ThreadLocal<LtrRanker.FeatureVector> CURRENT_VECTOR = new ThreadLocal();
    private final LTRStats ltrStats;
    private final List<Query> queries;
    private final FeatureSet features;
    private final LtrRanker ranker;
    private final Map<Integer, float[]> featureScoreCache;

    private RankerQuery(List<Query> queries, FeatureSet features, LtrRanker ranker, Map<Integer, float[]> featureScoreCache, LTRStats ltrStats) {
        this.queries = Objects.requireNonNull(queries);
        this.features = Objects.requireNonNull(features);
        this.ranker = Objects.requireNonNull(ranker);
        this.featureScoreCache = featureScoreCache;
        this.ltrStats = ltrStats;
    }

    public static RankerQuery build(PrebuiltLtrModel model, LTRStats ltrStats) {
        return RankerQuery.build(model.ranker(), model.featureSet(), new LtrQueryContext(null, Collections.emptySet()), Collections.emptyMap(), false, ltrStats);
    }

    public static RankerQuery build(LtrModel model, LtrQueryContext context, Map<String, Object> params, Boolean featureScoreCacheFlag, LTRStats ltrStats) {
        return RankerQuery.build(model.ranker(), model.featureSet(), context, params, featureScoreCacheFlag, ltrStats);
    }

    private static RankerQuery build(LtrRanker ranker, FeatureSet features, LtrQueryContext context, Map<String, Object> params, Boolean featureScoreCacheFlag, LTRStats ltrStats) {
        List<Query> queries = features.toQueries(context, params);
        HashMap<Integer, float[]> featureScoreCache = null;
        if (null != featureScoreCacheFlag && featureScoreCacheFlag.booleanValue()) {
            featureScoreCache = new HashMap<Integer, float[]>();
        }
        return new RankerQuery(queries, features, ranker, featureScoreCache, ltrStats);
    }

    public static RankerQuery buildLogQuery(LogLtrRanker.LogConsumer consumer, FeatureSet features, LtrQueryContext context, Map<String, Object> params, LTRStats ltrStats) {
        List<Query> queries = features.toQueries(context, params);
        return new RankerQuery(queries, features, new LogLtrRanker(consumer, features.size()), null, ltrStats);
    }

    public RankerQuery toLoggerQuery(LogLtrRanker.LogConsumer consumer) {
        NullRanker newRanker = new NullRanker(this.features.size());
        return new RankerQuery(this.queries, this.features, new LogLtrRanker(newRanker, consumer), this.featureScoreCache, this.ltrStats);
    }

    public Query rewrite(IndexSearcher reader) throws IOException {
        ArrayList<Query> rewrittenQueries = new ArrayList<Query>(this.queries.size());
        boolean rewritten = false;
        for (Query query : this.queries) {
            Query rewrittenQuery = query.rewrite(reader);
            rewritten |= rewrittenQuery != query;
            rewrittenQueries.add(rewrittenQuery);
        }
        return rewritten ? new RankerQuery(rewrittenQueries, this.features, this.ranker, this.featureScoreCache, this.ltrStats) : this;
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (!this.sameClassAs(obj)) {
            return false;
        }
        RankerQuery that = (RankerQuery)((Object)obj);
        return Objects.deepEquals(this.queries, that.queries) && Objects.deepEquals(this.features, that.features) && Objects.equals(this.ranker, that.ranker);
    }

    Stream<Query> stream() {
        return this.queries.stream();
    }

    public int hashCode() {
        return 31 * this.classHash() + Objects.hash(this.features, this.queries, this.ranker);
    }

    public String toString(String field) {
        return "rankerquery:" + field;
    }

    Feature getFeature(int ordinal) {
        return this.features.feature(ordinal);
    }

    LtrRanker ranker() {
        return this.ranker;
    }

    public FeatureSet featureSet() {
        return this.features;
    }

    public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
        if (!LTRSettings.isLTRPluginEnabled()) {
            throw new IllegalStateException("LTR plugin is disabled. To enable, update ltr.plugin.enabled to true");
        }
        try {
            return this.createWeightInternal(searcher, scoreMode, boost);
        }
        catch (Exception e) {
            this.ltrStats.getStat(StatName.LTR_REQUEST_ERROR_COUNT.getName()).increment();
            throw e;
        }
    }

    private Weight createWeightInternal(IndexSearcher searcher, final ScoreMode scoreMode, float boost) throws IOException {
        if (!scoreMode.needsScores()) {
            return new ConstantScoreWeight(this, this, boost){

                public ScorerSupplier scorerSupplier(final LeafReaderContext context) throws IOException {
                    return new ScorerSupplier(){

                        public Scorer get(long leadCost) throws IOException {
                            return new ConstantScoreScorer(this.score(), scoreMode, DocIdSetIterator.all((int)context.reader().maxDoc()));
                        }

                        public long cost() {
                            return context.reader().maxDoc();
                        }
                    };
                }

                public boolean isCacheable(LeafReaderContext ctx) {
                    return false;
                }
            };
        }
        ArrayList<Weight> weights = new ArrayList<Weight>(this.queries.size());
        FVLtrRankerWrapper ltrRankerWrapper = new FVLtrRankerWrapper(this.ranker);
        LtrRewriteContext context = new LtrRewriteContext(this.ranker, CURRENT_VECTOR::get);
        for (Query q : this.queries) {
            if (q instanceof LtrRewritableQuery) {
                q = ((LtrRewritableQuery)q).ltrRewrite(context);
            }
            weights.add(searcher.createWeight(q, ScoreMode.COMPLETE, boost));
        }
        return new RankerWeight(this, weights, ltrRankerWrapper, this.features, this.featureScoreCache);
    }

    public void visit(QueryVisitor visitor) {
        QueryVisitor v = visitor.getSubVisitor(BooleanClause.Occur.SHOULD, (Query)this);
        for (Query q : this.queries) {
            q.visit(v);
        }
    }

    static class FVLtrRankerWrapper
    implements LtrRanker {
        private final LtrRanker wrapped;

        FVLtrRankerWrapper(LtrRanker wrapped) {
            this.wrapped = Objects.requireNonNull(wrapped);
        }

        @Override
        public String name() {
            return this.wrapped.name();
        }

        @Override
        public LtrRanker.FeatureVector newFeatureVector(LtrRanker.FeatureVector reuse) {
            LtrRanker.FeatureVector fv = this.wrapped.newFeatureVector(reuse);
            CURRENT_VECTOR.set(fv);
            return fv;
        }

        @Override
        public float score(LtrRanker.FeatureVector point) {
            float score = this.wrapped.score(point);
            CURRENT_VECTOR.remove();
            return score;
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            FVLtrRankerWrapper that = (FVLtrRankerWrapper)o;
            return Objects.equals(this.wrapped, that.wrapped);
        }

        public int hashCode() {
            return Objects.hash(this.wrapped);
        }
    }

    public static class RankerWeight
    extends Weight {
        private final List<Weight> weights;
        private final FVLtrRankerWrapper ranker;
        private final FeatureSet features;
        private final Map<Integer, float[]> featureScoreCache;

        RankerWeight(RankerQuery query, List<Weight> weights, FVLtrRankerWrapper ranker, FeatureSet features, Map<Integer, float[]> featureScoreCache) {
            super((Query)query);
            assert (weights instanceof RandomAccess);
            this.weights = weights;
            this.ranker = Objects.requireNonNull(ranker);
            this.features = Objects.requireNonNull(features);
            this.featureScoreCache = featureScoreCache;
        }

        public boolean isCacheable(LeafReaderContext ctx) {
            return false;
        }

        public void extractTerms(Set<Term> terms) {
            for (Weight w : this.weights) {
                QueryVisitor.termCollector(terms);
            }
        }

        public Explanation explain(LeafReaderContext context, int doc) throws IOException {
            ArrayList<Explanation> subs = new ArrayList<Explanation>(this.weights.size());
            LtrRanker.FeatureVector d = this.ranker.newFeatureVector(null);
            int ordinal = -1;
            for (Weight weight : this.weights) {
                Explanation explain = weight.explain(context, doc);
                String featureString = "Feature " + Integer.toString(++ordinal);
                if (this.features.feature(ordinal).name() != null) {
                    featureString = featureString + "(" + this.features.feature(ordinal).name() + ")";
                }
                featureString = featureString + ":";
                if (!explain.isMatch()) {
                    subs.add(Explanation.noMatch((String)(featureString + " [no match, default value 0.0 used]"), (Explanation[])new Explanation[0]));
                    continue;
                }
                subs.add(Explanation.match((Number)explain.getValue(), (String)featureString, (Explanation[])new Explanation[]{explain}));
                d.setFeatureScore(ordinal, explain.getValue().floatValue());
            }
            float modelScore = this.ranker.score(d);
            return Explanation.match((Number)Float.valueOf(modelScore), (String)(" LtrModel: " + this.ranker.name() + " using features:"), subs);
        }

        public RankerScorer getScorer(LeafReaderContext context) throws IOException {
            ArrayList<Scorer> scorers = new ArrayList<Scorer>(this.weights.size());
            DisiPriorityQueue disiPriorityQueue = DisiPriorityQueue.ofMaxSize((int)this.weights.size());
            for (Weight weight : this.weights) {
                Scorer scorer = weight.scorer(context);
                if (scorer == null) {
                    scorer = new NoopScorer((Weight)this, DocIdSetIterator.empty());
                }
                scorers.add(scorer);
                disiPriorityQueue.add(new DisiWrapper(scorer, false));
            }
            DisjunctionDISI rankerIterator = new DisjunctionDISI(DocIdSetIterator.all((int)context.reader().maxDoc()), disiPriorityQueue, context.docBase, this.featureScoreCache);
            return new RankerScorer(this, scorers, rankerIterator, this.ranker, context.docBase, this.featureScoreCache);
        }

        public ScorerSupplier scorerSupplier(final LeafReaderContext context) throws IOException {
            return new ScorerSupplier(){

                public Scorer get(long leadCost) throws IOException {
                    return this.getScorer(context);
                }

                public long cost() {
                    return context.reader().maxDoc();
                }
            };
        }

        class RankerScorer
        extends Scorer {
            private final List<Scorer> scorers;
            private final DisjunctionDISI iterator;
            private final FVLtrRankerWrapper ranker;
            private LtrRanker.FeatureVector fv;
            private final int docBase;
            private final Map<Integer, float[]> featureScoreCache;

            RankerScorer(RankerWeight this$0, List<Scorer> scorers, DisjunctionDISI iterator, FVLtrRankerWrapper ranker, int docBase, Map<Integer, float[]> featureScoreCache) {
                this.scorers = scorers;
                this.iterator = iterator;
                this.ranker = ranker;
                this.docBase = docBase;
                this.featureScoreCache = featureScoreCache;
            }

            public int docID() {
                return this.iterator.docID();
            }

            public float score() throws IOException {
                this.fv = this.ranker.newFeatureVector(this.fv);
                if (this.featureScoreCache == null) {
                    int ordinal = -1;
                    for (Scorer scorer : this.scorers) {
                        ++ordinal;
                        if (scorer.docID() != this.docID()) continue;
                        this.fv.setFeatureScore(ordinal, scorer.score());
                    }
                } else {
                    int perShardDocId = this.docBase + this.docID();
                    if (this.featureScoreCache.containsKey(perShardDocId)) {
                        float[] featureScores = this.featureScoreCache.get(perShardDocId);
                        int ordinal = -1;
                        for (float score : featureScores) {
                            ++ordinal;
                            if (Float.isNaN(score)) continue;
                            this.fv.setFeatureScore(ordinal, score);
                        }
                    } else {
                        int ordinal = -1;
                        float[] featureScores = new float[this.scorers.size()];
                        for (Scorer scorer : this.scorers) {
                            ++ordinal;
                            float score = Float.NaN;
                            if (scorer.docID() == this.docID()) {
                                score = scorer.score();
                                this.fv.setFeatureScore(ordinal, score);
                            }
                            featureScores[ordinal] = score;
                        }
                        this.featureScoreCache.put(perShardDocId, featureScores);
                    }
                }
                return this.ranker.score(this.fv);
            }

            public DocIdSetIterator iterator() {
                return this.iterator;
            }

            public float getMaxScore(int upTo) throws IOException {
                return Float.POSITIVE_INFINITY;
            }
        }
    }

    static class DisjunctionDISI
    extends DocIdSetIterator {
        private final DocIdSetIterator main;
        private final DisiPriorityQueue subIteratorsPriorityQueue;
        private final int docBase;
        private final Map<Integer, float[]> featureScoreCache;

        DisjunctionDISI(DocIdSetIterator main, DisiPriorityQueue subIteratorsPriorityQueue, int docBase, Map<Integer, float[]> featureScoreCache) {
            this.main = main;
            this.subIteratorsPriorityQueue = subIteratorsPriorityQueue;
            this.docBase = docBase;
            this.featureScoreCache = featureScoreCache;
        }

        public int docID() {
            return this.main.docID();
        }

        public int nextDoc() throws IOException {
            int doc = this.main.nextDoc();
            this.advanceSubIterators(doc);
            return doc;
        }

        public int advance(int target) throws IOException {
            int docId = this.main.advance(target);
            if (this.featureScoreCache != null && this.featureScoreCache.containsKey(this.docBase + target)) {
                return docId;
            }
            this.advanceSubIterators(docId);
            return docId;
        }

        private void advanceSubIterators(int target) throws IOException {
            if (target == Integer.MAX_VALUE) {
                return;
            }
            DisiWrapper top = this.subIteratorsPriorityQueue.top();
            while (top.doc < target) {
                top.doc = top.iterator.advance(target);
                top = this.subIteratorsPriorityQueue.updateTop();
            }
        }

        public long cost() {
            return this.main.cost();
        }
    }
}

