# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of # the License is located at # # http://aws.amazon.com/apache2.0/ # # or in the "license" file accompanying this file. This file is # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from collections import defaultdict from collections import namedtuple from collections import OrderedDict Indices = namedtuple("Indices", field_names=("column_index", "feature_index")) class Header: """ A utility class to manage the header and target column. The header contains the names for all columns in a dataset including the target column. This class validates the header, checking for presence of duplicate column names and absence of target column name. This class provides functionality to translate the column names to column indices (data set including target column) and feature indices (data set excluding target column) respectively. This class is used in the code generated by the SageMaker Pipeline Recommender algorithm. Usage ------ >>> h = Header(column_names=['a', 'b', 'c'], target_column_name='b') >>> h.as_column_indices(['a', 'c']) [0, 2] >>> h.as_feature_indices(['a', 'c']) [0, 1] >>> h.target_column_name b >>> h.target_column_index 1 >>> h.as_column_indices(['b']) [1] """ def __init__(self, column_names: list, target_column_name: str): """ Parameters ---------- column_names : iterable of the column names in the order of occurrence target_column_name : str, name of the target column Raises ------ ValueError : target_column_name is not present in column_names or duplicate entries found in column_names """ self.target_column_index = None self.target_column_name = target_column_name # maintaining a dict{column_name: Indices} self._column_name_indices = OrderedDict() feature_index_offset = 0 duplicate_column_indices = defaultdict(list) for i, column_name in enumerate(column_names): # already seen the column, add to duplicate_column_indices if column_name in self._column_name_indices: duplicate_column_indices[column_name].append(i) else: self._column_name_indices[column_name] = Indices(column_index=i, feature_index=i - feature_index_offset) # if it's target column, setup target_index and adjust the feature index # offset for following features columns if column_name == target_column_name: self.target_column_index = i feature_index_offset = 1 self._column_name_indices[column_name] = Indices(column_index=i, feature_index=None) if self.target_column_index is None: raise ValueError( "Specified target column '{target_column_name}' is " "not a valid column name.".format(target_column_name=target_column_name) ) if duplicate_column_indices: raise ValueError( "Duplicate column names were found:\n{}".format( "\n".join( [ "{name} at index {index}".format(name=name, index=index) for (name, index) in duplicate_column_indices.items() ] ) ) ) def as_feature_indices(self, column_names: list) -> list: """ Returns list of feature indices for the given column names. Parameters ---------- column_names : iterable containing feature names Returns ------- feature_indices : iterable containing the indices corresponding to column_names, assuming target column excluded. Raises ------ ValueError : At least one of the items in column_names is not a feature name. """ def _index(name): if self.target_column_name == name: raise ValueError( "'{}' is the target column name. " "It cannot be converted to feature index.".format(name) ) try: return self._column_name_indices[name].feature_index except KeyError: raise ValueError("'{}' is an unknown feature name".format(name)) return [_index(name) for name in column_names] def as_column_indices(self, column_names: list) -> list: """ Returns list of indices for the given column names. Parameters ---------- column_names : iterable containing column names Returns ------- column_indices : iterable containing the indices corresponding to column names, assuming target column is included in the data. Raises ------ ValueError : Unknown column name is found in column_names """ def _index(name): try: return self._column_name_indices[name].column_index except KeyError: raise ValueError("'{}' is an unknown column name.".format(name)) return [_index(name) for name in column_names] @property def feature_column_indices(self): """Returns list of feature column indices in the order in which they were provided. The order of the indices is determined by the ``column_names`` parameter. Returns ------- feature_column_indices : list of int """ return [ index_instance.column_index for index_instance in self._column_name_indices.values() if index_instance.feature_index is not None ] @property def num_columns(self): """ Returns number of columns including target column. Returns ------- num_columns : integer, Number of columns. """ return len(self._column_name_indices) @property def num_features(self): """ Returns number of features, i.e. the number of columns excluding target column. Returns ------- num_features : integer, Number of features. """ return len(self._column_name_indices) - 1