/* * Copyright [2016] Doug Turnbull * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ package com.o19s.es.ltr.query; import ciir.umass.edu.learning.DataPoint; import ciir.umass.edu.learning.RANKER_TYPE; import ciir.umass.edu.learning.RankList; import ciir.umass.edu.learning.Ranker; import ciir.umass.edu.learning.RankerFactory; import ciir.umass.edu.learning.RankerTrainer; import ciir.umass.edu.metric.NDCGScorer; import ciir.umass.edu.utilities.MyThreadPool; import com.o19s.es.ltr.feature.FeatureSet; import com.o19s.es.ltr.feature.PrebuiltFeature; import com.o19s.es.ltr.feature.PrebuiltFeatureSet; import com.o19s.es.ltr.feature.PrebuiltLtrModel; import com.o19s.es.ltr.ranker.LogLtrRanker; import com.o19s.es.ltr.ranker.LtrRanker; import com.o19s.es.ltr.ranker.normalizer.FeatureNormalizingRanker; import com.o19s.es.ltr.ranker.normalizer.Normalizer; import com.o19s.es.ltr.ranker.normalizer.StandardFeatureNormalizer; import com.o19s.es.ltr.ranker.ranklib.DenseProgramaticDataPoint; import com.o19s.es.ltr.ranker.ranklib.RanklibRanker; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; import org.apache.lucene.document.Field.Store; import org.apache.lucene.document.FieldType; import org.apache.lucene.index.IndexOptions; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.index.Term; import org.apache.lucene.misc.SweetSpotSimilarity; import org.apache.lucene.queries.BlendedTermQuery; import org.apache.lucene.search.Explanation; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.PhraseQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.Scorable; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.SimpleCollector; import org.apache.lucene.search.TermQuery; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.similarities.AfterEffectB; import org.apache.lucene.search.similarities.AxiomaticF3LOG; import org.apache.lucene.search.similarities.BM25Similarity; import org.apache.lucene.search.similarities.BasicModelG; import org.apache.lucene.search.similarities.BooleanSimilarity; import org.apache.lucene.search.similarities.ClassicSimilarity; import org.apache.lucene.search.similarities.DFISimilarity; import org.apache.lucene.search.similarities.DFRSimilarity; import org.apache.lucene.search.similarities.DistributionLL; import org.apache.lucene.search.similarities.IBSimilarity; import org.apache.lucene.search.similarities.IndependenceChiSquared; import org.apache.lucene.search.similarities.LMDirichletSimilarity; import org.apache.lucene.search.similarities.LMJelinekMercerSimilarity; import org.apache.lucene.search.similarities.LambdaDF; import org.apache.lucene.search.similarities.NormalizationH1; import org.apache.lucene.search.similarities.NormalizationH3; import org.apache.lucene.search.similarities.Similarity; import org.apache.lucene.search.similarities.TFIDFSimilarity; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.util.LuceneTestCase; import org.opensearch.common.lucene.search.function.FunctionScoreQuery; import org.opensearch.common.lucene.search.function.WeightFactorFunction; import org.junit.After; import org.junit.AfterClass; import org.junit.Before; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.SortedMap; import java.util.TreeMap; import java.util.stream.Collectors; @LuceneTestCase.SuppressSysoutChecks(bugUrl = "RankURL does this when training models... ") public class LtrQueryTests extends LuceneTestCase { // Number of ULPs allowed when checking scores equality private static final int SCORE_NB_ULP_PREC = 1; private int[] range(int start, int stop) { int[] result = new int[stop-start]; for(int i=0;i sims = Arrays.asList( new ClassicSimilarity(), new SweetSpotSimilarity(), // extends Classic new BM25Similarity(), new LMDirichletSimilarity(), new BooleanSimilarity(), new LMJelinekMercerSimilarity(0.2F), new AxiomaticF3LOG(0.5F, 10), new DFISimilarity(new IndependenceChiSquared()), new DFRSimilarity(new BasicModelG(), new AfterEffectB(), new NormalizationH1()), new IBSimilarity(new DistributionLL(), new LambdaDF(), new NormalizationH3()) ); similarity = sims.get(random().nextInt(sims.size())); indexWriterUnderTest = new RandomIndexWriter(random(), dirUnderTest, newIndexWriterConfig().setSimilarity(similarity)); for (int i = 0; i < docs.length; i++) { Document doc = new Document(); doc.add(newStringField("id", "" + i, Field.Store.YES)); doc.add(newField("field", docs[i], Store.YES)); indexWriterUnderTest.addDocument(doc); } indexWriterUnderTest.commit(); indexWriterUnderTest.forceMerge(1); indexWriterUnderTest.flush(); indexReaderUnderTest = indexWriterUnderTest.getReader(); searcherUnderTest = newSearcher(indexReaderUnderTest); searcherUnderTest.setSimilarity(similarity); } public Map> getFeatureScores(List features, final float missingScore) throws IOException { Map> featuresPerDoc = new HashMap<>(); FeatureSet set = new PrebuiltFeatureSet("test", features); Map collectedScores = new HashMap<>(); for (int i = 0; i < features.size(); i++) { collectedScores.put(i, missingScore); } LogLtrRanker.LogConsumer logger = new LogLtrRanker.LogConsumer() { @Override public void accept(int featureOrdinal, float score) { collectedScores.put(featureOrdinal, score); } @Override public void reset() { collectedScores.clear(); for (int i = 0; i < features.size(); i++) { collectedScores.put(i, missingScore); } } }; RankerQuery query = RankerQuery.buildLogQuery(logger, set, null, Collections.emptyMap()); searcherUnderTest.search(query, new SimpleCollector() { private LeafReaderContext context; private Scorable scorer; /** * Indicates what features are required from the scorer. */ @Override public ScoreMode scoreMode() { return ScoreMode.COMPLETE; } @Override public void setScorer(Scorable scorer) throws IOException { this.scorer = scorer; } @Override protected void doSetNextReader(LeafReaderContext context) throws IOException { this.context = context; } @Override public void collect(int doc) throws IOException { scorer.score(); Document d = context.reader().document(doc); featuresPerDoc.put(d.get("id"), new HashMap<>(collectedScores)); } }); return featuresPerDoc; } public List makeQueryJudgements(int qid, Map> featuresPerDoc, int modelSize, Float[] relevanceGradesPerDoc, Map ftrNorms) { assert(featuresPerDoc.size() == docs.length); assert(relevanceGradesPerDoc.length == docs.length); List rVal = new ArrayList<>(); SortedMap points = new TreeMap<>(); featuresPerDoc.forEach((doc, vector) -> { DenseProgramaticDataPoint dp = new DenseProgramaticDataPoint(modelSize); int docId = Integer.decode(doc); dp.setLabel(relevanceGradesPerDoc[docId]); dp.setID(String.valueOf(qid)); vector.forEach( (final Integer ftrOrd, Float score) -> { Normalizer ftrNorm = ftrNorms.get(ftrOrd); if (ftrNorm != null) { score = ftrNorm.normalize(score); } dp.setFeatureScore(ftrOrd, score); } ); points.put(docId, dp); }); points.forEach((k, v) -> rVal.add(v)); return rVal; } public void checkFeatureNames(Explanation expl, List features) { Explanation[] expls = expl.getDetails(); int ftrIdx = 0; for (Explanation ftrExpl: expls) { String ftrName = features.get(ftrIdx).name(); String expectedFtrName; if (ftrName == null) { expectedFtrName = "Feature " + ftrIdx + ":"; } else { expectedFtrName = "Feature " + ftrIdx + "(" + ftrName + "):"; } String ftrExplainStart = ftrExpl.getDescription().substring(0,expectedFtrName.length()); assertEquals(expectedFtrName, ftrExplainStart); ftrIdx++; } } public void checkModelWithFeatures(List features, int[] modelFeatures, Map ftrNorms) throws IOException { // Each RankList needed for training corresponds to one query, // or that apperas how RankLib wants the data List samples = new ArrayList<>(); Map> rawFeaturesPerDoc = getFeatureScores(features, 0.0f); if (ftrNorms == null) { ftrNorms = new HashMap<>(); } // Normalize prior to training // these ranklists have been normalized for training RankList rl = new RankList(makeQueryJudgements(0, rawFeaturesPerDoc, features.size(), new Float[] {3.0f, 2.0f, 4.0f, 0.0f}, ftrNorms)); samples.add(rl); int[] featuresToUse = modelFeatures; if (featuresToUse == null) { featuresToUse = range(1, features.size() + 1); } // each RankList appears to correspond to a // query RankerTrainer trainer = new RankerTrainer(); Ranker ranker = trainer.train(/*what type of model ot train*/RANKER_TYPE.RANKNET, /*The training data*/ samples /*which features to use*/, featuresToUse /*how to score ranking*/, new NDCGScorer()); float[] scores = {(float)ranker.eval(rl.get(0)), (float)ranker.eval(rl.get(1)), (float)ranker.eval(rl.get(2)), (float)ranker.eval(rl.get(3))}; // Ok now lets rerun that as a Lucene Query RankerQuery ltrQuery = toRankerQuery(features, ranker, ftrNorms); TopDocs topDocs = searcherUnderTest.search(ltrQuery, 10); ScoreDoc[] scoreDocs = topDocs.scoreDocs; assert(scoreDocs.length == docs.length); ScoreDoc sc = scoreDocs[0]; scoreDocs[0] = scoreDocs[2]; scoreDocs[2] = sc; for (ScoreDoc scoreDoc: scoreDocs) { assertScoresMatch(features, scores, ltrQuery, scoreDoc); } // Try again with a model serialized String modelAsStr = ranker.model(); RankerFactory rankerFactory = new RankerFactory(); Ranker rankerAgain = rankerFactory.loadRankerFromString(modelAsStr); float[] scoresAgain = {(float)ranker.eval(rl.get(0)), (float)ranker.eval(rl.get(1)), (float)ranker.eval(rl.get(2)), (float)ranker.eval(rl.get(3))}; topDocs = searcherUnderTest.search(ltrQuery, 10); scoreDocs = topDocs.scoreDocs; assert(scoreDocs.length == docs.length); for (ScoreDoc scoreDoc: scoreDocs) { assertScoresMatch(features, scoresAgain, ltrQuery, scoreDoc); } } private void assertScoresMatch(List features, float[] scores, RankerQuery ltrQuery, ScoreDoc scoreDoc) throws IOException { Document d = searcherUnderTest.doc(scoreDoc.doc); String idVal = d.get("id"); int docId = Integer.decode(idVal); float modelScore = scores[docId]; float queryScore = scoreDoc.score; assertEquals("Scores match with similarity " + similarity.getClass(), modelScore, queryScore, SCORE_NB_ULP_PREC *Math.ulp(modelScore)); if (!(similarity instanceof TFIDFSimilarity)) { // There are precision issues with these similarities when using explain // It produces 0.56103003 for feat:0 in doc1 using score() but 0.5610301 using explain Explanation expl = searcherUnderTest.explain(ltrQuery, docId); assertEquals("Explain scores match with similarity " + similarity.getClass(), expl.getValue().floatValue(), queryScore, 5 * Math.ulp(modelScore)); checkFeatureNames(expl, features); } } private RankerQuery toRankerQuery(List features, Ranker ranker, Map ftrNorms) { LtrRanker ltrRanker = new RanklibRanker(ranker, features.size()); if (ftrNorms.size() > 0) { ltrRanker = new FeatureNormalizingRanker(ltrRanker, ftrNorms); } PrebuiltLtrModel model = new PrebuiltLtrModel(ltrRanker.name(), ltrRanker, new PrebuiltFeatureSet(null, features)); return RankerQuery.build(model); } public void testTrainModel() throws IOException { String userQuery = "brown cow"; List features = Arrays.asList( new TermQuery(new Term("field", userQuery.split(" ")[0])), new PhraseQuery("field", userQuery.split(" "))); checkModelWithFeatures(toPrebuildFeatureWithNoName(features), null, null); } public void testSubsetFeaturesFuncScore() throws IOException { // public LambdaMART(List samples, int[] features, MetricScorer scorer) { String userQuery = "brown cow"; Query baseQuery = new MatchAllDocsQuery(); List features = Arrays.asList( new TermQuery(new Term("field", userQuery.split(" ")[0])), new PhraseQuery("field", userQuery.split(" ")), new FunctionScoreQuery(baseQuery, new WeightFactorFunction(1.0f)) ); checkModelWithFeatures(toPrebuildFeatureWithNoName(features), new int[] {1}, null); } public void testSubsetFeaturesTermQ() throws IOException { // public LambdaMART(List samples, int[] features, MetricScorer scorer) { String userQuery = "brown cow"; Query baseQuery = new MatchAllDocsQuery(); List features = Arrays.asList( new TermQuery(new Term("field", userQuery.split(" ")[0])), new PhraseQuery("field", userQuery.split(" ")), new PhraseQuery(1, "field", userQuery.split(" ") )); checkModelWithFeatures(toPrebuildFeatureWithNoName(features), new int[] {1}, null); } public void testExplainWithNames() throws IOException { // public LambdaMART(List samples, int[] features, MetricScorer scorer) { String userQuery = "brown cow"; List features = Arrays.asList( new PrebuiltFeature("funky_term_q", new TermQuery(new Term("field", userQuery.split(" ")[0]))), new PrebuiltFeature("funky_phrase_q", new PhraseQuery("field", userQuery.split(" ")))); checkModelWithFeatures(features, null, null); } public void testOnRewrittenQueries() throws IOException { String userQuery = "brown cow"; Term[] termsToBlend = new Term[]{new Term("field", userQuery.split(" ")[0])}; Query blended = BlendedTermQuery.dismaxBlendedQuery(termsToBlend, 1f); List features = Arrays.asList(new TermQuery(new Term("field", userQuery.split(" ")[0])), blended); checkModelWithFeatures(toPrebuildFeatureWithNoName(features), null, null); } private List toPrebuildFeatureWithNoName(List features) { return features.stream() .map(x -> new PrebuiltFeature(null, x)) .collect(Collectors.toList()); } public void testNoMatchQueries() throws IOException { String userQuery = "brown cow"; Term[] termsToBlend = new Term[]{new Term("field", userQuery.split(" ")[0])}; Query blended = BlendedTermQuery.dismaxBlendedQuery(termsToBlend, 1f); List features = Arrays.asList( new PrebuiltFeature(null, new TermQuery(new Term("field", "missingterm"))), new PrebuiltFeature(null, blended)); checkModelWithFeatures(features, null, null); } public void testMatchingNormalizedQueries() throws IOException { String userQuery = "brown cow"; List features = Arrays.asList( new PrebuiltFeature(null, new TermQuery(new Term("field", "brown"))), new PrebuiltFeature(null, new TermQuery(new Term("field", "cow")))); Map ftrNorms = new HashMap<>(); ftrNorms.put(0, new StandardFeatureNormalizer(1, 0.5f)); ftrNorms.put(1, new StandardFeatureNormalizer(1, 0.5f)); checkModelWithFeatures(features, null, ftrNorms); } public void testNoMatchNormalizedQueries() throws IOException { List features = Arrays.asList( new PrebuiltFeature(null, new TermQuery(new Term("field", "missingterm"))), new PrebuiltFeature(null, new TermQuery(new Term("field", "othermissingterm")))); Map ftrNorms = new HashMap<>(); ftrNorms.put(0, new StandardFeatureNormalizer(0.5f, 1)); ftrNorms.put(1, new StandardFeatureNormalizer(0.7f, 0.2f)); checkModelWithFeatures(features, null, ftrNorms); } @After public void closeStuff() throws IOException { indexReaderUnderTest.close(); indexWriterUnderTest.close(); dirUnderTest.close(); // Ranklib's singleton instance } @AfterClass public static void closeOtherStuff() { MyThreadPool.getInstance().shutdown(); } }