### Cross account lineage - query endpoint 

#### Scenario:
* user creates modelpackage-group and modelpackage-version in account 1
* user creates endpoint in account 2 using account 1's modelpackage

`To get the lineage response account2 must share the lineage resource to account1 by AWS RAM`

docs:
* RAM : https://docs.aws.amazon.com/sagemaker/latest/dg/xaccount-lineage-tracking.html
* deploy endpoint cross account: https://docs.aws.amazon.com/sagemaker/latest/dg/model-registry-deploy.html#model-registry-deploy-xaccount

* Please reach out to Yuyao Zhang ozhang@amazon.com or Melanie Li mmelli@amazon.com for any issue or questions

In [None]:
# install pyvis if not installed
!pip install pyvis

#### The visualizer class to display lineage hierarchical response

In [None]:


from pyvis.network import Network
import os


class Visualizer:
    def __init__(self):
        self.directory = "generated"
        if not os.path.exists(self.directory):
            os.makedirs(self.directory)

    def render(self, query_lineage_response, scenario_name):
        net = self.get_network()

        for vertex in query_lineage_response["Vertices"]:
            arn = vertex["Arn"]
            if "Type" in vertex:
                label = vertex["Type"]
            else:
                label = None
            lineage_type = vertex["LineageType"]
            name = self.get_name(arn, label)
            title = self.get_title(arn, label, lineage_type)
            net.add_node(vertex["Arn"], label=name, title=title, shape="box", physics=False)

        for edge in query_lineage_response["Edges"]:
            source = edge["SourceArn"]
            dest = edge["DestinationArn"]
            net.add_edge(dest, source)

        return net.show(f"{self.directory}/{scenario_name}.html")

    def get_title(self, arn, label, lineage_type):
        return f"Arn: {arn}\nType: {label}\nLineage Type: {lineage_type}"

    def get_name(self, arn, type):
        print(arn)
        name = arn.split("/")[1]+' '+type
        return name

    def get_network(self):
        net = Network(height="800px", width="1000px", directed=True, notebook=True)
        net.set_options(
            """
            var options = {
                  "nodes": {
                    "borderWidth": 1,
                    "shadow": {
                      "enabled": true
                    },
                    "shapeProperties": {
                      "borderRadius": 0
                    },
                    "size": 40,
                    "shape": "circle"
                  },
                  "edges": {
                    "arrows": {
                      "to": {
                        "enabled": true
                      }
                    },
                    "color": {
                      "inherit": true
                    },
                    "smooth": false
                  },
                  "layout": {
                    "hierarchical": {
                      "enabled": false,
                      "direction": "LR",
                      "sortMethod": "directed"
                    }
                  }
                }
        """
        )
        return net

#### Lineage imports and variables

In [None]:
import boto3
import os
import boto3
import sagemaker
import pprint
from botocore.config import Config
from sagemaker.lineage.context import Context, EndpointContext
from sagemaker.lineage.action import Action
from sagemaker.lineage.association import Association
from sagemaker.lineage.artifact import Artifact, ModelArtifact, DatasetArtifact

from sagemaker.lineage.query import (
    LineageQuery,
    LineageFilter,
    LineageSourceEnum,
    LineageEntityEnum,
    LineageQueryDirectionEnum,
)
sagemaker_session = sagemaker.Session()
sm_client = sagemaker_session.sagemaker_client

region = sagemaker_session.boto_region_name

default_bucket = sagemaker_session.default_bucket()
role = sagemaker.get_execution_role()

#### Querying all artifacts from ModelPackage
* Get model artifacts from modelpackage arn.
* Filter context lineage entities and endpoint sources by running LineageQuery.

In [None]:
# From model package version get model artifact
model_package_arn = 'arn:aws:sagemaker:us-east-1:631450739534:model-package/xgboost-abalone2022-05-15-10-42-27/1'
model_artifact_summary = list(Artifact.list(source_uri=model_package_arn))[0]
model_artifact = ModelArtifact.load(artifact_arn=model_artifact_summary.artifact_arn)
query_filter = LineageFilter(
    entities=[LineageEntityEnum.CONTEXT],
    sources=[LineageSourceEnum.ENDPOINT, LineageSourceEnum.MODEL],
)

query_result = LineageQuery(sagemaker_session).query(
    start_arns=[model_artifact.artifact_arn],  # Model is the starting artifact
    query_filter=query_filter,
    # Find all the entities that descend from the model, i.e. the endpoint
    direction=LineageQueryDirectionEnum.DESCENDANTS,
    include_edges=True,
)
associations = []
for vertex in query_result.vertices:
    associations.append(vertex.__dict__)
print(associations)

# Run above block for LineageQueryDirectionEnum.ASCENDANTS to get entities ascends

#### Visualize hierarchy from query response

In [None]:
query_response = sm_client.query_lineage(
    StartArns=[model_artifact.artifact_arn], Direction="Descendants", IncludeEdges=True
)

viz = Visualizer()
viz.render(query_response, "ModelPackageVersion")

#### Get endpoint arn from endpoint context

In [None]:
endpoint_context ='arn:aws:sagemaker:us-east-1:682604156941:context/xgboost-abalone2022-05-15-10-42-27'

In [None]:
sm_client.describe_context(ContextName=endpoint_context)