# Test your FSx for Lustre and SageMaker training connection

### 1. Point to relevant configs

In [None]:
import sagemaker

sess = sagemaker.Session()

bucket = sess.default_bucket()

In [None]:
from sagemaker.inputs import FileSystemInput

# Specify FSx Lustre file system id.
file_system_id = ""

# Specify the SG and subnet used by the FSX, these are passed to SM Estimator so jobs use this as well
fsx_security_group_id = ""
fsx_subnet = ""

# Specify directory path for input data on the file system.
# You need to provide normalized and absolute path below.
# Your mount name can be provided by you when creating fsx, or generated automatically.
# You can find this mount_name on the FSX page in console.
# Example of fsx generated mount_name: "3x5lhbmv"
base_path = ""

# Specify your file system type.
file_system_type = "FSxLustre"

train = FileSystemInput(
 file_system_id=file_system_id,
 file_system_type=file_system_type,
 directory_path=base_path,
 file_system_access_mode="rw",
)

data_channels = {"train": train}

In [None]:
kwargs = {}
# Use the security group and subnet that was used to create the fsx filesystem
kwargs["security_group_ids"] = [fsx_security_group_id]
kwargs["subnets"] = [fsx_subnet]

### 2. Write a basic script

In [None]:
!mkdir fsx_scripts

In [None]:
%%writefile fsx_scripts/test.py

import argparse
import os

def parse_args():
 
 parser = argparse.ArgumentParser()

 # remember this environment variable needs to exactly match what you defined earlier
 parser.add_argument("--train_folder", type=str, default=os.environ["SM_CHANNEL_TRAIN"])
 
 args = parser.parse_args()
 
 return args

if __name__ == "__main__":
 
 print ('running your job!')
 
 args = parse_args()
 
 print ('train path looks like {}, now we will try an ls'.format(args.train_folder))
 
 cmd = 'ls {}'.format(args.train_folder)
 
 os.system(cmd)
 

### 3. Run on SageMaker training

In [None]:
import sagemaker
from sagemaker.pytorch import PyTorch

sess = sagemaker.Session()
role = sagemaker.get_execution_role()

bucket = sess.default_bucket()

estimator = PyTorch(
 entry_point="test.py",
 base_job_name="lustre-test",
 role=role,
 source_dir="fsx_scripts",
 # configures the SageMaker training resource, you can increase as you need
 instance_count=1,
 instance_type="ml.m5.large",
 py_version="py38",
 framework_version = '1.10',
 sagemaker_session=sess,
 debugger_hook_config=False,
 # enable warm pools for 60 minutes, useful for debugging
 keep_alive_period_in_seconds = 60 * 60,
 **kwargs
)

In [None]:
estimator.fit(inputs = data_channels, wait=False)