/* * Copyright [2017] Wikimedia Foundation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.o19s.es.ltr.logging; import com.o19s.es.ltr.feature.FeatureSet; import com.o19s.es.ltr.query.RankerQuery; import com.o19s.es.ltr.ranker.LogLtrRanker; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.BoostQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; import org.apache.lucene.search.Weight; import org.opensearch.common.collect.Tuple; import org.opensearch.common.document.DocumentField; import org.opensearch.search.SearchHit; import org.opensearch.search.fetch.FetchContext; import org.opensearch.search.fetch.FetchSubPhase; import org.opensearch.search.fetch.FetchSubPhaseProcessor; import org.opensearch.search.rescore.QueryRescorer; import org.opensearch.search.rescore.RescoreContext; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; public class LoggingFetchSubPhase implements FetchSubPhase { @Override public FetchSubPhaseProcessor getProcessor(FetchContext context) throws IOException { LoggingSearchExtBuilder ext = (LoggingSearchExtBuilder) context.getSearchExt(LoggingSearchExtBuilder.NAME); if (ext == null) { return null; } BooleanQuery.Builder builder = new BooleanQuery.Builder(); List loggers = new ArrayList<>(); Map namedQueries = context.parsedQuery().namedFilters(); if (namedQueries.size() > 0) { ext.logSpecsStream().filter((l) -> l.getNamedQuery() != null).forEach((l) -> { Tuple query = extractQuery(l, namedQueries); builder.add(new BooleanClause(query.v1(), BooleanClause.Occur.MUST)); loggers.add(query.v2()); }); ext.logSpecsStream().filter((l) -> l.getRescoreIndex() != null).forEach((l) -> { Tuple query = extractRescore(l, context.rescore()); builder.add(new BooleanClause(query.v1(), BooleanClause.Occur.MUST)); loggers.add(query.v2()); }); } Weight w = context.searcher().rewrite(builder.build()).createWeight(context.searcher(), ScoreMode.COMPLETE, 1.0F); return new LoggingFetchSubPhaseProcessor(w, loggers); } private Tuple extractQuery(LoggingSearchExtBuilder.LogSpec logSpec, Map namedQueries) { Query q = namedQueries.get(logSpec.getNamedQuery()); if (q == null) { throw new IllegalArgumentException("No query named [" + logSpec.getNamedQuery() + "] found"); } return toLogger(logSpec, inspectQuery(q) .orElseThrow(() -> new IllegalArgumentException("Query named [" + logSpec.getNamedQuery() + "] must be a [sltr] query [" + ((q instanceof BoostQuery) ? ((BoostQuery) q).getQuery().getClass().getSimpleName( ) : q.getClass().getSimpleName()) + "] found"))); } private Tuple extractRescore(LoggingSearchExtBuilder.LogSpec logSpec, List contexts) { if (logSpec.getRescoreIndex() >= contexts.size()) { throw new IllegalArgumentException("rescore index [" + logSpec.getRescoreIndex() + "] is out of bounds, only " + "[" + contexts.size() + "] rescore context(s) are available"); } RescoreContext context = contexts.get(logSpec.getRescoreIndex()); if (!(context instanceof QueryRescorer.QueryRescoreContext)) { throw new IllegalArgumentException("Expected a [QueryRescoreContext] but found a " + "[" + context.getClass().getSimpleName() + "] " + "at index [" + logSpec.getRescoreIndex() + "]"); } QueryRescorer.QueryRescoreContext qrescore = (QueryRescorer.QueryRescoreContext) context; return toLogger(logSpec, inspectQuery(qrescore.query()) .orElseThrow(() -> new IllegalArgumentException("Expected a [sltr] query but found a " + "[" + qrescore.query().getClass().getSimpleName() + "] " + "at index [" + logSpec.getRescoreIndex() + "]"))); } private Optional inspectQuery(Query q) { if (q instanceof RankerQuery) { return Optional.of((RankerQuery) q); } else if (q instanceof BoostQuery && ((BoostQuery) q).getQuery() instanceof RankerQuery) { return Optional.of((RankerQuery) ((BoostQuery) q).getQuery()); } return Optional.empty(); } private Tuple toLogger(LoggingSearchExtBuilder.LogSpec logSpec, RankerQuery query) { HitLogConsumer consumer = new HitLogConsumer(logSpec.getLoggerName(), query.featureSet(), logSpec.isMissingAsZero()); query = query.toLoggerQuery(consumer); return new Tuple<>(query, consumer); } static class LoggingFetchSubPhaseProcessor implements FetchSubPhaseProcessor { private final Weight weight; private final List loggers; private Scorer scorer; LoggingFetchSubPhaseProcessor(Weight weight, List loggers) { this.weight = weight; this.loggers = loggers; } @Override public void setNextReader(LeafReaderContext readerContext) throws IOException { scorer = weight.scorer(readerContext); } @Override public void process(HitContext hitContext) throws IOException { if (scorer != null && scorer.iterator().advance(hitContext.docId()) == hitContext.docId()) { loggers.forEach((l) -> l.nextDoc(hitContext.hit())); // Scoring will trigger log collection scorer.score(); } } } static class HitLogConsumer implements LogLtrRanker.LogConsumer { private static final String FIELD_NAME = "_ltrlog"; private static final String EXTRA_LOGGING_NAME = "extra_logging"; private final String name; private final FeatureSet set; private final boolean missingAsZero; // [ // { // "name": "featureName", // "value": 1.33 // }, // { // "name": "otherFeatureName", // } // ] private List> currentLog; private SearchHit currentHit; private Map extraLogging; HitLogConsumer(String name, FeatureSet set, boolean missingAsZero) { this.name = name; this.set = set; this.missingAsZero = missingAsZero; } private void rebuild() { // Allocate one Map per feature, plus one placeholder for an extra logging Map // that will only be added if used. List> ini = new ArrayList<>(set.size() + 1); for (int i = 0; i < set.size(); i++) { Map defaultKeyVal = new HashMap<>(); defaultKeyVal.put("name", set.feature(i).name()); if (missingAsZero) { defaultKeyVal.put("value", 0.0F); } ini.add(i, defaultKeyVal); } currentLog = ini; extraLogging = null; } @Override public void accept(int featureOrdinal, float score) { assert currentLog != null; assert currentHit != null; currentLog.get(featureOrdinal).put("value", score); } /** * Return Map to store additional logging information returned with the feature values. *

* The Map is created on first access. */ @Override public Map getExtraLoggingMap() { if (extraLogging == null) { extraLogging = new HashMap<>(); Map logEntry = new HashMap<>(); logEntry.put("name", EXTRA_LOGGING_NAME); logEntry.put("value", extraLogging); currentLog.add(logEntry); } return extraLogging; } void nextDoc(SearchHit hit) { DocumentField logs = hit.getFields().get(FIELD_NAME); if (logs == null) { logs = newLogField(); hit.setDocumentField(FIELD_NAME, logs); } Map>> entries = logs.getValue(); rebuild(); currentHit = hit; entries.put(name, currentLog); } DocumentField newLogField() { List logList = Collections.singletonList(new HashMap>>()); return new DocumentField(FIELD_NAME, logList); } } }