# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of # the License is located at # # http://aws.amazon.com/apache2.0/ # # or in the "license" file accompanying this file. This file is # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from __future__ import absolute_import from urllib import request import json from packaging.version import Version from enum import Enum class Tasks(str, Enum): """The ML task name as referenced in the infix of the model ID.""" IC = "ic" OD = "od" OD1 = "od1" SEMSEG = "semseg" IS = "is" TC = "tc" SPC = "spc" EQA = "eqa" TEXT_GENERATION = "textgeneration" IC_EMBEDDING = "icembedding" TC_EMBEDDING = "tcembedding" NER = "ner" SUMMARIZATION = "summarization" TRANSLATION = "translation" TABULAR_REGRESSION = "regression" TABULAR_CLASSIFICATION = "classification" class ProblemTypes(str, Enum): """Possible problem types for JumpStart models.""" IMAGE_CLASSIFICATION = "Image Classification" IMAGE_EMBEDDING = "Image Embedding" OBJECT_DETECTION = "Object Detection" SEMANTIC_SEGMENTATION = "Semantic Segmentation" INSTANCE_SEGMENTATION = "Instance Segmentation" TEXT_CLASSIFICATION = "Text Classification" TEXT_EMBEDDING = "Text Embedding" QUESTION_ANSWERING = "Question Answering" SENTENCE_PAIR_CLASSIFICATION = "Sentence Pair Classification" TEXT_GENERATION = "Text Generation" TEXT_SUMMARIZATION = "Text Summarization" MACHINE_TRANSLATION = "Machine Translation" NAMED_ENTITY_RECOGNITION = "Named Entity Recognition" TABULAR_REGRESSION = "Regression" TABULAR_CLASSIFICATION = "Classification" class Frameworks(str, Enum): """Possible frameworks for JumpStart models""" TENSORFLOW = "Tensorflow Hub" PYTORCH = "Pytorch Hub" HUGGINGFACE = "HuggingFace" CATBOOST = "Catboost" GLUONCV = "GluonCV" LIGHTGBM = "LightGBM" XGBOOST = "XGBoost" SCIKIT_LEARN = "ScikitLearn" SOURCE = "Source" JUMPSTART_REGION = "eu-west-2" SDK_MANIFEST_FILE = "models_manifest.json" JUMPSTART_BUCKET_BASE_URL = "https://jumpstart-cache-prod-{}.s3.{}.amazonaws.com".format( JUMPSTART_REGION, JUMPSTART_REGION ) TASK_MAP = { Tasks.IC: ProblemTypes.IMAGE_CLASSIFICATION, Tasks.IC_EMBEDDING: ProblemTypes.IMAGE_EMBEDDING, Tasks.OD: ProblemTypes.OBJECT_DETECTION, Tasks.OD1: ProblemTypes.OBJECT_DETECTION, Tasks.SEMSEG: ProblemTypes.SEMANTIC_SEGMENTATION, Tasks.IS: ProblemTypes.INSTANCE_SEGMENTATION, Tasks.TC: ProblemTypes.TEXT_CLASSIFICATION, Tasks.TC_EMBEDDING: ProblemTypes.TEXT_EMBEDDING, Tasks.EQA: ProblemTypes.QUESTION_ANSWERING, Tasks.SPC: ProblemTypes.SENTENCE_PAIR_CLASSIFICATION, Tasks.TEXT_GENERATION: ProblemTypes.TEXT_GENERATION, Tasks.SUMMARIZATION: ProblemTypes.TEXT_SUMMARIZATION, Tasks.TRANSLATION: ProblemTypes.MACHINE_TRANSLATION, Tasks.NER: ProblemTypes.NAMED_ENTITY_RECOGNITION, Tasks.TABULAR_REGRESSION: ProblemTypes.TABULAR_REGRESSION, Tasks.TABULAR_CLASSIFICATION: ProblemTypes.TABULAR_CLASSIFICATION, } TO_FRAMEWORK = { "Tensorflow Hub": Frameworks.TENSORFLOW, "Pytorch Hub": Frameworks.PYTORCH, "HuggingFace": Frameworks.HUGGINGFACE, "Catboost": Frameworks.CATBOOST, "GluonCV": Frameworks.GLUONCV, "LightGBM": Frameworks.LIGHTGBM, "XGBoost": Frameworks.XGBOOST, "ScikitLearn": Frameworks.SCIKIT_LEARN, "Source": Frameworks.SOURCE, } MODALITY_MAP = { (Tasks.IC, Frameworks.PYTORCH): "algorithms/vision/image_classification_pytorch.rst", (Tasks.IC, Frameworks.TENSORFLOW): "algorithms/vision/image_classification_tensorflow.rst", (Tasks.IC_EMBEDDING, Frameworks.TENSORFLOW): "algorithms/vision/image_embedding_tensorflow.rst", (Tasks.IS, Frameworks.GLUONCV): "algorithms/vision/instance_segmentation_mxnet.rst", (Tasks.OD, Frameworks.GLUONCV): "algorithms/vision/object_detection_mxnet.rst", (Tasks.OD, Frameworks.PYTORCH): "algorithms/vision/object_detection_pytorch.rst", (Tasks.OD, Frameworks.TENSORFLOW): "algorithms/vision/object_detection_tensorflow.rst", (Tasks.SEMSEG, Frameworks.GLUONCV): "algorithms/vision/semantic_segmentation_mxnet.rst", ( Tasks.TRANSLATION, Frameworks.HUGGINGFACE, ): "algorithms/text/machine_translation_hugging_face.rst", (Tasks.NER, Frameworks.GLUONCV): "algorithms/text/named_entity_recognition_hugging_face.rst", (Tasks.EQA, Frameworks.PYTORCH): "algorithms/text/question_answering_pytorch.rst", ( Tasks.SPC, Frameworks.HUGGINGFACE, ): "algorithms/text/sentence_pair_classification_hugging_face.rst", ( Tasks.SPC, Frameworks.TENSORFLOW, ): "algorithms/text/sentence_pair_classification_tensorflow.rst", (Tasks.TC, Frameworks.TENSORFLOW): "algorithms/text/text_classification_tensorflow.rst", ( Tasks.TC_EMBEDDING, Frameworks.GLUONCV, ): "algorithms/vision/text_embedding_tensorflow_mxnet.rst", ( Tasks.TC_EMBEDDING, Frameworks.TENSORFLOW, ): "algorithms/vision/text_embedding_tensorflow_mxnet.rst", ( Tasks.TEXT_GENERATION, Frameworks.HUGGINGFACE, ): "algorithms/text/text_generation_hugging_face.rst", ( Tasks.SUMMARIZATION, Frameworks.HUGGINGFACE, ): "algorithms/text/text_summarization_hugging_face.rst", } def get_jumpstart_sdk_manifest(): url = "{}/{}".format(JUMPSTART_BUCKET_BASE_URL, SDK_MANIFEST_FILE) with request.urlopen(url) as f: models_manifest = f.read().decode("utf-8") return json.loads(models_manifest) def get_jumpstart_sdk_spec(key): url = "{}/{}".format(JUMPSTART_BUCKET_BASE_URL, key) with request.urlopen(url) as f: model_spec = f.read().decode("utf-8") return json.loads(model_spec) def get_model_task(id): task_short = id.split("-")[1] return TASK_MAP[task_short] if task_short in TASK_MAP else "Source" def get_string_model_task(id): return id.split("-")[1] def get_model_source(url): if "tfhub" in url: return "Tensorflow Hub" if "pytorch" in url: return "Pytorch Hub" if "huggingface" in url: return "HuggingFace" if "catboost" in url: return "Catboost" if "gluon" in url: return "GluonCV" if "lightgbm" in url: return "LightGBM" if "xgboost" in url: return "XGBoost" if "scikit" in url: return "ScikitLearn" else: return "Source" def create_jumpstart_model_table(): sdk_manifest = get_jumpstart_sdk_manifest() sdk_manifest_top_versions_for_models = {} for model in sdk_manifest: if model["model_id"] not in sdk_manifest_top_versions_for_models: sdk_manifest_top_versions_for_models[model["model_id"]] = model else: if Version( sdk_manifest_top_versions_for_models[model["model_id"]]["version"] ) < Version(model["version"]): sdk_manifest_top_versions_for_models[model["model_id"]] = model file_content_intro = [] file_content_intro.append(".. _all-pretrained-models:\n\n") file_content_intro.append(".. |external-link| raw:: html\n\n") file_content_intro.append(' \n\n') file_content_intro.append("================================================\n") file_content_intro.append("Built-in Algorithms with pre-trained Model Table\n") file_content_intro.append("================================================\n") file_content_intro.append( """ The SageMaker Python SDK uses model IDs and model versions to access the necessary utilities for pre-trained models. This table serves to provide the core material plus some extra information that can be useful in selecting the correct model ID and corresponding parameters.\n""" ) file_content_intro.append( """ If you want to automatically use the latest version of the model, use "*" for the `model_version` attribute. We highly suggest pinning an exact model version however.\n""" ) file_content_intro.append( """ These models are also available through the `JumpStart UI in SageMaker Studio `__\n""" ) file_content_intro.append("\n") file_content_intro.append(".. list-table:: Available Models\n") file_content_intro.append(" :widths: 50 20 20 20 30 20\n") file_content_intro.append(" :header-rows: 1\n") file_content_intro.append(" :class: datatable\n") file_content_intro.append("\n") file_content_intro.append(" * - Model ID\n") file_content_intro.append(" - Fine Tunable?\n") file_content_intro.append(" - Latest Version\n") file_content_intro.append(" - Min SDK Version\n") file_content_intro.append(" - Problem Type\n") file_content_intro.append(" - Source\n") dynamic_table_files = [] file_content_entries = [] for model in sdk_manifest_top_versions_for_models.values(): model_spec = get_jumpstart_sdk_spec(model["spec_key"]) model_task = get_model_task(model_spec["model_id"]) string_model_task = get_string_model_task(model_spec["model_id"]) model_source = get_model_source(model_spec["url"]) file_content_entries.append(" * - {}\n".format(model_spec["model_id"])) file_content_entries.append(" - {}\n".format(model_spec["training_supported"])) file_content_entries.append(" - {}\n".format(model["version"])) file_content_entries.append(" - {}\n".format(model["min_version"])) file_content_entries.append(" - {}\n".format(model_task)) file_content_entries.append( " - `{} <{}>`__ |external-link|\n".format(model_source, model_spec["url"]) ) if (string_model_task, TO_FRAMEWORK[model_source]) in MODALITY_MAP: file_content_single_entry = [] if ( MODALITY_MAP[(string_model_task, TO_FRAMEWORK[model_source])] not in dynamic_table_files ): file_content_single_entry.append("\n") file_content_single_entry.append(".. list-table:: Available Models\n") file_content_single_entry.append(" :widths: 50 20 20 20 20\n") file_content_single_entry.append(" :header-rows: 1\n") file_content_single_entry.append(" :class: datatable\n") file_content_single_entry.append("\n") file_content_single_entry.append(" * - Model ID\n") file_content_single_entry.append(" - Fine Tunable?\n") file_content_single_entry.append(" - Latest Version\n") file_content_single_entry.append(" - Min SDK Version\n") file_content_single_entry.append(" - Source\n") dynamic_table_files.append( MODALITY_MAP[(string_model_task, TO_FRAMEWORK[model_source])] ) file_content_single_entry.append(" * - {}\n".format(model_spec["model_id"])) file_content_single_entry.append(" - {}\n".format(model_spec["training_supported"])) file_content_single_entry.append(" - {}\n".format(model["version"])) file_content_single_entry.append(" - {}\n".format(model["min_version"])) file_content_single_entry.append( " - `{} <{}>`__\n".format(model_source, model_spec["url"]) ) f = open(MODALITY_MAP[(string_model_task, TO_FRAMEWORK[model_source])], "a") f.writelines(file_content_single_entry) f.close() f = open("doc_utils/pretrainedmodels.rst", "a") f.writelines(file_content_intro) f.writelines(file_content_entries) f.close()