# Amazon SageMaker Lineage Queries

Amazon SageMaker Lineage tracks events that happen within SageMaker allowing the relationships between them to be traced via a graph structure. SageMaker Lineage introduces a new API called LineageQuery that allows customers to query the lineage graph structure to discover relationship across their Machine Learning entities.

Your machine learning workflows can generate deeply nested relationships, the lineage APIs allow you to answer questions about these relationships. For example find all Data Sets that trained the model deployed to a given Endpoint or find all Models trained by a Data Set.

The lineage graph is created automatically by SageMaker and you can directly create or modify your own lineage.

In addition to the LineageQuery API, the SageMaker SDK provides wrapper functions that make it easy to run queries that span across multiple hops of the entity relationship graph. These APIs and helper functions are described in this notebook.

In [None]:
!pip install "sagemaker>=2.123.0"

In [None]:
import os
import boto3
import sagemaker
import pprint
from botocore.config import Config

sagemaker_session = sagemaker.Session()
pp = pprint.PrettyPrinter()

## SageMaker Lineage Queries

We explore SageMaker's lineage capabilities to traverse the relationships between the entities created in this notebook - datasets, model, endpoint, and training job.


In [None]:
from sagemaker.lineage.context import Context, EndpointContext
from sagemaker.lineage.action import Action
from sagemaker.lineage.association import Association
from sagemaker.lineage.artifact import Artifact, ModelArtifact, DatasetArtifact

from sagemaker.lineage.query import (
 LineageQuery,
 LineageFilter,
 LineageSourceEnum,
 LineageEntityEnum,
 LineageQueryDirectionEnum,
)

## Using the LineageQuery API to find entity associations

In this section we use two APIs, LineageQuery and LineageFilter to construct queries to answer questions about the Lineage Graph and extract entity relationships.

LineageQuery parameters:

 start_arns: A list of ARNs that is used as the starting point for the query.
 direction: The direction of the query.
 include_edges: If true, return edges in addition to vertices.
 query_filter: The query filter.

LineageFilter paramters:

 entities: A list of entity types (Artifact, Association, Action) to filter for when returning the results on LineageQuery
 sources: A list of source types (Endpoint, Model, Dataset) to filter for when returning the results of LineageQuery

A Context is automatically created when a SageMaker Endpoint is created, an Artifact is automatically created when a Model is created in SageMaker.


In [None]:
sm_client = boto3.client('sagemaker')

In [None]:
endpoint_arn = sm_client.describe_endpoint(EndpointName = 'workshop-project-staging')['EndpointArn']

In [None]:
# Find the endpoint context and model artifact that should be used for the lineage queries.

contexts = Context.list(source_uri=endpoint_arn)
context_name = list(contexts)[0].context_name
endpoint_context = EndpointContext.load(context_name=context_name)

### Find all datasets associated with an Endpoint

In [None]:
# Define the LineageFilter to look for entities of type `ARTIFACT` and the source of type `DATASET`.

query_filter = LineageFilter(
 entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.DATASET]
)

# Providing this `LineageFilter` to the `LineageQuery` constructs a query that traverses through the given context `endpoint_context`
# and find all datasets.

query_result = LineageQuery(sagemaker_session).query(
 start_arns=[endpoint_context.context_arn],
 query_filter=query_filter,
 direction=LineageQueryDirectionEnum.ASCENDANTS,
 include_edges=False,
)

# Parse through the query results to get the lineage objects corresponding to the datasets
dataset_artifacts = []
for vertex in query_result.vertices:
 dataset_artifacts.append(vertex.to_lineage_object().source.source_uri)

pp.pprint(dataset_artifacts)



### Find the models associated with an Endpoint

In [None]:
# Define the LineageFilter to look for entities of type `ARTIFACT` and the source of type `MODEL`.

query_filter = LineageFilter(
 entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.MODEL]
)

# Providing this `LineageFilter` to the `LineageQuery` constructs a query that traverses through the given context `endpoint_context`
# and find all datasets.

query_result = LineageQuery(sagemaker_session).query(
 start_arns=[endpoint_context.context_arn],
 query_filter=query_filter,
 direction=LineageQueryDirectionEnum.ASCENDANTS,
 include_edges=False,
)

# Parse through the query results to get the lineage objects corresponding to the model
model_artifacts = []
for vertex in query_result.vertices:
 model_artifacts.append(vertex.to_lineage_object().source.source_uri)

# The results of the `LineageQuery` API call return the ARN of the model deployed to the endpoint along with
# the S3 URI to the model.tar.gz file associated with the model
pp.pprint(model_artifacts)



### Find the trial components associated with an Endpoint

In [None]:
# Define the LineageFilter to look for entities of type `TRIAL_COMPONENT` and the source of type `TRAINING_JOB`.

query_filter = LineageFilter(
 entities=[LineageEntityEnum.TRIAL_COMPONENT],
 sources=[LineageSourceEnum.TRAINING_JOB],
)

# Providing this `LineageFilter` to the `LineageQuery` constructs a query that traverses through the given context `endpoint_context`
# and find all datasets.

query_result = LineageQuery(sagemaker_session).query(
 start_arns=[endpoint_context.context_arn],
 query_filter=query_filter,
 direction=LineageQueryDirectionEnum.ASCENDANTS,
 include_edges=False,
)

# Parse through the query results to get the ARNs of the training jobs associated with this Endpoint
trial_components = []
for vertex in query_result.vertices:
 trial_components.append(vertex.arn)

pp.pprint(trial_components)



## Change the focal point of lineage

The LineageQuery can be modified to have different start_arns which changes the focal point of lineage. In addition, the LineageFilter can take multiple sources and entities to expand the scope of the query.

Here we use the model as the lineage focal point and find the Endpoints and Datasets associated with it.


In [None]:
model_package_arn = sm_client.list_model_packages(ModelPackageGroupName = 'mlops-workshop-model-group')['ModelPackageSummaryList'][0]['ModelPackageArn']

In [None]:
# Get the ModelArtifact

model_artifact_summary = list(Artifact.list(source_uri=model_package_arn))[0]
model_artifact = ModelArtifact.load(artifact_arn=model_artifact_summary.artifact_arn)

In [None]:
query_filter = LineageFilter(
 entities=[LineageEntityEnum.ARTIFACT],
 sources=[LineageSourceEnum.ENDPOINT, LineageSourceEnum.DATASET],
)

query_result = LineageQuery(sagemaker_session).query(
 start_arns=[model_artifact.artifact_arn], # Model is the starting artifact
 query_filter=query_filter,
 # Find all the entities that descend from the model, i.e. the endpoint
 direction=LineageQueryDirectionEnum.DESCENDANTS,
 include_edges=False,
)

associations = []
for vertex in query_result.vertices:
 associations.append(vertex.to_lineage_object().source.source_uri)

query_result = LineageQuery(sagemaker_session).query(
 start_arns=[model_artifact.artifact_arn], # Model is the starting artifact
 query_filter=query_filter,
 # Find all the entities that ascend from the model, i.e. the datasets
 direction=LineageQueryDirectionEnum.ASCENDANTS,
 include_edges=False,
)

for vertex in query_result.vertices:
 associations.append(vertex.to_lineage_object().source.source_uri)

pp.pprint(associations)

## Use LineageQueryDirectionEnum.BOTH

When the direction is set to BOTH, when the query traverses the graph to find ascendant and descendant relationships, the traversal takes place not only from the starting node, but from each node that is visited.

e.g. If the training job is run twice and both models generated by the training job are deployed to endpoints, this result of the query with direction set to BOTH shows both endpoints. This is because the same image is used for training and deploying the model. Since the image is common to the model (start_arn) and both the endpoints, it appears in the query result.


In [None]:
query_filter = LineageFilter(
 entities=[LineageEntityEnum.ARTIFACT],
 sources=[LineageSourceEnum.ENDPOINT, LineageSourceEnum.DATASET],
)

query_result = LineageQuery(sagemaker_session).query(
 start_arns=[model_artifact.artifact_arn], # Model is the starting artifact
 query_filter=query_filter,
 # This specifies that the query should look for associations both ascending and descending for the start
 direction=LineageQueryDirectionEnum.BOTH,
 include_edges=False,
)

associations = []
for vertex in query_result.vertices:
 associations.append(vertex.to_lineage_object().source.source_uri)

pp.pprint(associations)

## Directions in LineageQuery: Ascendants vs. Descendants

To understand the direction in the Lineage Graph, take the following entity relationship graph - Dataset -> Training Job -> Model -> Endpoint

The endpoint is a descendant of the model, and the model is a descendant of the dataset. Similarly, the model is an ascendant of the endpoint The direction parameter can be used to specify whether the query should return entities that are descendants or ascendants of the entity in start_arns. If start_arns contains a model and the direction is DESCENDANTS, the query returns the endpoint. If the direction is ASCENDANTS, the query returns the dataset."


In [None]:
# In this example, we'll look at the impact of specifying the direction as ASCENDANT or DESCENDANT in a `LineageQuery`.

query_filter = LineageFilter(
 entities=[LineageEntityEnum.ARTIFACT],
 sources=[
 LineageSourceEnum.ENDPOINT,
 LineageSourceEnum.MODEL,
 LineageSourceEnum.DATASET,
 LineageSourceEnum.TRAINING_JOB,
 ],
)

query_result = LineageQuery(sagemaker_session).query(
 start_arns=[model_artifact.artifact_arn],
 query_filter=query_filter,
 direction=LineageQueryDirectionEnum.ASCENDANTS,
 include_edges=False,
)

ascendant_artifacts = []

# The lineage entity returned for the Training Job is a TrialComponent which can't be converted to a
# lineage object using the method `to_lineage_object()` so we extract the TrialComponent ARN.
for vertex in query_result.vertices:
 try:
 ascendant_artifacts.append(vertex.to_lineage_object().source.source_uri)
 except:
 ascendant_artifacts.append(vertex.arn)

print("Ascendant artifacts:")
pp.pprint(ascendant_artifacts)

query_result = LineageQuery(sagemaker_session).query(
 start_arns=[model_artifact.artifact_arn],
 query_filter=query_filter,
 direction=LineageQueryDirectionEnum.DESCENDANTS,
 include_edges=False,
)

descendant_artifacts = []
for vertex in query_result.vertices:
 try:
 descendant_artifacts.append(vertex.to_lineage_object().source.source_uri)
 except:
 # Handling TrialComponents.
 descendant_artifacts.append(vertex.arn)

print("Descendant artifacts:")
pp.pprint(descendant_artifacts)