# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: MIT-0 import os import argparse import sys import subprocess from pathlib import Path import numpy as np import pandas as pd from sklearn.model_selection import train_test_split from smexperiments.tracker import Tracker import dvc.api from git.repo.base import Repo # Prepare paths input_data_path = os.path.join("/opt/ml/processing/input", "dataset.csv") data_path = 'dataset' base_dir = f"./sagemaker-dvc-sample/{data_path}" file_types = ['test','train','validation'] dvc_repo_url = os.environ.get('DVC_REPO_URL') dvc_branch = os.environ.get('DVC_BRANCH') user = os.environ.get('USER', "sagemaker") def configure_git(): subprocess.check_call(['git', 'config', '--global', 'user.email', '"sagemaker-processing@example.com"']) subprocess.check_call(['git', 'config', '--global', 'user.name', user]) def split_dataframe(df, num=5): chunk_size = int(df.shape[0] / num) chunks = [df.iloc[i:i+chunk_size] for i in range(0,df.shape[0], chunk_size)] return chunks def clone_dvc_git_repo(): print(f"Cloning repo: {dvc_repo_url}") repo = Repo.clone_from(dvc_repo_url, './sagemaker-dvc-sample') return repo def generate_train_validation_files(ratio): for path in ['train', 'validation', 'test']: output_dir = Path(f"{base_dir}/{path}/") output_dir.mkdir(parents=True, exist_ok=True) print("Read dataset") dataset = pd.read_csv(input_data_path) train, other = train_test_split(dataset, test_size=ratio) validation, test = train_test_split(other, test_size=ratio) print("create train, validation, test") for index, chunk in enumerate(split_dataframe(pd.DataFrame(train))): chunk.to_csv(f"{base_dir}/train/california_train_{index + 1}.csv", header=False, index=False) for index, chunk in enumerate(split_dataframe(pd.DataFrame(validation), 3)): chunk.to_csv(f"{base_dir}/validation/california_validation_{index + 1}.csv", header=False, index=False) pd.DataFrame(test).to_csv(f"{base_dir}/test/california_test.csv", header=False, index=False) print("data created") def sync_data_with_dvc(repo): os.chdir(base_dir) print(f"Create branch {dvc_branch}") try: repo.git.checkout('-b', dvc_branch) print(f"Create a new branch: {dvc_branch}") except: repo.git.checkout(dvc_branch) print(f"Checkout existing branch: {dvc_branch}") print("Add files to DVC") for file_type in file_types: subprocess.check_call(['dvc', 'add', f"{file_type}/"]) repo.git.add(all=True) repo.git.commit('-m', f"'add data for {dvc_branch}'") print("Push data to DVC") subprocess.check_call(['dvc', 'push']) print("Push dvc metadata to git") repo.remote(name='origin') repo.git.push('--set-upstream', repo.remote().name, dvc_branch, '--force') sha = repo.head.commit.hexsha print(f"commit hash: {sha}") with Tracker.load() as tracker: tracker.log_parameters({"data_commit_hash": sha}) for file_type in file_types: path = dvc.api.get_url( f"{data_path}/{file_type}", repo=dvc_repo_url, rev=dvc_branch ) tracker.log_output(name=f"{file_type}",value=path) if __name__=="__main__": parser = argparse.ArgumentParser() parser.add_argument("--train-test-split-ratio", type=float, default=0.3) args, _ = parser.parse_known_args() train_test_split_ratio = args.train_test_split_ratio with Tracker.load() as tracker: tracker.log_parameters( { "train_test_split_ratio": train_test_split_ratio } ) configure_git() repo = clone_dvc_git_repo() generate_train_validation_files(train_test_split_ratio) sync_data_with_dvc(repo)