/* * SPDX-License-Identifier: Apache-2.0 * * The OpenSearch Contributors require contributions made to * this file be licensed under the Apache-2.0 license or a * compatible open source license. */ /* * Licensed to Elasticsearch under one or more contributor * license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright * ownership. Elasticsearch licenses this file to you under * the Apache License, Version 2.0 (the "License"); you may * not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /* * Modifications Copyright OpenSearch Contributors. See * GitHub history for details. */ package org.opensearch.index.rankeval; import org.opensearch.action.OriginalIndices; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParseException; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.common.xcontent.XContentType; import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.core.index.shard.ShardId; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchShardTarget; import org.opensearch.test.OpenSearchTestCase; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.List; import static org.opensearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashCode; import static org.opensearch.test.XContentTestUtils.insertRandomFields; import static org.hamcrest.CoreMatchers.containsString; public class ExpectedReciprocalRankTests extends OpenSearchTestCase { private static final double DELTA = 10E-14; public void testProbabilityOfRelevance() { ExpectedReciprocalRank err = new ExpectedReciprocalRank(5); assertEquals(0.0, err.probabilityOfRelevance(0), 0.0); assertEquals(1d / 32d, err.probabilityOfRelevance(1), 0.0); assertEquals(3d / 32d, err.probabilityOfRelevance(2), 0.0); assertEquals(7d / 32d, err.probabilityOfRelevance(3), 0.0); assertEquals(15d / 32d, err.probabilityOfRelevance(4), 0.0); assertEquals(31d / 32d, err.probabilityOfRelevance(5), 0.0); } /** * Assuming the result ranking is * *
{@code
     * rank | relevance | probR / r | p        | p * probR / r
     * -------------------------------------------------------
     * 1    | 3         | 0.875     | 1        | 0.875       |
     * 2    | 2         | 0.1875    | 0.125    | 0.0234375   |
     * 3    | 0         | 0         | 0.078125 | 0           |
     * 4    | 1         | 0.03125   | 0.078125 | 0.00244140625 |
     * }
* * err = sum of last column */ public void testERRAt() { List rated = new ArrayList<>(); Integer[] relevanceRatings = new Integer[] { 3, 2, 0, 1 }; SearchHit[] hits = createSearchHits(rated, relevanceRatings); ExpectedReciprocalRank err = new ExpectedReciprocalRank(3, 0, 3); assertEquals(0.8984375, err.evaluate("id", hits, rated).metricScore(), DELTA); // take 4th rank into window err = new ExpectedReciprocalRank(3, 0, 4); assertEquals(0.8984375 + 0.00244140625, err.evaluate("id", hits, rated).metricScore(), DELTA); } /** * Assuming the result ranking is * *
{@code
     * rank | relevance | probR / r | p        | p * probR / r
     * -------------------------------------------------------
     * 1    | 3         | 0.875     | 1        | 0.875       |
     * 2    | n/a       | n/a       | 0.125    | n/a   |
     * 3    | 0         | 0         | 0.125    | 0           |
     * 4    | 1         | 0.03125   | 0.125    | 0.00390625 |
     * }
* * err = sum of last column */ public void testERRMissingRatings() { List rated = new ArrayList<>(); Integer[] relevanceRatings = new Integer[] { 3, null, 0, 1 }; SearchHit[] hits = createSearchHits(rated, relevanceRatings); ExpectedReciprocalRank err = new ExpectedReciprocalRank(3, null, 4); EvalQueryQuality evaluation = err.evaluate("id", hits, rated); assertEquals(0.875 + 0.00390625, evaluation.metricScore(), DELTA); assertEquals(1, ((ExpectedReciprocalRank.Detail) evaluation.getMetricDetails()).getUnratedDocs()); // if we supply e.g. 2 as unknown docs rating, it should be the same as in the other test above err = new ExpectedReciprocalRank(3, 2, 4); assertEquals(0.8984375 + 0.00244140625, err.evaluate("id", hits, rated).metricScore(), DELTA); } private SearchHit[] createSearchHits(List rated, Integer[] relevanceRatings) { SearchHit[] hits = new SearchHit[relevanceRatings.length]; for (int i = 0; i < relevanceRatings.length; i++) { if (relevanceRatings[i] != null) { rated.add(new RatedDocument("index", Integer.toString(i), relevanceRatings[i])); } hits[i] = new SearchHit(i, Integer.toString(i), Collections.emptyMap(), Collections.emptyMap()); hits[i].shard(new SearchShardTarget("testnode", new ShardId("index", "uuid", 0), null, OriginalIndices.NONE)); } return hits; } /** * test that metric returns 0.0 when there are no search results */ public void testNoResults() throws Exception { ExpectedReciprocalRank err = new ExpectedReciprocalRank(5, 0, 10); assertEquals(0.0, err.evaluate("id", new SearchHit[0], Collections.emptyList()).metricScore(), DELTA); } public void testParseFromXContent() throws IOException { assertParsedCorrect("{ \"unknown_doc_rating\": 2, \"maximum_relevance\": 5, \"k\" : 15 }", 2, 5, 15); assertParsedCorrect("{ \"unknown_doc_rating\": 2, \"maximum_relevance\": 4 }", 2, 4, 10); assertParsedCorrect("{ \"maximum_relevance\": 4, \"k\": 23 }", null, 4, 23); } private void assertParsedCorrect(String xContent, Integer expectedUnknownDocRating, int expectedMaxRelevance, int expectedK) throws IOException { try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) { ExpectedReciprocalRank errAt = ExpectedReciprocalRank.fromXContent(parser); assertEquals(expectedUnknownDocRating, errAt.getUnknownDocRating()); assertEquals(expectedK, errAt.getK()); assertEquals(expectedMaxRelevance, errAt.getMaxRelevance()); } } public static ExpectedReciprocalRank createTestItem() { Integer unknownDocRating = frequently() ? Integer.valueOf(randomIntBetween(0, 10)) : null; int maxRelevance = randomIntBetween(1, 10); return new ExpectedReciprocalRank(maxRelevance, unknownDocRating, randomIntBetween(1, 10)); } public void testXContentRoundtrip() throws IOException { ExpectedReciprocalRank testItem = createTestItem(); XContentBuilder builder = MediaTypeRegistry.contentBuilder(randomFrom(XContentType.values())); XContentBuilder shuffled = shuffleXContent(testItem.toXContent(builder, ToXContent.EMPTY_PARAMS)); try (XContentParser itemParser = createParser(shuffled)) { itemParser.nextToken(); itemParser.nextToken(); ExpectedReciprocalRank parsedItem = ExpectedReciprocalRank.fromXContent(itemParser); assertNotSame(testItem, parsedItem); assertEquals(testItem, parsedItem); assertEquals(testItem.hashCode(), parsedItem.hashCode()); } } public void testXContentParsingIsNotLenient() throws IOException { ExpectedReciprocalRank testItem = createTestItem(); XContentType xContentType = randomFrom(XContentType.values()); BytesReference originalBytes = toShuffledXContent(testItem, xContentType, ToXContent.EMPTY_PARAMS, randomBoolean()); BytesReference withRandomFields = insertRandomFields(xContentType, originalBytes, null, random()); try (XContentParser parser = createParser(xContentType.xContent(), withRandomFields)) { parser.nextToken(); parser.nextToken(); XContentParseException exception = expectThrows( XContentParseException.class, () -> DiscountedCumulativeGain.fromXContent(parser) ); assertThat(exception.getMessage(), containsString("[dcg] unknown field")); } } public void testMetricDetails() { int unratedDocs = randomIntBetween(0, 100); ExpectedReciprocalRank.Detail detail = new ExpectedReciprocalRank.Detail(unratedDocs); assertEquals(unratedDocs, detail.getUnratedDocs()); } public void testSerialization() throws IOException { ExpectedReciprocalRank original = createTestItem(); ExpectedReciprocalRank deserialized = OpenSearchTestCase.copyWriteable( original, new NamedWriteableRegistry(Collections.emptyList()), ExpectedReciprocalRank::new ); assertEquals(deserialized, original); assertEquals(deserialized.hashCode(), original.hashCode()); assertNotSame(deserialized, original); } public void testEqualsAndHash() throws IOException { checkEqualsAndHashCode(createTestItem(), original -> { return new ExpectedReciprocalRank(original.getMaxRelevance(), original.getUnknownDocRating(), original.getK()); }, ExpectedReciprocalRankTests::mutateTestItem); } private static ExpectedReciprocalRank mutateTestItem(ExpectedReciprocalRank original) { switch (randomIntBetween(0, 2)) { case 0: return new ExpectedReciprocalRank(original.getMaxRelevance() + 1, original.getUnknownDocRating(), original.getK()); case 1: return new ExpectedReciprocalRank( original.getMaxRelevance(), randomValueOtherThan(original.getUnknownDocRating(), () -> randomIntBetween(0, 10)), original.getK() ); case 2: return new ExpectedReciprocalRank( original.getMaxRelevance(), original.getUnknownDocRating(), randomValueOtherThan(original.getK(), () -> randomIntBetween(1, 10)) ); default: throw new IllegalArgumentException("mutation variant not allowed"); } } }