Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions mlxtend/evaluate/counterfactual.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#
# License: BSD 3 clause


import warnings

import numpy as np
Expand All @@ -19,6 +20,7 @@ def create_counterfactual(
y_desired_proba=None,
lammbda=0.1,
random_seed=None,
feature_names_to_vary=None,
):
"""
Implementation of the counterfactual method by Wachter et al. 2017
Expand Down Expand Up @@ -79,19 +81,21 @@ class probability for `y_desired`.
)
else:
use_proba = False

if y_desired_proba is None:
# class label

y_to_be_annealed_to = y_desired
else:
# class proba corresponding to class label y_desired
y_to_be_annealed_to = y_desired_proba

y_to_be_annealed_to = y_desired_proba
# start with random counterfactual

rng = np.random.RandomState(random_seed)
x_counterfact = X_dataset[rng.randint(X_dataset.shape[0])]

# compute median absolute deviation

mad = np.abs(np.median(X_dataset, axis=0) - x_reference)

def dist(x_reference, x_counterfact):
Expand All @@ -105,7 +109,6 @@ def loss(x_counterfact, lammbda):
]
else:
y_predict = model.predict(x_counterfact.reshape(1, -1))

diff = lammbda * (y_predict - y_to_be_annealed_to) ** 2

return diff + dist(x_reference, x_counterfact)
Expand All @@ -114,7 +117,6 @@ def loss(x_counterfact, lammbda):

if not res["success"]:
warnings.warn(res["message"])

x_counterfact = res["x"]

return x_counterfact
5 changes: 2 additions & 3 deletions mlxtend/evaluate/feature_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,8 @@ def score_func(y_true, y_pred):
save_col = X[:, feat].copy()

if save_col.ndim > 1:
columns = save_col.shape[1]
for i in range(columns):
rng.shuffle(X[:, i])
shuffled_indices = rng.permutation(X.shape[0])
X[:, feat] = X[shuffled_indices][:, feat]
else:
rng.shuffle(X[:, feat])

Expand Down
65 changes: 56 additions & 9 deletions mlxtend/frequent_patterns/association_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

def association_rules(
df: pd.DataFrame,
num_itemsets: Optional[int] = 1,
num_itemsets: Optional[int] = None,
df_orig: Optional[pd.DataFrame] = None,
null_values=False,
metric="confidence",
Expand Down Expand Up @@ -120,15 +120,20 @@ def association_rules(
raise TypeError("If null values exist, df_orig must be provided.")

# if null values exist, num_itemsets must be provided
if null_values and num_itemsets == 1:
if null_values and df_orig is None and num_itemsets is None:
raise TypeError("If null values exist, num_itemsets must be provided.")

if num_itemsets is None:
if df_orig is not None:
num_itemsets = len(df_orig)
else:
num_itemsets = 1
# check for valid input
fpc.valid_input_check(df_orig, null_values)

if not df.shape[0]:
raise ValueError(
"The input DataFrame `df` containing " "the frequent itemsets is empty."
"The input DataFrame `df` containing the frequent itemsets is empty."
)

# check for mandatory columns
Expand Down Expand Up @@ -188,8 +193,22 @@ def certainty_metric_helper(sAC, sA, sC, disAC, disA, disC, dis_int, dis_int_):

# metrics for association rules
metric_dict = {
"antecedent support": lambda _, sA, ___, ____, _____, ______, _______, ________: sA,
"consequent support": lambda _, __, sC, ____, _____, ______, _______, ________: sC,
"antecedent support": lambda _,
sA,
___,
____,
_____,
______,
_______,
________: sA,
"consequent support": lambda _,
__,
sC,
____,
_____,
______,
_______,
________: sC,
"support": lambda sAC, _, __, ___, ____, _____, ______, _______: sAC,
"confidence": lambda sAC, sA, _, disAC, disA, __, dis_int, ___: (
sAC * (num_itemsets - disAC)
Expand All @@ -207,19 +226,47 @@ def certainty_metric_helper(sAC, sA, sC, disAC, disA, disC, dis_int, dis_int_):
"support"
](sAC, sA, sC, disAC, disA, disC, dis_int, dis_int_)
- sA * sC,
"conviction": lambda sAC, sA, sC, disAC, disA, disC, dis_int, dis_int_: conviction_helper(
"conviction": lambda sAC,
sA,
sC,
disAC,
disA,
disC,
dis_int,
dis_int_: conviction_helper(
metric_dict["confidence"](
sAC, sA, sC, disAC, disA, disC, dis_int, dis_int_
),
sC,
),
"zhangs_metric": lambda sAC, sA, sC, disAC, disA, disC, dis_int, dis_int_: zhangs_metric_helper(
"zhangs_metric": lambda sAC,
sA,
sC,
disAC,
disA,
disC,
dis_int,
dis_int_: zhangs_metric_helper(
sAC, sA, sC, disAC, disA, disC, dis_int, dis_int_
),
"jaccard": lambda sAC, sA, sC, _, __, ____, _____, ______: jaccard_metric_helper(
"jaccard": lambda sAC,
sA,
sC,
_,
__,
____,
_____,
______: jaccard_metric_helper(
sAC, sA, sC, disAC, disA, disC, dis_int, dis_int_
),
"certainty": lambda sAC, sA, sC, _, __, ____, _____, ______: certainty_metric_helper(
"certainty": lambda sAC,
sA,
sC,
_,
__,
____,
_____,
______: certainty_metric_helper(
sAC, sA, sC, disAC, disA, disC, dis_int, dis_int_
),
"kulczynski": lambda sAC, sA, sC, _, __, ____, _____, ______: kulczynski_helper(
Expand Down
Loading