diff --git a/mlxtend/utils/checking.py b/mlxtend/utils/checking.py index 38ef1e635..35ba53cd7 100644 --- a/mlxtend/utils/checking.py +++ b/mlxtend/utils/checking.py @@ -16,7 +16,7 @@ def check_Xy(X, y, y_int=True): if not isinstance(y, np.ndarray): raise ValueError("y must be a NumPy array. Found %s" % type(y)) - if "int" not in str(y.dtype): + if y_int and "int" not in str(y.dtype): raise ValueError( "y must be an integer array. Found %s. " "Try passing the array as y.astype(np.int_)" % y.dtype