Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions docs/sources/user_guide/plotting/plot_decision_regions_3d.ipynb

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions mlxtend/plotting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
#
# License: BSD 3 clause


from .checkerboard import checkerboard_plot
from .decision_regions import plot_decision_regions
from .decision_regions_3d import plot_decision_regions_3d
from .ecdf import ecdf
from .enrichment_plot import enrichment_plot
from .heatmap import heatmap
Expand All @@ -23,6 +25,7 @@
__all__ = [
"plot_learning_curves",
"plot_decision_regions",
"plot_decision_regions_3d",
"plot_confusion_matrix",
"plot_sequential_feature_selection",
"plot_linear_regression",
Expand Down
54 changes: 54 additions & 0 deletions mlxtend/plotting/decision_regions_3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import matplotlib.pyplot as plt
import numpy as np


def plot_decision_regions_3d(
X,
y,
clf,
z_slices,
feature_index=(0, 1, 2),
ax=None,
res=0.02,
scatter_points=True,
alpha=0.3,
):
"""
Stack 2D decision regions in a 3D space.
"""
if ax is None:
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection="3d")
colors = ("red", "blue", "lightgreen", "gray", "cyan")
markers = ("s", "x", "o", "^", "v")

x_min, x_max = X[:, feature_index[0]].min() - 1, X[:, feature_index[0]].max() + 1
y_min, y_max = X[:, feature_index[1]].min() - 1, X[:, feature_index[1]].max() + 1

xx, yy = np.meshgrid(np.arange(x_min, x_max, res), np.arange(y_min, y_max, res))

for z_val in z_slices:
n_points = xx.ravel().shape[0]
grid_points = np.c_[xx.ravel(), yy.ravel(), np.full(n_points, z_val)]

Z = clf.predict(grid_points)
Z = Z.reshape(xx.shape)

ax.contourf(xx, yy, Z, zdir="z", offset=z_val, alpha=alpha, cmap="RdYlBu")
if scatter_points:
for idx, cl in enumerate(np.unique(y)):
ax.scatter(
X[y == cl, feature_index[0]],
X[y == cl, feature_index[1]],
X[y == cl, feature_index[2]],
alpha=0.8,
c=colors[idx % len(colors)],
marker=markers[idx % len(markers)],
label=f"Class {cl}",
)
ax.set_xlabel(f"Feature {feature_index[0]}")
ax.set_ylabel(f"Feature {feature_index[1]}")
ax.set_zlabel(f"Feature {feature_index[2]}")
ax.legend(loc="upper left")

return ax
18 changes: 18 additions & 0 deletions mlxtend/plotting/tests/test_decision_regions_3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import matplotlib.pyplot as plt
import numpy as np
import pytest
from sklearn.svm import SVC

from mlxtend.plotting import plot_decision_regions_3d


def test_plot_decision_regions_3d():
X = np.array([[1, 2, 3], [4, 5, 6], [1, 1, 1], [5, 5, 5]])
y = np.array([0, 1, 0, 1])
clf = SVC().fit(X, y)

try:
plot_decision_regions_3d(X, y, clf, z_slices=[1, 3, 5])
plt.close()
except Exception as e:
pytest.fail(f"3D plotting failed: {e}")
Loading