/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/
package org.opensearch.index.rankeval;
import org.opensearch.action.OriginalIndices;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParseException;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.common.xcontent.json.JsonXContent;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchShardTarget;
import org.opensearch.test.OpenSearchTestCase;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import static org.opensearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashCode;
import static org.opensearch.test.XContentTestUtils.insertRandomFields;
import static org.hamcrest.CoreMatchers.containsString;
public class ExpectedReciprocalRankTests extends OpenSearchTestCase {
private static final double DELTA = 10E-14;
public void testProbabilityOfRelevance() {
ExpectedReciprocalRank err = new ExpectedReciprocalRank(5);
assertEquals(0.0, err.probabilityOfRelevance(0), 0.0);
assertEquals(1d / 32d, err.probabilityOfRelevance(1), 0.0);
assertEquals(3d / 32d, err.probabilityOfRelevance(2), 0.0);
assertEquals(7d / 32d, err.probabilityOfRelevance(3), 0.0);
assertEquals(15d / 32d, err.probabilityOfRelevance(4), 0.0);
assertEquals(31d / 32d, err.probabilityOfRelevance(5), 0.0);
}
/**
* Assuming the result ranking is
*
*
{@code
* rank | relevance | probR / r | p | p * probR / r
* -------------------------------------------------------
* 1 | 3 | 0.875 | 1 | 0.875 |
* 2 | 2 | 0.1875 | 0.125 | 0.0234375 |
* 3 | 0 | 0 | 0.078125 | 0 |
* 4 | 1 | 0.03125 | 0.078125 | 0.00244140625 |
* }
*
* err = sum of last column
*/
public void testERRAt() {
List rated = new ArrayList<>();
Integer[] relevanceRatings = new Integer[] { 3, 2, 0, 1 };
SearchHit[] hits = createSearchHits(rated, relevanceRatings);
ExpectedReciprocalRank err = new ExpectedReciprocalRank(3, 0, 3);
assertEquals(0.8984375, err.evaluate("id", hits, rated).metricScore(), DELTA);
// take 4th rank into window
err = new ExpectedReciprocalRank(3, 0, 4);
assertEquals(0.8984375 + 0.00244140625, err.evaluate("id", hits, rated).metricScore(), DELTA);
}
/**
* Assuming the result ranking is
*
* {@code
* rank | relevance | probR / r | p | p * probR / r
* -------------------------------------------------------
* 1 | 3 | 0.875 | 1 | 0.875 |
* 2 | n/a | n/a | 0.125 | n/a |
* 3 | 0 | 0 | 0.125 | 0 |
* 4 | 1 | 0.03125 | 0.125 | 0.00390625 |
* }
*
* err = sum of last column
*/
public void testERRMissingRatings() {
List rated = new ArrayList<>();
Integer[] relevanceRatings = new Integer[] { 3, null, 0, 1 };
SearchHit[] hits = createSearchHits(rated, relevanceRatings);
ExpectedReciprocalRank err = new ExpectedReciprocalRank(3, null, 4);
EvalQueryQuality evaluation = err.evaluate("id", hits, rated);
assertEquals(0.875 + 0.00390625, evaluation.metricScore(), DELTA);
assertEquals(1, ((ExpectedReciprocalRank.Detail) evaluation.getMetricDetails()).getUnratedDocs());
// if we supply e.g. 2 as unknown docs rating, it should be the same as in the other test above
err = new ExpectedReciprocalRank(3, 2, 4);
assertEquals(0.8984375 + 0.00244140625, err.evaluate("id", hits, rated).metricScore(), DELTA);
}
private SearchHit[] createSearchHits(List rated, Integer[] relevanceRatings) {
SearchHit[] hits = new SearchHit[relevanceRatings.length];
for (int i = 0; i < relevanceRatings.length; i++) {
if (relevanceRatings[i] != null) {
rated.add(new RatedDocument("index", Integer.toString(i), relevanceRatings[i]));
}
hits[i] = new SearchHit(i, Integer.toString(i), Collections.emptyMap(), Collections.emptyMap());
hits[i].shard(new SearchShardTarget("testnode", new ShardId("index", "uuid", 0), null, OriginalIndices.NONE));
}
return hits;
}
/**
* test that metric returns 0.0 when there are no search results
*/
public void testNoResults() throws Exception {
ExpectedReciprocalRank err = new ExpectedReciprocalRank(5, 0, 10);
assertEquals(0.0, err.evaluate("id", new SearchHit[0], Collections.emptyList()).metricScore(), DELTA);
}
public void testParseFromXContent() throws IOException {
assertParsedCorrect("{ \"unknown_doc_rating\": 2, \"maximum_relevance\": 5, \"k\" : 15 }", 2, 5, 15);
assertParsedCorrect("{ \"unknown_doc_rating\": 2, \"maximum_relevance\": 4 }", 2, 4, 10);
assertParsedCorrect("{ \"maximum_relevance\": 4, \"k\": 23 }", null, 4, 23);
}
private void assertParsedCorrect(String xContent, Integer expectedUnknownDocRating, int expectedMaxRelevance, int expectedK)
throws IOException {
try (XContentParser parser = createParser(JsonXContent.jsonXContent, xContent)) {
ExpectedReciprocalRank errAt = ExpectedReciprocalRank.fromXContent(parser);
assertEquals(expectedUnknownDocRating, errAt.getUnknownDocRating());
assertEquals(expectedK, errAt.getK());
assertEquals(expectedMaxRelevance, errAt.getMaxRelevance());
}
}
public static ExpectedReciprocalRank createTestItem() {
Integer unknownDocRating = frequently() ? Integer.valueOf(randomIntBetween(0, 10)) : null;
int maxRelevance = randomIntBetween(1, 10);
return new ExpectedReciprocalRank(maxRelevance, unknownDocRating, randomIntBetween(1, 10));
}
public void testXContentRoundtrip() throws IOException {
ExpectedReciprocalRank testItem = createTestItem();
XContentBuilder builder = MediaTypeRegistry.contentBuilder(randomFrom(XContentType.values()));
XContentBuilder shuffled = shuffleXContent(testItem.toXContent(builder, ToXContent.EMPTY_PARAMS));
try (XContentParser itemParser = createParser(shuffled)) {
itemParser.nextToken();
itemParser.nextToken();
ExpectedReciprocalRank parsedItem = ExpectedReciprocalRank.fromXContent(itemParser);
assertNotSame(testItem, parsedItem);
assertEquals(testItem, parsedItem);
assertEquals(testItem.hashCode(), parsedItem.hashCode());
}
}
public void testXContentParsingIsNotLenient() throws IOException {
ExpectedReciprocalRank testItem = createTestItem();
XContentType xContentType = randomFrom(XContentType.values());
BytesReference originalBytes = toShuffledXContent(testItem, xContentType, ToXContent.EMPTY_PARAMS, randomBoolean());
BytesReference withRandomFields = insertRandomFields(xContentType, originalBytes, null, random());
try (XContentParser parser = createParser(xContentType.xContent(), withRandomFields)) {
parser.nextToken();
parser.nextToken();
XContentParseException exception = expectThrows(
XContentParseException.class,
() -> DiscountedCumulativeGain.fromXContent(parser)
);
assertThat(exception.getMessage(), containsString("[dcg] unknown field"));
}
}
public void testMetricDetails() {
int unratedDocs = randomIntBetween(0, 100);
ExpectedReciprocalRank.Detail detail = new ExpectedReciprocalRank.Detail(unratedDocs);
assertEquals(unratedDocs, detail.getUnratedDocs());
}
public void testSerialization() throws IOException {
ExpectedReciprocalRank original = createTestItem();
ExpectedReciprocalRank deserialized = OpenSearchTestCase.copyWriteable(
original,
new NamedWriteableRegistry(Collections.emptyList()),
ExpectedReciprocalRank::new
);
assertEquals(deserialized, original);
assertEquals(deserialized.hashCode(), original.hashCode());
assertNotSame(deserialized, original);
}
public void testEqualsAndHash() throws IOException {
checkEqualsAndHashCode(createTestItem(), original -> {
return new ExpectedReciprocalRank(original.getMaxRelevance(), original.getUnknownDocRating(), original.getK());
}, ExpectedReciprocalRankTests::mutateTestItem);
}
private static ExpectedReciprocalRank mutateTestItem(ExpectedReciprocalRank original) {
switch (randomIntBetween(0, 2)) {
case 0:
return new ExpectedReciprocalRank(original.getMaxRelevance() + 1, original.getUnknownDocRating(), original.getK());
case 1:
return new ExpectedReciprocalRank(
original.getMaxRelevance(),
randomValueOtherThan(original.getUnknownDocRating(), () -> randomIntBetween(0, 10)),
original.getK()
);
case 2:
return new ExpectedReciprocalRank(
original.getMaxRelevance(),
original.getUnknownDocRating(),
randomValueOtherThan(original.getK(), () -> randomIntBetween(1, 10))
);
default:
throw new IllegalArgumentException("mutation variant not allowed");
}
}
}