import logging from boto3.dynamodb.conditions import Key from orchestrator.exceptions.ddb_client_exceptions import RecordAlreadyExistsException logger = logging.getLogger(__name__) class JoinDbClient(object): def __init__(self, table_session): self.table_session = table_session def check_join_job_record_exists(self, experiment_id, join_job_id): if self.get_join_job_record(experiment_id, join_job_id) is None: return False else: return True def get_join_job_record(self, experiment_id, join_job_id): response = self.table_session.query( ConsistentRead=True, KeyConditionExpression=Key("experiment_id").eq(experiment_id) & Key("join_job_id").eq(join_job_id), ) for i in response["Items"]: return i return None def create_new_join_job_record(self, record): try: self.table_session.put_item( Item=record, ConditionExpression="attribute_not_exists(join_job_id)" ) except Exception as e: if "ConditionalCheckFailedException" in str(e): raise RecordAlreadyExistsException() raise e def update_join_job_record(self, record): self.table_session.put_item(Item=record) def get_all_join_job_records_of_experiment(self, experiment_id): response = self.table_session.query( ConsistentRead=True, KeyConditionExpression=Key("experiment_id").eq(experiment_id) ) if response["Items"]: return response["Items"] else: return None def batch_delete_items(self, experiment_id, join_job_id_list): logger.warning("Deleting join job records of experiment...") with self.table_session.batch_writer() as batch: for join_job_id in join_job_id_list: logger.debug(f"Deleting join job record {join_job_id}...") batch.delete_item(Key={"experiment_id": experiment_id, "join_job_id": join_job_id}) def update_join_job_current_state(self, experiment_id, join_job_id, current_state): self.table_session.update_item( Key={"experiment_id": experiment_id, "join_job_id": join_job_id}, UpdateExpression=f"SET current_state = :val", ExpressionAttributeValues={":val": current_state}, ) def update_join_job_input_obs_data_s3_path( self, experiment_id, join_job_id, input_obs_data_s3_path ): self.table_session.update_item( Key={"experiment_id": experiment_id, "join_job_id": join_job_id}, UpdateExpression=f"SET input_obs_data_s3_path = :val", ExpressionAttributeValues={":val": input_obs_data_s3_path}, ) def update_join_job_input_reward_data_s3_path( self, experiment_id, join_job_id, input_reward_data_s3_path ): self.table_session.update_item( Key={"experiment_id": experiment_id, "join_job_id": join_job_id}, UpdateExpression=f"SET input_reward_data_s3_path = :val", ExpressionAttributeValues={":val": input_reward_data_s3_path}, ) def update_join_job_join_query_ids(self, experiment_id, join_job_id, join_query_ids): self.table_session.update_item( Key={"experiment_id": experiment_id, "join_job_id": join_job_id}, UpdateExpression=f"SET join_query_ids = :val", ExpressionAttributeValues={":val": join_query_ids}, ) def update_join_job_obs_end_time(self, experiment_id, join_job_id, obs_end_time): self.table_session.update_item( Key={"experiment_id": experiment_id, "join_job_id": join_job_id}, UpdateExpression=f"SET obs_end_time = :val", ExpressionAttributeValues={":val": obs_end_time}, ) def update_join_job_obs_start_time(self, experiment_id, join_job_id, obs_start_time): self.table_session.update_item( Key={"experiment_id": experiment_id, "join_job_id": join_job_id}, UpdateExpression=f"SET obs_start_time = :val", ExpressionAttributeValues={":val": obs_start_time}, ) def update_join_job_output_joined_eval_data_s3_path( self, experiment_id, join_job_id, output_joined_eval_data_s3_path ): self.table_session.update_item( Key={"experiment_id": experiment_id, "join_job_id": join_job_id}, UpdateExpression=f"SET output_joined_eval_data_s3_path = :val", ExpressionAttributeValues={":val": output_joined_eval_data_s3_path}, ) def update_join_job_output_joined_train_data_s3_path( self, experiment_id, join_job_id, output_joined_train_data_s3_path ): self.table_session.update_item( Key={"experiment_id": experiment_id, "join_job_id": join_job_id}, UpdateExpression=f"SET output_joined_train_data_s3_path = :val", ExpressionAttributeValues={":val": output_joined_train_data_s3_path}, )