diff --git a/dataset.py b/dataset.py index a9243c0..95b745b 100644 --- a/dataset.py +++ b/dataset.py @@ -5,6 +5,7 @@ import numpy as np import json import argparse +from sklearn import preprocessing _DATA_DIR = './data' _TRAIN = 'train.csv' @@ -85,6 +86,9 @@ def preprocess(self, do_val_split=False): # Preprocessing operations go here. df['log_sum_revenue'] = self._make_log_sum_revenue() + tmp_geoNetwork = self._convert_geoNetwork_domain() + geoNetwork_columns = tmp_geoNetwork.columns + df[geoNetwork_columns] = tmp_geoNetwork[geoNetwork_columns] return df @@ -107,15 +111,40 @@ def _make_log_sum_revenue(self): train_revenue_log_sum = (train_revenue_sum + 1).apply(np.log) return train_revenue_log_sum - - + def _make_json_converter(self, column_name): """Helper function to interpret columns in PANDAS.""" return lambda x: {column_name: json.loads(x)} - - + def _convert_geoNetwork_domain(self): + """Ont hot encode domain, location, region, subContinent in geoNetwork. + Missing value automatically imputed by one hot encoder. + Standardize using normal distribution. + Group by fullVisitorID. + + Returns: + A DataFrame containing preprocessed geoNetwork Data with one hot encoding. + """ + train_df = self.train.copy(deep=False) + train_df.set_index('fullVisitorId', inplace=True) + to_encode = ['networkDomain', 'networkLocation', 'region', 'subContinent'] + results = pd.DataFrame(index=train_df.index.copy()) + for index, row in train_df.iterrows(): + for item in to_encode: + individual_key = 'geoNetwork.' + item + row[individual_key] = row['geoNetwork']['geoNetwork'][item] + for item in to_encode: + individual_key = 'geoNetwork.' + item + encoded = pd.get_dummies(train_df[individual_key], prefix=individual_key) + results = pd.concat([results, encoded], axis=1) + columns = results.columns + scaler = preprocessing.StandardScaler() + results[columns] = scaler.fit_transform(results[columns]) + results = results.groupby(results.index).agg('mean') + results[columns] = preprocessing.normalize(results[columns].values.astype(float), norm='l2') + return results + if __name__ == '__main__': parser = argparse.ArgumentParser(