/* * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ package com.o19s.es.ltr.feature.store; import com.o19s.es.ltr.LtrQueryContext; import com.o19s.es.ltr.feature.Feature; import com.o19s.es.ltr.feature.FeatureSet; import com.o19s.es.ltr.query.LtrRewritableQuery; import com.o19s.es.ltr.query.LtrRewriteContext; import com.o19s.es.ltr.ranker.LogLtrRanker; import com.o19s.es.termstat.TermStatSupplier; import org.apache.lucene.analysis.Analyzer; import org.apache.lucene.analysis.TokenStream; import org.apache.lucene.analysis.tokenattributes.TermToBytesRefAttribute; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.Term; import org.apache.lucene.index.TermStates; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.Query; import org.apache.lucene.search.QueryVisitor; import org.apache.lucene.search.Weight; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; import org.apache.lucene.search.Explanation; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.MatchAllDocsQuery; import org.opensearch.common.lucene.search.function.LeafScoreFunction; import org.opensearch.common.lucene.search.function.ScriptScoreFunction; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.NamedXContentRegistry; import org.opensearch.common.xcontent.XContentParser; import org.opensearch.common.xcontent.XContentType; import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.script.ScoreScript; import org.opensearch.script.Script; import java.io.IOException; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; import java.util.stream.Collectors; public class ScriptFeature implements Feature { public static final String TEMPLATE_LANGUAGE = "script_feature"; public static final String FEATURE_VECTOR = "feature_vector"; public static final String TERM_STAT = "termStats"; public static final String MATCH_COUNT = "matchCount"; public static final String UNIQUE_TERMS = "uniqueTerms"; public static final String EXTRA_LOGGING = "extra_logging"; public static final String EXTRA_SCRIPT_PARAMS = "extra_script_params"; private final String name; private final Script script; private final Collection queryParams; private final Map baseScriptParams; private final Map extraScriptParams; @SuppressWarnings("unchecked") public ScriptFeature(String name, Script script, Collection queryParams) { this.name = Objects.requireNonNull(name); this.script = Objects.requireNonNull(script); this.queryParams = queryParams; Map ltrScriptParams = new HashMap<>(); Map ltrExtraScriptParams = new HashMap<>(); for (Map.Entry entry : script.getParams().entrySet()) { if (!entry.getKey().equals(EXTRA_SCRIPT_PARAMS)) { ltrScriptParams.put(String.valueOf(entry.getKey()), entry.getValue()); } else { ltrExtraScriptParams = (Map) entry.getValue(); } } this.baseScriptParams = ltrScriptParams; this.extraScriptParams = ltrExtraScriptParams; } public static ScriptFeature compile(StoredFeature feature) { try { XContentParser xContentParser = XContentType.JSON.xContent().createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, feature.template()); return new ScriptFeature(feature.name(), Script.parse(xContentParser, "native"), feature.queryParams()); } catch (IOException e) { throw new RuntimeException(e); } } /** * The feature name */ @Override public String name() { return name; } /** * Transform this feature into a lucene query */ @Override @SuppressWarnings("unchecked") public Query doToQuery(LtrQueryContext context, FeatureSet featureSet, Map params) { List missingParams = queryParams.stream() .filter((x) -> !params.containsKey(x)) .collect(Collectors.toList()); if (!missingParams.isEmpty()) { String names = String.join(",", missingParams); throw new IllegalArgumentException("Missing required param(s): [" + names + "]"); } Map queryTimeParams = new HashMap<>(); Map extraQueryTimeParams = new HashMap<>(); for (String x : queryParams) { if (params.containsKey(x)) { /* If extra_script_param then add the appropriate param name for the script else add name:value as is */ if (extraScriptParams.containsKey(x)) { extraQueryTimeParams.put(extraScriptParams.get(x), params.get(x)); } else { queryTimeParams.put(x, params.get(x)); } } } FeatureSupplier supplier = new FeatureSupplier(featureSet); ExtraLoggingSupplier extraLoggingSupplier = new ExtraLoggingSupplier(); TermStatSupplier termstatSupplier = new TermStatSupplier(); Map nparams = new HashMap<>(); // Parse terms if set Set terms = new HashSet<>(); if (baseScriptParams.containsKey("term_stat")) { HashMap termspec = (HashMap) baseScriptParams.get("term_stat"); String analyzerName = null; ArrayList fields = null; ArrayList termList = null; final Object analyzerNameObj = termspec.get("analyzer"); final Object fieldsObj = termspec.get("fields"); final Object termListObj = termspec.get("terms"); // Support lookup via params or direct assignment if (analyzerNameObj != null) { if (analyzerNameObj instanceof String) { // Support direct assignment by prefixing analyzer with a bang if (((String)analyzerNameObj).startsWith("!")) { analyzerName = ((String) analyzerNameObj).substring(1); } else { analyzerName = (String) params.get(analyzerNameObj); } } } if (fieldsObj != null) { if (fieldsObj instanceof String) { fields = (ArrayList) params.get(fieldsObj); } else if (fieldsObj instanceof ArrayList) { fields = (ArrayList) fieldsObj; } } if (termListObj != null) { if (termListObj instanceof String) { termList = (ArrayList) params.get(termListObj); } else if (termListObj instanceof ArrayList) { termList = (ArrayList) termListObj; } } if (fields == null || termList == null) { throw new IllegalArgumentException("Term Stats injection requires fields and terms"); } Analyzer analyzer = null; for(String field : fields) { if (analyzerName == null) { final MappedFieldType fieldType = context.getQueryShardContext().getFieldType(field); analyzer = fieldType.getTextSearchInfo().getSearchAnalyzer(); } else { analyzer = context.getQueryShardContext().getIndexAnalyzers().get(analyzerName); } if (analyzer == null) { throw new IllegalArgumentException("No analyzer found for [" + analyzerName + "]"); } for (String termString : termList) { final TokenStream ts = analyzer.tokenStream(field, termString); final TermToBytesRefAttribute termAtt = ts.getAttribute(TermToBytesRefAttribute.class); try { ts.reset(); while (ts.incrementToken()) { terms.add(new Term(field, termAtt.getBytesRef())); } ts.close(); } catch (IOException ex) { // No-op } } } nparams.put(TERM_STAT, termstatSupplier); nparams.put(MATCH_COUNT, termstatSupplier.getMatchedTermCountSupplier()); nparams.put(UNIQUE_TERMS, terms.size()); } nparams.putAll(baseScriptParams); nparams.putAll(queryTimeParams); nparams.putAll(extraQueryTimeParams); nparams.put(FEATURE_VECTOR, supplier); nparams.put(EXTRA_LOGGING, extraLoggingSupplier); Script script = new Script(this.script.getType(), this.script.getLang(), this.script.getIdOrCode(), this.script.getOptions(), nparams); ScoreScript.Factory factoryFactory = context.getQueryShardContext().compile(script, ScoreScript.CONTEXT); ScoreScript.LeafFactory leafFactory = factoryFactory.newFactory(nparams, context.getQueryShardContext().lookup()); ScriptScoreFunction function = new ScriptScoreFunction(script, leafFactory, context.getQueryShardContext().index().getName(), context.getQueryShardContext().getShardId(), context.getQueryShardContext().indexVersionCreated(), null //TODO: this is different from ES LTR ); return new LtrScript(function, supplier, extraLoggingSupplier, termstatSupplier, terms); } static class LtrScript extends Query implements LtrRewritableQuery { private final ScriptScoreFunction function; private final FeatureSupplier supplier; private final ExtraLoggingSupplier extraLoggingSupplier; private final TermStatSupplier termStatSupplier; private final Set terms; LtrScript(ScriptScoreFunction function, FeatureSupplier supplier, ExtraLoggingSupplier extraLoggingSupplier, TermStatSupplier termStatSupplier, Set terms) { this.function = function; this.supplier = supplier; this.extraLoggingSupplier = extraLoggingSupplier; this.termStatSupplier = termStatSupplier; this.terms = terms; } @Override public boolean equals(Object o) { if (this == o) return true; LtrScript ol = (LtrScript) o; return sameClassAs(o) && Objects.equals(function, ol.function); } @Override public int hashCode() { return Objects.hash(classHash(), function); } @Override public String toString(String field) { return "LtrScript:" + field; } @Override public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { if (!scoreMode.needsScores()) { return new MatchAllDocsQuery().createWeight(searcher, scoreMode, 1F); } return new LtrScriptWeight(this, this.function, termStatSupplier, terms, searcher, scoreMode); } @Override public Query ltrRewrite(LtrRewriteContext context) throws IOException { supplier.set(context.getFeatureVectorSupplier()); LogLtrRanker.LogConsumer consumer = context.getLogConsumer(); if (consumer != null) { extraLoggingSupplier.setSupplier(consumer::getExtraLoggingMap); } else { extraLoggingSupplier.setSupplier(() -> null); } return this; } @Override public void visit(QueryVisitor visitor) { Set fields = terms.stream().map(Term::field).collect(Collectors.toUnmodifiableSet()); for (String field : fields) { if (visitor.acceptField(field) == false) { return; } } visitor.getSubVisitor(BooleanClause.Occur.SHOULD, this).consumeTerms(this, terms.toArray(new Term[0])); } } static class LtrScriptWeight extends Weight { private final IndexSearcher searcher; private final ScoreMode scoreMode; private final ScriptScoreFunction function; private final TermStatSupplier termStatSupplier; private final Set terms; private final HashMap termContexts; LtrScriptWeight(Query query, ScriptScoreFunction function, TermStatSupplier termStatSupplier, Set terms, IndexSearcher searcher, ScoreMode scoreMode) throws IOException { super(query); this.function = function; this.termStatSupplier = termStatSupplier; this.terms = terms; this.searcher = searcher; this.scoreMode = scoreMode; this.termContexts = new HashMap<>(); if (scoreMode.needsScores()) { for (Term t : terms) { TermStates ctx = TermStates.build(searcher.getTopReaderContext(), t, true); if (ctx != null && ctx.docFreq() > 0) { searcher.collectionStatistics(t.field()); searcher.termStatistics(t, ctx.docFreq(), ctx.totalTermFreq()); } termContexts.put(t, ctx); } } } @Override public Explanation explain(LeafReaderContext context, int doc) throws IOException { return function.getLeafScoreFunction(context).explainScore(doc, Explanation.noMatch("none")); } @Override public Scorer scorer(LeafReaderContext context) throws IOException { LeafScoreFunction leafScoreFunction = function.getLeafScoreFunction(context); DocIdSetIterator iterator = DocIdSetIterator.all(context.reader().maxDoc()); return new Scorer(this) { @Override public int docID() { return iterator.docID(); } @Override public float score() throws IOException { // Do the terms magic if the user asked for it if (terms.size() > 0) { termStatSupplier.bump(searcher, context, docID(), terms, scoreMode, termContexts); } return (float) leafScoreFunction.score(iterator.docID(), 0F); } @Override public DocIdSetIterator iterator() { return iterator; } /** * Return the maximum score that documents between the last {@code target} * that this iterator was {@link #advanceShallow(int) shallow-advanced} to * included and {@code upTo} included. */ @Override public float getMaxScore(int upTo) throws IOException { //TODO?? return Float.POSITIVE_INFINITY; } }; } public void extractTerms(Set terms) { } @Override public boolean isCacheable(LeafReaderContext ctx) { // Never ever cache this query, its parent query is mutable return false; } } }