Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 131 additions & 0 deletions docs/sources/user_guide/utils/serialization.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"id": "7ca989ce",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Functions imported correctly!\n"
]
}
],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"import sys\n",
"import os\n",
"\n",
"path = os.path.abspath(os.path.join(os.getcwd(), \"..\", \"..\", \"..\", \"..\"))\n",
"if path not in sys.path:\n",
" sys.path.insert(0, path)\n",
"\n",
"from mlxtend.utils import save_model_to_json, load_model_from_json\n",
"print(\"Functions imported correctly!\")"
]
},
{
"cell_type": "markdown",
"id": "8f29bed8",
"metadata": {},
"source": [
"# Serialization - JSON Utilities\n",
"\n",
"The `serialization` module contains utility functions to export and import models in a human-readable JSON format.\n",
"\n",
"> from mlxtend.utils import save_model_to_json\n",
"> from mlxtend.utils import load_model_from_json\n",
"\n",
"## Overview\n",
"\n",
"While Python's `pickle` is a common way to serialize models, it has security risks and versioning issues. The JSON serialization utilities in `mlxtend` provide a transparent alternative that:\n",
"\n",
"1. Saves model parameters and fitted attributes in a readable text format.\n",
"2. Handles NumPy arrays and specialized types automatically.\n",
"3. Allows model reconstruction without manual class instantiation."
]
},
{
"cell_type": "markdown",
"id": "bdff3eac",
"metadata": {},
"source": [
"## Example 1 - Saving and Loading a Perceptron\n",
"\n",
"This example demonstrates training a simple `Perceptron` classifier, saving it to a JSON file, and then loading it back to verify that the state is preserved."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "182144de",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Original model weights: [[0.2 0.2]]\n",
"\n",
"[Model saved to model_data.json]\n",
"Loaded model weights: [[0.2 0.2]]\n",
"\n",
" Predictions and weights are identical!\n"
]
}
],
"source": [
"from sklearn.linear_model import Perceptron\n",
"import numpy as np\n",
"import os\n",
"\n",
"X = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])\n",
"y = np.array([0, 0, 0, 1])\n",
"\n",
"ppn = Perceptron(eta0=0.1, random_state=1)\n",
"ppn.fit(X, y)\n",
"\n",
"print(\"Original model weights:\", ppn.coef_)\n",
"\n",
"save_model_to_json(ppn, 'model_data.json')\n",
"print(\"\\n[Model saved to model_data.json]\")\n",
"\n",
"ppn_loaded = load_model_from_json('model_data.json')\n",
"\n",
"print(\"Loaded model weights: \", ppn_loaded.coef_)\n",
"\n",
"if np.array_equal(ppn.coef_, ppn_loaded.coef_):\n",
" print(\"\\n Predictions and weights are identical!\")\n",
"\n",
"if os.path.exists('model_data.json'):\n",
" os.remove('model_data.json')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
11 changes: 10 additions & 1 deletion mlxtend/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,17 @@
#
# License: BSD 3 clause


from .checking import check_Xy, format_kwarg_dictionaries
from .counter import Counter
from .serialization import load_model_from_json, save_model_to_json
from .testing import assert_raises

__all__ = ["Counter", "assert_raises", "check_Xy", "format_kwarg_dictionaries"]
__all__ = [
"Counter",
"assert_raises",
"check_Xy",
"format_kwarg_dictionaries",
"save_model_to_json",
"load_model_from_json",
]
49 changes: 49 additions & 0 deletions mlxtend/utils/serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import importlib
import json

import numpy as np


class MlxtendEncoder(json.JSONEncoder):
"""Custom JSON encoder to handle numpy types and fallback to strings."""

def default(self, obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
if isinstance(obj, (np.integer, np.int64, np.int32)):
return int(obj)
if isinstance(obj, (np.floating, np.float64, np.float32)):
return float(obj)
try:
return super(MlxtendEncoder, self).default(obj)
except TypeError:
return str(obj)


def save_model_to_json(model, filename):
"""Save an mlxtend estimator to a JSON file."""
model_data = model.__dict__.copy()
model_data["__module__"] = model.__class__.__module__
model_data["__class__"] = model.__class__.__name__

with open(filename, "w") as f:
json.dump(model_data, f, cls=MlxtendEncoder, indent=4)


def load_model_from_json(filename):
"""Load an mlxtend estimator from a JSON file."""
with open(filename, "r") as f:
model_data = json.load(f)
module_name = model_data.pop("__module__")
class_name = model_data.pop("__class__")

module = importlib.import_module(module_name)
class_ = getattr(module, class_name)
model_instance = class_()

for key, value in model_data.items():
if isinstance(value, list) and key.endswith("_"):
setattr(model_instance, key, np.array(value))
else:
setattr(model_instance, key, value)
return model_instance
42 changes: 42 additions & 0 deletions mlxtend/utils/tests/test_serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import os

import numpy as np
import pytest

from mlxtend.classifier import Perceptron
from mlxtend.utils.serialization import load_model_from_json, save_model_to_json


def test_serialization_perceptron():
X = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
y = np.array([0, 1, 1, 1])

ppn = Perceptron(epochs=5, eta=0.1)
ppn.fit(X, y)

filename = "temp_ppn_model.json"

try:
save_model_to_json(ppn, filename)
assert os.path.exists(filename)

ppn_loaded = load_model_from_json(filename)

assert ppn.__class__ == ppn_loaded.__class__
np.testing.assert_array_almost_equal(ppn.w_, ppn_loaded.w_)
orig_pred = ppn.predict(X)
load_pred = ppn_loaded.predict(X)
np.testing.assert_array_equal(orig_pred, load_pred)
finally:
if os.path.exists(filename):
os.remove(filename)


def test_encoder_fallback():
import json

from mlxtend.utils.serialization import MlxtendEncoder

data = {"complex_obj": open}
encoded = json.dumps(data, cls=MlxtendEncoder)
assert "built-in function open" in encoded
Loading