# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the 'License'). You # may not use this file except in compliance with the License. A copy of # the License is located at # # http://aws.amazon.com/apache2.0/ # # or in the 'license' file accompanying this file. This file is # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. import cgi import csv import logging import os import shutil from typing import List, Union import mlio import numpy as np import pandas as pd import pyarrow.parquet as pq import xgboost as xgb from mlio.integ.arrow import as_arrow_file from mlio.integ.numpy import as_numpy from mlio.integ.scipy import to_coo_matrix from sagemaker_containers import _content_types from scipy.sparse import vstack as scipy_vstack from sagemaker_algorithm_toolkit import exceptions as exc from sagemaker_xgboost_container.constants import xgb_content_types BATCH_SIZE = 4000 CSV = "csv" LIBSVM = "libsvm" PARQUET = "parquet" RECORDIO_PROTOBUF = "recordio-protobuf" MAX_FOLDER_DEPTH = 3 VALID_CONTENT_TYPES = [ CSV, LIBSVM, PARQUET, RECORDIO_PROTOBUF, _content_types.CSV, xgb_content_types.LIBSVM, xgb_content_types.X_LIBSVM, xgb_content_types.X_PARQUET, xgb_content_types.X_RECORDIO_PROTOBUF, ] VALID_PIPED_CONTENT_TYPES = [ CSV, PARQUET, RECORDIO_PROTOBUF, _content_types.CSV, xgb_content_types.X_PARQUET, xgb_content_types.X_RECORDIO_PROTOBUF, ] INVALID_CONTENT_TYPE_ERROR = ( "{invalid_content_type} is not an accepted ContentType: " + ", ".join(["%s" % c for c in VALID_CONTENT_TYPES]) + "." ) INVALID_CONTENT_FORMAT_ERROR = ( "First line '{line_snippet}...' of file '{file_name}' is not " "'{content_type}' format. Please ensure the file is in '{content_type}' format." ) def _get_invalid_content_type_error_msg(invalid_content_type): return INVALID_CONTENT_TYPE_ERROR.format(invalid_content_type=invalid_content_type) def _get_invalid_libsvm_error_msg(line_snippet, file_name): return INVALID_CONTENT_FORMAT_ERROR.format(line_snippet=line_snippet, file_name=file_name, content_type="LIBSVM") def _get_invalid_csv_error_msg(line_snippet, file_name): return INVALID_CONTENT_FORMAT_ERROR.format(line_snippet=line_snippet, file_name=file_name, content_type="CSV") def get_content_type(content_type_cfg_val): """Get content type from data config. Assumes that training and validation data have the same content type. ['libsvm', 'text/libsvm ;charset=utf8', 'text/x-libsvm'] will return 'libsvm' ['csv', 'text/csv', 'text/csv; label_size=1'] will return 'csv' :param content_type_cfg_val :return: Parsed content type """ if content_type_cfg_val is None: return LIBSVM else: # cgi.parse_header extracts all arguments after ';' as key-value pairs # e.g. cgi.parse_header('text/csv;label_size=1;charset=utf8') returns # the tuple ('text/csv', {'label_size': '1', 'charset': 'utf8'}) content_type, params = cgi.parse_header(content_type_cfg_val.lower()) if content_type in [CSV, _content_types.CSV]: # CSV content type allows a label_size parameter # that should be 1 for XGBoost if params and "label_size" in params and params["label_size"] != "1": msg = ( "{} is not an accepted csv ContentType. " "Optional parameter label_size must be equal to 1".format(content_type_cfg_val) ) raise exc.UserError(msg) return CSV elif content_type in [LIBSVM, xgb_content_types.LIBSVM, xgb_content_types.X_LIBSVM]: return LIBSVM elif content_type in [PARQUET, xgb_content_types.X_PARQUET]: return PARQUET elif content_type in [RECORDIO_PROTOBUF, xgb_content_types.X_RECORDIO_PROTOBUF]: return RECORDIO_PROTOBUF else: raise exc.UserError(_get_invalid_content_type_error_msg(content_type_cfg_val)) def _is_data_file(file_path, file_name): """Return true if file name is a valid data file name. A file is valid if: * File name does not start with '.' or '_'. * File is not a XGBoost cache file. :param file_path: :param file_name: :return: bool """ if not os.path.isfile(os.path.join(file_path, file_name)): return False if file_name.startswith(".") or file_name.startswith("_"): return False # avoid XGB cache file if ".cache" in file_name: if "dtrain" in file_name or "dval" in file_name: return False return True def _get_csv_delimiter(sample_csv_line): try: delimiter = csv.Sniffer().sniff(sample_csv_line).delimiter logging.info("Determined delimiter of CSV input is '{}'".format(delimiter)) except Exception as e: raise exc.UserError("Could not determine delimiter on line {}:\n{}".format(sample_csv_line[:50], e)) return delimiter def _get_num_valid_libsvm_features(libsvm_line): """Get number of valid LIBSVM features. XGBoost expects the following LIBSVM format: