{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# BYOC training for paddleOCR" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sagemaker as sage\n", "from time import gmtime, strftime\n", "from sagemaker import get_execution_role" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## step1: Upload data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sess = sage.Session()\n", "\n", "# Local directory for training data\n", "WORK_DIRECTORY = \"./input/data\"\n", "\n", "# S3 prefix\n", "prefix = \"DEMO-paddle-byo\"\n", "\n", "role = get_execution_role()\n", "\n", "data_location = sess.upload_data(WORK_DIRECTORY, key_prefix=prefix)\n", "print(data_location)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## step2: Get the training image container in Amazon ECR" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "account = sess.boto_session.client(\"sts\").get_caller_identity()[\"Account\"]\n", "region = sess.boto_session.region_name\n", "\n", "# You need to replace project id with your own ID \n", "PROJECT_ID = \"sagemaker-p-5an0os9jqfdi\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "image = f'{account}.dkr.ecr.{region}.amazonaws.com/{PROJECT_ID}-training-imagebuild:latest'\n", "print('Training image location: ',image)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## step3: Configure the SageMaker Experiments for experiment tracking (optional)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Install the SageMaker Experiments Python SDK\n", "import sys\n", "!{sys.executable} -m pip install sagemaker-experiments" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import time\n", "from time import strftime\n", "\n", "from smexperiments.experiment import Experiment\n", "from smexperiments.trial import Trial\n", "from smexperiments.trial_component import TrialComponent\n", "from smexperiments.tracker import Tracker\n", "\n", "create_date = strftime(\"%Y-%m-%d-%H-%M-%S\")\n", "\n", "demo_experiment = Experiment.create(experiment_name = \"PaddleOCR-{}\".format(create_date),\n", " description = \"OCR experiment\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "demo_trial = Trial.create(trial_name = \"trial-{}\".format(create_date),\n", " experiment_name = demo_experiment.experiment_name)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## step4: Create training job for training" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sess = sage.Session()\n", "hyperparameters = {\"epoch_num\": 10,\n", " \"print_batch_step\":5,\n", " \"save_epoch_step\":3,\n", " 'pretrained_model':'/opt/program/pretrain/ch_ppocr_mobile_v2.0_rec_train/best_accuracy'}\n", "\n", "train = sage.estimator.Estimator(\n", " image,\n", " role,\n", " instance_count = 1,\n", " sagemaker_session=sess,\n", " instance_type='ml.p3.2xlarge',\n", " hyperparameters=hyperparameters,\n", ")\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train.fit(data_location, \n", " wait=False,\n", " experiment_config = {\n", " # \"ExperimentName\"\n", " \"TrialName\" : demo_trial.trial_name,\n", " \"TrialComponentDisplayName\" : \"TrainingJob\",\n", " })" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## step5: Check the training job status in SageMaker Studio or AWS console" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "instance_type": "ml.t3.medium", "kernelspec": { "display_name": "Python 3 (Data Science)", "language": "python", "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:eu-west-1:470317259841:image/datascience-1.0" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.10" } }, "nbformat": 4, "nbformat_minor": 5 }