# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the 'License'). You # may not use this file except in compliance with the License. A copy of # the License is located at # # http://aws.amazon.com/apache2.0/ # # or in the 'license' file accompanying this file. This file is # distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from sagemaker_algorithm_toolkit import exceptions as exc from sagemaker_algorithm_toolkit import hyperparameter_validation as hpv from sagemaker_xgboost_container.constants.xgb_constants import ( XGB_MAXIMIZE_METRICS, XGB_MINIMIZE_METRICS, ) def initialize(metrics): @hpv.range_validator(["auto", "exact", "approx", "hist", "gpu_hist"]) def tree_method_range_validator(CATEGORIES, value): return value in CATEGORIES @hpv.dependencies_validator(["booster", "process_type"]) def updater_validator(value, dependencies): valid_tree_plugins = [ "grow_colmaker", "distcol", "grow_histmaker", "grow_skmaker", "sync", "refresh", "prune", "grow_quantile_histmaker", ] valid_tree_build_plugins = [ "grow_colmaker", "distcol", "grow_histmaker", "grow_colmaker", "grow_quantile_histmaker", ] valid_linear_plugins = ["shotgun", "coord_descent"] valid_process_update_plugins = ["refresh", "prune"] if dependencies.get("booster") == "gblinear": # validate only one linear updater is selected if not (len(value) == 1 and value[0] in valid_linear_plugins): raise exc.UserError( "Linear updater should be one of these options: {}.".format( ", ".join("'{0}'".format(valid_updater for valid_updater in valid_linear_plugins)) ) ) elif dependencies.get("process_type") == "update": if not all(x in valid_process_update_plugins for x in value): raise exc.UserError("process_type 'update' can only be used with updater 'refresh' and 'prune'") else: if not all(x in valid_tree_plugins for x in value): raise exc.UserError( "Tree updater should be selected from these options: 'grow_colmaker', 'distcol', 'grow_histmaker', " "'grow_skmaker', 'grow_quantile_histmaker', 'sync', 'refresh', 'prune', " "'shortgun', 'coord_descent'." ) # validate only one tree updater is selected counter = 0 for tmp in value: if tmp in valid_tree_build_plugins: counter += 1 if counter > 1: raise exc.UserError( "Only one tree grow plugin can be selected. Choose one from the" "following: 'grow_colmaker', 'distcol', 'grow_histmaker', " "'grow_skmaker'" ) @hpv.range_validator(["auto", "cpu_predictor", "gpu_predictor"]) def predictor_validator(CATEGORIES, value): return value in CATEGORIES @hpv.dependencies_validator(["num_class"]) def objective_validator(value, dependencies): num_class = dependencies.get("num_class") if value in ("multi:softmax", "multi:softprob") and num_class is None: raise exc.UserError("Require input for parameter 'num_class' for multi-classification") if value is None and num_class is not None: raise exc.UserError( "Do not need to setup parameter 'num_class' for learning task other than " "multi-classification." ) @hpv.range_validator(XGB_MAXIMIZE_METRICS + XGB_MINIMIZE_METRICS) def eval_metric_range_validator(SUPPORTED_METRIC, metric): if "