{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "be437ab0-42b8-44e1-8cc7-c92b765c037a", "metadata": {}, "outputs": [], "source": [ "import sagemaker\n", "import boto3\n", "from sagemaker.estimator import Estimator\n", "from sagemaker.tensorflow import TensorFlow\n", "from sagemaker.local import LocalSession\n", "from sagemaker import get_execution_role" ] }, { "cell_type": "code", "execution_count": null, "id": "7aab0a9b-39d5-4d4f-b985-46952e84a6d3", "metadata": {}, "outputs": [], "source": [ "role = get_execution_role()\n", "session = LocalSession()\n", "session.config = {'local': {'local_code': True}}" ] }, { "cell_type": "code", "execution_count": null, "id": "3299bc66-853d-4c8f-b7ef-dbeaf7690259", "metadata": {}, "outputs": [], "source": [ "sm_training_s3_input_location=\"\" #for local testing, this variable could be a directory in the local file\n", "sm_training_s3_output_location=\"\" #for local testing, this variable could be a directory in the local file\n", "epochs = 1 # change the number of epochs for the job" ] }, { "cell_type": "code", "execution_count": null, "id": "36c80f4c-0c00-4e7b-bf3c-56582716bdc3", "metadata": {}, "outputs": [], "source": [ "metric_definitions=[\n", " {'Name': 'train:error', 'Regex': 'loss: (.*?) -'},\n", " {'Name': 'validation:acc', 'Regex': 'acc: (.*)'}\n", "]\n", "\n", "hyperparams = {\n", " 'save_weights' : '/opt/ml/model/sm_model', \n", " 'dataDir' : '/opt/ml/input/data/training', \n", " 'epochs' : epochs, \n", " 'tol' : 50\n", " }" ] }, { "cell_type": "code", "execution_count": null, "id": "d6d1390a-ba95-43d8-abfa-f20518fd994f", "metadata": {}, "outputs": [], "source": [ "estimator = TensorFlow(\n", " framework_version=\"2.5.1\",\n", " role=role,\n", " instance_count=1,\n", " instance_type='local_gpu',\n", " volume_size=225,\n", " output_path=sm_training_s3_output_location,\n", " sagemaker_session=session,\n", " hyperparameters=hyperparams,\n", " metric_definitions=metric_definitions,\n", " source_dir=\"../src/training\",\n", " entry_point='train.py',\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "d82cd696-3afe-47c4-9df2-441f9a45a39f", "metadata": {}, "outputs": [], "source": [ "estimator.fit({'training': sm_training_s3_input_location})" ] }, { "cell_type": "code", "execution_count": null, "id": "4e5f9d8d-129f-4b61-9909-a0ebb6f2057b", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "conda_python3", "language": "python", "name": "conda_python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.8" } }, "nbformat": 4, "nbformat_minor": 5 }