{ "cells": [ { "cell_type": "markdown", "id": "21865fd3", "metadata": {}, "source": [ "# Amazon SageMaker Lineage Queries" ] }, { "cell_type": "markdown", "id": "36f40180", "metadata": {}, "source": [ "Amazon SageMaker Lineage tracks events that happen within SageMaker allowing the relationships between them to be traced via a graph structure. SageMaker Lineage introduces a new API called LineageQuery that allows customers to query the lineage graph structure to discover relationship across their Machine Learning entities.\n", "\n", "Your machine learning workflows can generate deeply nested relationships, the lineage APIs allow you to answer questions about these relationships. For example find all Data Sets that trained the model deployed to a given Endpoint or find all Models trained by a Data Set.\n", "\n", "The lineage graph is created automatically by SageMaker and you can directly create or modify your own lineage.\n", "\n", "In addition to the LineageQuery API, the SageMaker SDK provides wrapper functions that make it easy to run queries that span across multiple hops of the entity relationship graph. These APIs and helper functions are described in this notebook." ] }, { "cell_type": "code", "execution_count": null, "id": "cbb1664a", "metadata": {}, "outputs": [], "source": [ "!pip install \"sagemaker>=2.123.0\"" ] }, { "cell_type": "code", "execution_count": null, "id": "9c6f8537", "metadata": {}, "outputs": [], "source": [ "import os\n", "import boto3\n", "import sagemaker\n", "import pprint\n", "from botocore.config import Config\n", "\n", "sagemaker_session = sagemaker.Session()\n", "pp = pprint.PrettyPrinter()" ] }, { "cell_type": "markdown", "id": "28d8191a", "metadata": {}, "source": [ "## SageMaker Lineage Queries\n", "\n", "We explore SageMaker's lineage capabilities to traverse the relationships between the entities created in this notebook - datasets, model, endpoint, and training job.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "47b5a34e", "metadata": {}, "outputs": [], "source": [ "from sagemaker.lineage.context import Context, EndpointContext\n", "from sagemaker.lineage.action import Action\n", "from sagemaker.lineage.association import Association\n", "from sagemaker.lineage.artifact import Artifact, ModelArtifact, DatasetArtifact\n", "\n", "from sagemaker.lineage.query import (\n", " LineageQuery,\n", " LineageFilter,\n", " LineageSourceEnum,\n", " LineageEntityEnum,\n", " LineageQueryDirectionEnum,\n", ")" ] }, { "cell_type": "markdown", "id": "0962f515", "metadata": {}, "source": [ "## Using the LineageQuery API to find entity associations\n", "\n", "In this section we use two APIs, LineageQuery and LineageFilter to construct queries to answer questions about the Lineage Graph and extract entity relationships.\n", "\n", "LineageQuery parameters:\n", "\n", " start_arns: A list of ARNs that is used as the starting point for the query.\n", " direction: The direction of the query.\n", " include_edges: If true, return edges in addition to vertices.\n", " query_filter: The query filter.\n", "\n", "LineageFilter paramters:\n", "\n", " entities: A list of entity types (Artifact, Association, Action) to filter for when returning the results on LineageQuery\n", " sources: A list of source types (Endpoint, Model, Dataset) to filter for when returning the results of LineageQuery\n", "\n", "A Context is automatically created when a SageMaker Endpoint is created, an Artifact is automatically created when a Model is created in SageMaker.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "dc8ce24b", "metadata": {}, "outputs": [], "source": [ "sm_client = boto3.client('sagemaker')" ] }, { "cell_type": "code", "execution_count": null, "id": "4037ab77", "metadata": {}, "outputs": [], "source": [ "endpoint_arn = sm_client.describe_endpoint(EndpointName = 'workshop-project-staging')['EndpointArn']" ] }, { "cell_type": "code", "execution_count": null, "id": "5bfd73ec", "metadata": {}, "outputs": [], "source": [ "# Find the endpoint context and model artifact that should be used for the lineage queries.\n", "\n", "contexts = Context.list(source_uri=endpoint_arn)\n", "context_name = list(contexts)[0].context_name\n", "endpoint_context = EndpointContext.load(context_name=context_name)" ] }, { "cell_type": "markdown", "id": "6d0c00d0", "metadata": {}, "source": [ "### Find all datasets associated with an Endpoint" ] }, { "cell_type": "code", "execution_count": null, "id": "5dfc3d8f", "metadata": {}, "outputs": [], "source": [ "# Define the LineageFilter to look for entities of type `ARTIFACT` and the source of type `DATASET`.\n", "\n", "query_filter = LineageFilter(\n", " entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.DATASET]\n", ")\n", "\n", "# Providing this `LineageFilter` to the `LineageQuery` constructs a query that traverses through the given context `endpoint_context`\n", "# and find all datasets.\n", "\n", "query_result = LineageQuery(sagemaker_session).query(\n", " start_arns=[endpoint_context.context_arn],\n", " query_filter=query_filter,\n", " direction=LineageQueryDirectionEnum.ASCENDANTS,\n", " include_edges=False,\n", ")\n", "\n", "# Parse through the query results to get the lineage objects corresponding to the datasets\n", "dataset_artifacts = []\n", "for vertex in query_result.vertices:\n", " dataset_artifacts.append(vertex.to_lineage_object().source.source_uri)\n", "\n", "pp.pprint(dataset_artifacts)\n", "\n" ] }, { "cell_type": "markdown", "id": "ff7e9ee5", "metadata": {}, "source": [ "### Find the models associated with an Endpoint" ] }, { "cell_type": "code", "execution_count": null, "id": "bb992fe3", "metadata": {}, "outputs": [], "source": [ "# Define the LineageFilter to look for entities of type `ARTIFACT` and the source of type `MODEL`.\n", "\n", "query_filter = LineageFilter(\n", " entities=[LineageEntityEnum.ARTIFACT], sources=[LineageSourceEnum.MODEL]\n", ")\n", "\n", "# Providing this `LineageFilter` to the `LineageQuery` constructs a query that traverses through the given context `endpoint_context`\n", "# and find all datasets.\n", "\n", "query_result = LineageQuery(sagemaker_session).query(\n", " start_arns=[endpoint_context.context_arn],\n", " query_filter=query_filter,\n", " direction=LineageQueryDirectionEnum.ASCENDANTS,\n", " include_edges=False,\n", ")\n", "\n", "# Parse through the query results to get the lineage objects corresponding to the model\n", "model_artifacts = []\n", "for vertex in query_result.vertices:\n", " model_artifacts.append(vertex.to_lineage_object().source.source_uri)\n", "\n", "# The results of the `LineageQuery` API call return the ARN of the model deployed to the endpoint along with\n", "# the S3 URI to the model.tar.gz file associated with the model\n", "pp.pprint(model_artifacts)\n", "\n" ] }, { "cell_type": "markdown", "id": "3da33a61", "metadata": {}, "source": [ "### Find the trial components associated with an Endpoint" ] }, { "cell_type": "code", "execution_count": null, "id": "3772b853", "metadata": {}, "outputs": [], "source": [ "# Define the LineageFilter to look for entities of type `TRIAL_COMPONENT` and the source of type `TRAINING_JOB`.\n", "\n", "query_filter = LineageFilter(\n", " entities=[LineageEntityEnum.TRIAL_COMPONENT],\n", " sources=[LineageSourceEnum.TRAINING_JOB],\n", ")\n", "\n", "# Providing this `LineageFilter` to the `LineageQuery` constructs a query that traverses through the given context `endpoint_context`\n", "# and find all datasets.\n", "\n", "query_result = LineageQuery(sagemaker_session).query(\n", " start_arns=[endpoint_context.context_arn],\n", " query_filter=query_filter,\n", " direction=LineageQueryDirectionEnum.ASCENDANTS,\n", " include_edges=False,\n", ")\n", "\n", "# Parse through the query results to get the ARNs of the training jobs associated with this Endpoint\n", "trial_components = []\n", "for vertex in query_result.vertices:\n", " trial_components.append(vertex.arn)\n", "\n", "pp.pprint(trial_components)\n", "\n" ] }, { "cell_type": "markdown", "id": "d27a243b", "metadata": {}, "source": [ "## Change the focal point of lineage\n", "\n", "The LineageQuery can be modified to have different start_arns which changes the focal point of lineage. In addition, the LineageFilter can take multiple sources and entities to expand the scope of the query.\n", "\n", "Here we use the model as the lineage focal point and find the Endpoints and Datasets associated with it.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "73e51133", "metadata": {}, "outputs": [], "source": [ "model_package_arn = sm_client.list_model_packages(ModelPackageGroupName = 'mlops-workshop-model-group')['ModelPackageSummaryList'][0]['ModelPackageArn']" ] }, { "cell_type": "code", "execution_count": null, "id": "87cd58f2", "metadata": {}, "outputs": [], "source": [ "# Get the ModelArtifact\n", "\n", "model_artifact_summary = list(Artifact.list(source_uri=model_package_arn))[0]\n", "model_artifact = ModelArtifact.load(artifact_arn=model_artifact_summary.artifact_arn)" ] }, { "cell_type": "code", "execution_count": null, "id": "41f1964e", "metadata": {}, "outputs": [], "source": [ "query_filter = LineageFilter(\n", " entities=[LineageEntityEnum.ARTIFACT],\n", " sources=[LineageSourceEnum.ENDPOINT, LineageSourceEnum.DATASET],\n", ")\n", "\n", "query_result = LineageQuery(sagemaker_session).query(\n", " start_arns=[model_artifact.artifact_arn], # Model is the starting artifact\n", " query_filter=query_filter,\n", " # Find all the entities that descend from the model, i.e. the endpoint\n", " direction=LineageQueryDirectionEnum.DESCENDANTS,\n", " include_edges=False,\n", ")\n", "\n", "associations = []\n", "for vertex in query_result.vertices:\n", " associations.append(vertex.to_lineage_object().source.source_uri)\n", "\n", "query_result = LineageQuery(sagemaker_session).query(\n", " start_arns=[model_artifact.artifact_arn], # Model is the starting artifact\n", " query_filter=query_filter,\n", " # Find all the entities that ascend from the model, i.e. the datasets\n", " direction=LineageQueryDirectionEnum.ASCENDANTS,\n", " include_edges=False,\n", ")\n", "\n", "for vertex in query_result.vertices:\n", " associations.append(vertex.to_lineage_object().source.source_uri)\n", "\n", "pp.pprint(associations)" ] }, { "cell_type": "markdown", "id": "f40027f0", "metadata": {}, "source": [ "## Use LineageQueryDirectionEnum.BOTH\n", "\n", "When the direction is set to BOTH, when the query traverses the graph to find ascendant and descendant relationships, the traversal takes place not only from the starting node, but from each node that is visited.\n", "\n", "e.g. If the training job is run twice and both models generated by the training job are deployed to endpoints, this result of the query with direction set to BOTH shows both endpoints. This is because the same image is used for training and deploying the model. Since the image is common to the model (start_arn) and both the endpoints, it appears in the query result.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "ca3457a2", "metadata": {}, "outputs": [], "source": [ "query_filter = LineageFilter(\n", " entities=[LineageEntityEnum.ARTIFACT],\n", " sources=[LineageSourceEnum.ENDPOINT, LineageSourceEnum.DATASET],\n", ")\n", "\n", "query_result = LineageQuery(sagemaker_session).query(\n", " start_arns=[model_artifact.artifact_arn], # Model is the starting artifact\n", " query_filter=query_filter,\n", " # This specifies that the query should look for associations both ascending and descending for the start\n", " direction=LineageQueryDirectionEnum.BOTH,\n", " include_edges=False,\n", ")\n", "\n", "associations = []\n", "for vertex in query_result.vertices:\n", " associations.append(vertex.to_lineage_object().source.source_uri)\n", "\n", "pp.pprint(associations)" ] }, { "cell_type": "markdown", "id": "0f403eb0", "metadata": {}, "source": [ "## Directions in LineageQuery: Ascendants vs. Descendants\n", "\n", "To understand the direction in the Lineage Graph, take the following entity relationship graph - Dataset -> Training Job -> Model -> Endpoint\n", "\n", "The endpoint is a descendant of the model, and the model is a descendant of the dataset. Similarly, the model is an ascendant of the endpoint The direction parameter can be used to specify whether the query should return entities that are descendants or ascendants of the entity in start_arns. If start_arns contains a model and the direction is DESCENDANTS, the query returns the endpoint. If the direction is ASCENDANTS, the query returns the dataset.\"\n" ] }, { "cell_type": "code", "execution_count": null, "id": "a60f817f", "metadata": {}, "outputs": [], "source": [ "# In this example, we'll look at the impact of specifying the direction as ASCENDANT or DESCENDANT in a `LineageQuery`.\n", "\n", "query_filter = LineageFilter(\n", " entities=[LineageEntityEnum.ARTIFACT],\n", " sources=[\n", " LineageSourceEnum.ENDPOINT,\n", " LineageSourceEnum.MODEL,\n", " LineageSourceEnum.DATASET,\n", " LineageSourceEnum.TRAINING_JOB,\n", " ],\n", ")\n", "\n", "query_result = LineageQuery(sagemaker_session).query(\n", " start_arns=[model_artifact.artifact_arn],\n", " query_filter=query_filter,\n", " direction=LineageQueryDirectionEnum.ASCENDANTS,\n", " include_edges=False,\n", ")\n", "\n", "ascendant_artifacts = []\n", "\n", "# The lineage entity returned for the Training Job is a TrialComponent which can't be converted to a\n", "# lineage object using the method `to_lineage_object()` so we extract the TrialComponent ARN.\n", "for vertex in query_result.vertices:\n", " try:\n", " ascendant_artifacts.append(vertex.to_lineage_object().source.source_uri)\n", " except:\n", " ascendant_artifacts.append(vertex.arn)\n", "\n", "print(\"Ascendant artifacts:\")\n", "pp.pprint(ascendant_artifacts)\n", "\n", "query_result = LineageQuery(sagemaker_session).query(\n", " start_arns=[model_artifact.artifact_arn],\n", " query_filter=query_filter,\n", " direction=LineageQueryDirectionEnum.DESCENDANTS,\n", " include_edges=False,\n", ")\n", "\n", "descendant_artifacts = []\n", "for vertex in query_result.vertices:\n", " try:\n", " descendant_artifacts.append(vertex.to_lineage_object().source.source_uri)\n", " except:\n", " # Handling TrialComponents.\n", " descendant_artifacts.append(vertex.arn)\n", "\n", "print(\"Descendant artifacts:\")\n", "pp.pprint(descendant_artifacts)" ] }, { "cell_type": "code", "execution_count": null, "id": "9e432880", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "instance_type": "ml.t3.medium", "kernelspec": { "display_name": "Python 3 (Data Science)", "language": "python", "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/datascience-1.0" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.10" } }, "nbformat": 4, "nbformat_minor": 5 }