# Evaluation 
* Container: codna_pytorch_py39

## -1. 가상의 GT 작업 수행

In [None]:
import pathlib
import pickle

gt_list = [
 'rubout n she yn', 
 'rubout n she yn',
 'rubout n she yn',
 'rubout n she yn',
 'rubout n she yn',
 'rubout n she yn',
 'rubout n she yn',
 'rubout n she yn',
 'rubout n she yn',
 'rubout n she yn',
 'rubout n she yn',
 'rubout n she yn'
]


pathlib.Path("opt/ml/processing/input/manifest").mkdir(parents=True, exist_ok=True)
with open('opt/ml/processing/input/manifest/gt_manifest.pkl', 'wb') as f:
 pickle.dump(gt_list, f, protocol=pickle.HIGHEST_PROTOCOL)



## 1. 설정값 

In [None]:
select_date = '2023/03/23'
variant_name = 'AllTraffic'
tolerance = 0.5

## 2. 패키지 설치

In [None]:
import boto3
import sys

In [None]:
%%bash
chmod 777 /tmp
apt -y update && apt -y install sox libsox-fmt-all

pip install --no-cache-dir --upgrade pip
pip install --no-cache-dir -U omegaconf hydra-core librosa sentencepiece youtokentome inflect sox
pip install --no-cache-dir -U braceexpand webdataset editdistance jiwer jsonlines
pip install --no-cache-dir -U pytorch-lightning==1.9.4
pip install --no-cache-dir -U https://github.com/pyannote/pyannote-audio/archive/develop.zip
pip install --no-cache-dir git+https://github.com/huggingface/transformers
pip install --no-cache-dir git+https://github.com/NVIDIA/NeMo.git@main
pip install sagemaker-experiments

## 3. parameter store 설정

In [None]:
import os
import boto3

class parameter_store():
 
 def __init__(self, region_name="ap-northeast-2"):
 
 self.ssm = boto3.client('ssm', region_name=region_name)
 
 def put_params(self, key, value, dtype="String", overwrite=False, enc=False) -> str:
 
 #aws ssm put-parameter --name "RDS-MASTER-PASSWORD" --value 'PASSWORD' --type "SecureString"
 if enc: dtype="SecureString"
 if overwrite:
 strQuery = ''.join(['aws ssm put-parameter', ' --name ', '"', str(key), '"', ' --value ', '"', str(value), '"', ' --type ', '"', str(dtype), '"', ' --overwrite'])
 strResponse = os.popen(strQuery).read()
 
 if strResponse != '': return 'Store suceess'
 else: return 'Error'
 
 def get_params(self, key, enc=False):
 
 if enc: WithDecryption = True
 else: WithDecryption = False
 response = self.ssm.get_parameters(
 Names=[key,],
 WithDecryption=WithDecryption
 )
 
 return response['Parameters'][0]['Value']

 def get_all_params(self, ):

 response = self.ssm.describe_parameters(MaxResults=50)

 return [dicParam["Name"] for dicParam in response["Parameters"]]

 def delete_param(self, listParams):

 response = self.ssm.delete_parameters(
 Names=listParams
 )
 print (f" parameters: {listParams} is deleted successfully")

## 4. 기존 설정값 가져오기

In [None]:
strRegionName=boto3.Session().region_name
pm = parameter_store(strRegionName)

In [None]:
prefix = pm.get_params(key="PREFIX")
endpoint_name = pm.get_params(key='ENDPOINTNAME-lg-ramp-cyj-staging')
monitor_output = pm.get_params(key='MONITOROUTPUT-lg-ramp-cyj-staging')
bucket_name = pm.get_params(key=prefix + '-BUCKET')

## 5. 경로 설정하기

In [None]:

#s3://sagemaker-us-west-2-322537213286/nemo-prod/inference/monitor_output
#s3://sagemaker-us-west-2-322537213286/nemo-prod/inference/monitor_output/nemo-prod-nemo-experiments-0320-07331679297605/AllTraffic/2023/03/20/07/

inference_output_s3uri = os.path.join(
 monitor_output,
 endpoint_name,
 variant_name,
 select_date
)

gtmanifest_s3uri = os.path.join(
 "s3://{}".format(bucket_name),
 prefix,
 "gt-manifest"
)

output_s3uri = os.path.join(
 "s3://{}".format(bucket_name),
 prefix,
 "pred-output"
)

print (f"bucket_name: {bucket_name}")
print (f"endpoint_name: {endpoint_name}")
print (f"monitor_output: {monitor_output}")


In [None]:
# !aws s3 sync ./manifest $gtmanifest_s3uri

## 6. SageMaker Endpoint에서 설정한 Data Capture 정보 가져오기

In [None]:
!aws s3 sync $inference_output_s3uri 'opt/ml/processing/input/inference_data'

## 7. Evaluation 코드 가져오기

In [None]:
import os
import copy
import boto3
import logging
import json
import jsonlines # !pip install jsonlines 해주기
import torch
import tarfile

from tqdm.auto import tqdm
from omegaconf import open_dict

# import glob
import pickle
import sox
import time
import io
import soundfile as sf
import base64
import numpy as np
import pathlib
from sagemaker.s3 import S3Downloader
from datetime import datetime

from nemo.collections.asr.metrics.wer import word_error_rate
from nemo.collections.asr.parts.utils.transcribe_utils import PunctuationCapitalization
from nemo.core.config import hydra_runner
from nemo.utils import logging


use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

# logger = logging.getLogger()
# logger.setLevel(logging.INFO)
# logger.addHandler(logging.StreamHandler())


def find_checkpoint(model_dir):
 checkpoint_path = None
 for (root, dirs, files) in os.walk(model_dir):
 if len(files) > 0:
 for file_name in files:
 if file_name.endswith('last.ckpt'):
 checkpoint_path = root + '/' + file_name
 return checkpoint_path


def find_files(jsonl_dir):
 jsonl_list = []
 for (root, dirs, files) in os.walk(jsonl_dir):
 if len(files) > 0:
 for file_name in files:
 if file_name.endswith('jsonl'):
 jsonl_list.append(root + '/' + file_name)
 return jsonl_list


def read_manifest(path):
 manifest = []
 with open(path, 'r') as f:
 for line in tqdm(f, desc="Reading manifest data"):
 line = line.replace("\n", "")
 data = json.loads(line)
 manifest.append(data)
 return manifest


def write_processed_manifest(data, original_path):
 original_manifest_name = os.path.basename(original_path)
 new_manifest_name = original_manifest_name.replace(".json", "_processed.json")

 manifest_dir = os.path.split(original_path)[0]
 filepath = os.path.join(manifest_dir, new_manifest_name)
 with open(filepath, 'w') as f:
 for datum in tqdm(data, desc="Writing manifest data"):
 datum = json.dumps(datum)
 f.write(f"{datum}\n")
 print(f"Finished writing manifest: {filepath}")
 return filepath


def apply_preprocessors(manifest, preprocessors):
 for processor in preprocessors:
 for idx in tqdm(range(len(manifest)), desc=f"Applying {processor.__name__}"):
 manifest[idx] = processor(manifest[idx])

 print("Finished processing manifest !")
 return manifest


def change_dir(data):
 MANIFEST_PATH = os.environ['MANIFEST_PATH']
 WAV_PATH = os.environ['WAV_PATH']
 data['audio_filepath'] = data['audio_filepath'].replace(MANIFEST_PATH, WAV_PATH)
 return data


def predict(asr_model, predictions, targets, target_lengths, predictions_lengths=None):
 references = []
 with torch.no_grad():
 # prediction_cpu_tensor = tensors[0].long().cpu()
 targets_cpu_tensor = targets.long().cpu()
 tgt_lenths_cpu_tensor = target_lengths.long().cpu()

 # iterate over batch
 for ind in range(targets_cpu_tensor.shape[0]):
 tgt_len = tgt_lenths_cpu_tensor[ind].item()
 target = targets_cpu_tensor[ind][:tgt_len].numpy().tolist()
 reference = asr_model.decoding.decode_tokens_to_str(target)
 references.append(reference)

 hypotheses, _ = asr_model.decoding.ctc_decoder_predictions_tensor(
 predictions, predictions_lengths, fold_consecutive=True
 )
 return references[0], hypotheses[0]

def start_retraining_codepipeline():
 # region_name = os.environ["region"]
 sm_client = boto3.client('sagemaker', region_name=strRegionName)
 pipeline_client = boto3.client('codepipeline', region_name=strRegionName)

 response = sm_client.list_projects(
 SortBy='CreationTime',
 SortOrder='Descending'
 )

 for pjt_list in response['ProjectSummaryList']:
 if pjt_list['ProjectStatus'] == 'CreateCompleted':
 ProjectName = pjt_list['ProjectName']
 break

 des_response = sm_client.describe_project(
 ProjectName=ProjectName
 )

 code_pipeline_name = f"sagemaker-{des_response['ProjectName']}-{des_response['ProjectId']}-modelbuild"
 pipeline_client.start_pipeline_execution(name=code_pipeline_name) 
 print("Start retraining ........")
 
def main():
 
 reference_list = []
 predicted_list = [] 
 
 
 # select_date = os.environ["select_date"]

 # output_list = S3Downloader.list(inference_output_s3uri + f'/output_monitor/{endpoint_name}/{target_model}/{select_date}')
 output_list = find_files('opt/ml/processing/input/inference_data')
 print(f"output_list: {output_list}")
 with open('opt/ml/processing/input/manifest/gt_manifest.pkl', 'rb') as f:
 gt_list = pickle.load(f)


 result_data = []

 train_mount_dir=f"opt/ml/input/data/training/"
 test_mount_dir=f"opt/ml/input/data/testing/"
 manifest_path = f"opt/ml/processing/output/{select_date}/manifest"
 manifest_file = f"{manifest_path}/test_manifest.json"
 result_wav_file = f"opt/ml/processing/output/{select_date}/wav"

 pathlib.Path(manifest_path).mkdir(parents=True, exist_ok=True)
 pathlib.Path(result_wav_file).mkdir(parents=True, exist_ok=True)

 seq = 0
 with open(manifest_file, 'w') as fout:
 for json_list in output_list:
 # Read a specific file

 fname = json_list.split('/')[-1]
 fname = fname.split('.')[0]
 f_date = select_date.replace('/','-')
 

 with jsonlines.open(json_list) as read_file:
 for res in read_file.iter():
 filename = f"{result_wav_file}/{f_date}-{fname}-{seq}.wav"
 sf_data, samplerate = sf.read(io.BytesIO(base64.b64decode(res['captureData']['endpointInput']['data'])))
 sf.write(file=filename, data=sf_data, samplerate=samplerate)

 np_val = base64.b64decode(res['captureData']['endpointOutput']['data'])
 np_val = json.loads(np_val)
 
 # transcript = '-'.join(np_val['result'])
 transcript = ' '.join(np_val['result'])
 # transcript = np_val['result']
 predicted_list.append(transcript)
 reference_list.append(gt_list[seq])

 print(f"predicted_list : {predicted_list}")
 print(f"reference_list : {reference_list}")
 mounted_audio_path = filename.replace(result_wav_file, test_mount_dir)

 # import sox here to not require sox to be available for importing all utils.
 duration = sox.file_info.duration(filename)

 # Write the metadata to the manifest
 metadata = {"audio_filepath": mounted_audio_path, "duration": duration, "pred_text": transcript}
 json.dump(metadata, fout)
 fout.write('\n')
 seq += 1
 

 pc = PunctuationCapitalization('.,?')
 reference_list = pc.separate_punctuation(reference_list)
 reference_list = pc.do_lowercase(reference_list)
 predicted_list = pc.do_lowercase(predicted_list)
 reference_list = pc.rm_punctuation(reference_list)
 predicted_list = pc.rm_punctuation(predicted_list)
 

 # Compute the WER
 cer = word_error_rate(hypotheses=predicted_list, references=reference_list, use_cer=True)
 wer = word_error_rate(hypotheses=predicted_list, references=reference_list, use_cer=False)

 use_cer = False

 if use_cer:
 metric_name = 'CER'
 metric_value = cer
 else:
 metric_name = 'WER'
 metric_value = wer

 print(f" tolerance : {tolerance}")
 print(f" tolerance : {type(tolerance)}")
 print(f" metric_value : {metric_value}")
 print(f" metric_value : {type(metric_value)}")


 if tolerance is not None:
 if metric_value > tolerance:
 print(f"Got {metric_name} of {metric_value}, which was higher than tolerance={tolerance}")
 start_retraining_codepipeline()

 print(f'Got {metric_name} of {metric_value}. Tolerance was {tolerance}')
 else:
 print(f'Got {metric_name} of {metric_value}')

 print(f'Dataset WER/CER ' + str(round(100 * wer, 2)) + "%/" + str(round(100 * cer, 2)) + "%")

 wer_result = wer


if __name__ == '__main__':
 main()

In [None]:
output_s3uri

In [None]:
!aws s3 sync opt/ml/processing/input $output_s3uri

In [None]:
## Retraining
pm.put_params(key="-".join([prefix, "RETRAIN"]), value=True, overwrite=True)
