From 2e0662ab140fd213fbe5b21dc13e93a5a8544f9b Mon Sep 17 00:00:00 2001 From: Henry Schreiner Date: Fri, 13 Jun 2025 02:10:49 -0400 Subject: [PATCH 01/10] feat: add NumPy scalars Co-authored-by: Steve R. Sun <1638650145@qq.com> Signed-off-by: Henry Schreiner --- docs/advanced/pycpp/numpy.rst | 40 +++++++ include/pybind11/numpy.h | 196 +++++++++++++++++++++++++++------- tests/CMakeLists.txt | 1 + tests/test_numpy_scalars.cpp | 52 +++++++++ tests/test_numpy_scalars.py | 62 +++++++++++ 5 files changed, 315 insertions(+), 36 deletions(-) create mode 100644 tests/test_numpy_scalars.cpp create mode 100644 tests/test_numpy_scalars.py diff --git a/docs/advanced/pycpp/numpy.rst b/docs/advanced/pycpp/numpy.rst index d09a2cea2c..29638eb821 100644 --- a/docs/advanced/pycpp/numpy.rst +++ b/docs/advanced/pycpp/numpy.rst @@ -232,6 +232,46 @@ prevent many types of unsupported structures, it is still the user's responsibility to use only "plain" structures that can be safely manipulated as raw memory without violating invariants. +Scalar types +============ + +In some cases we may want to accept or return NumPy scalar values such as +``np.float32`` or ``np.float64``. We hope to be able to handle single-precision +and double-precision on the C-side. However, both are bound to Python's +double-precision builtin float by default, so they cannot be processed separately. +We used the ``py::buffer`` trick to implement the previous approach, which +will cause the readability of the code to drop significantly. + +Luckily, there's a helper type for this occasion - ``py::numpy_scalar``: + +.. code-block:: cpp + + m.def("add", [](py::numpy_scalar a, py::numpy_scalar b) { + return py::make_scalar(a + b); + }); + m.def("add", [](py::numpy_scalar a, py::numpy_scalar b) { + return py::make_scalar(a + b); + }); + +This type is trivially convertible to and from the type it wraps; currently +supported scalar types are NumPy arithmetic types: ``bool_``, ``int8``, +``int16``, ``int32``, ``int64``, ``uint8``, ``uint16``, ``uint32``, +``uint64``, ``float32``, ``float64``, ``complex64``, ``complex128``, all of +them mapping to respective C++ counterparts. + +.. note:: + + This is a strict type, it will only allows to specify NumPy type as input + arguments, and does not allow other types of input parameters (e.g., + ``py::numpy_scalar`` will not accept Python's builtin ``int`` ). + +.. note:: + + Native C types are mapped to NumPy types in a platform specific way: for + instance, ``char`` may be mapped to either ``np.int8`` or ``np.uint8`` + and ``long`` may use 4 or 8 bytes depending on the platform. Unless you + clearly understand the difference and your needs, please use ````. + Vectorizing functions ===================== diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index f65b5c9d7a..1e971fb311 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -49,6 +49,9 @@ PYBIND11_WARNING_DISABLE_MSVC(4127) class dtype; // Forward declaration class array; // Forward declaration +template +struct numpy_scalar; // Forward declaration + PYBIND11_NAMESPACE_BEGIN(detail) template <> @@ -245,6 +248,21 @@ struct npy_api { NPY_UINT64_ = platform_lookup( NPY_ULONG_, NPY_ULONGLONG_, NPY_UINT_), + NPY_FLOAT32_ = platform_lookup( + NPY_DOUBLE_, NPY_FLOAT_, NPY_LONGDOUBLE_), + NPY_FLOAT64_ = platform_lookup( + NPY_DOUBLE_, NPY_FLOAT_, NPY_LONGDOUBLE_), + NPY_COMPLEX64_ + = platform_lookup, + std::complex, + std::complex, + std::complex>(NPY_DOUBLE_, NPY_FLOAT_, NPY_LONGDOUBLE_), + NPY_COMPLEX128_ + = platform_lookup, + std::complex, + std::complex, + std::complex>(NPY_DOUBLE_, NPY_FLOAT_, NPY_LONGDOUBLE_), + NPY_CHAR_ = std::is_signed::value ? NPY_BYTE_ : NPY_UBYTE_, }; unsigned int PyArray_RUNTIME_VERSION_; @@ -268,6 +286,7 @@ struct npy_api { unsigned int (*PyArray_GetNDArrayCFeatureVersion_)(); PyObject *(*PyArray_DescrFromType_)(int); + PyObject *(*PyArray_TypeObjectFromType_)(int); PyObject *(*PyArray_NewFromDescr_)(PyTypeObject *, PyObject *, int, @@ -284,6 +303,8 @@ struct npy_api { PyTypeObject *PyVoidArrType_Type_; PyTypeObject *PyArrayDescr_Type_; PyObject *(*PyArray_DescrFromScalar_)(PyObject *); + PyObject *(*PyArray_Scalar_)(void *, PyObject *, PyObject *); + void (*PyArray_ScalarAsCtype_)(PyObject *, void *); PyObject *(*PyArray_FromAny_)(PyObject *, PyObject *, int, int, int, PyObject *); int (*PyArray_DescrConverter_)(PyObject *, PyObject **); bool (*PyArray_EquivTypes_)(PyObject *, PyObject *); @@ -301,7 +322,10 @@ struct npy_api { API_PyArrayDescr_Type = 3, API_PyVoidArrType_Type = 39, API_PyArray_DescrFromType = 45, + API_PyArray_TypeObjectFromType = 46, API_PyArray_DescrFromScalar = 57, + API_PyArray_Scalar = 60, + API_PyArray_ScalarAsCtype = 62, API_PyArray_FromAny = 69, API_PyArray_Resize = 80, // CopyInto was slot 82 and 50 was effectively an alias. NumPy 2 removed 82. @@ -336,7 +360,10 @@ struct npy_api { DECL_NPY_API(PyVoidArrType_Type); DECL_NPY_API(PyArrayDescr_Type); DECL_NPY_API(PyArray_DescrFromType); + DECL_NPY_API(PyArray_TypeObjectFromType); DECL_NPY_API(PyArray_DescrFromScalar); + DECL_NPY_API(PyArray_Scalar); + DECL_NPY_API(PyArray_ScalarAsCtype); DECL_NPY_API(PyArray_FromAny); DECL_NPY_API(PyArray_Resize); DECL_NPY_API(PyArray_CopyInto); @@ -355,6 +382,88 @@ struct npy_api { } }; +template +struct is_complex : std::false_type {}; +template +struct is_complex> : std::true_type {}; + +template +struct npy_format_descriptor_name; + +template +struct npy_format_descriptor_name::value>> { + static constexpr auto name = const_name::value>( + const_name("bool"), + const_name::value>("int", "uint") + const_name()); +}; + +template +struct npy_format_descriptor_name::value>> { + static constexpr auto name + = const_name < std::is_same::value + || std::is_same::value + > (const_name("float") + const_name(), const_name("longdouble")); +}; + +template +struct npy_format_descriptor_name::value>> { + static constexpr auto name + = const_name < std::is_same::value + || std::is_same::value + > (const_name("complex") + const_name(), + const_name("longcomplex")); +}; + +template +struct numpy_scalar_info {}; + +#define DECL_NPY_SCALAR(ctype_, typenum_) \ + template <> \ + struct numpy_scalar_info { \ + static constexpr auto name = npy_format_descriptor_name::name; \ + static constexpr int typenum = npy_api::typenum_##_; \ + } + +// boolean type +DECL_NPY_SCALAR(bool, NPY_BOOL); + +// character types +DECL_NPY_SCALAR(char, NPY_CHAR); +DECL_NPY_SCALAR(signed char, NPY_BYTE); +DECL_NPY_SCALAR(unsigned char, NPY_UBYTE); + +// signed integer types +DECL_NPY_SCALAR(std::int16_t, NPY_SHORT); +DECL_NPY_SCALAR(std::int32_t, NPY_INT); +DECL_NPY_SCALAR(std::int64_t, NPY_LONG); +#if defined(__linux__) +DECL_NPY_SCALAR(long long, NPY_LONG); +#else +DECL_NPY_SCALAR(long, NPY_LONG); +#endif + +// unsigned integer types +DECL_NPY_SCALAR(std::uint16_t, NPY_USHORT); +DECL_NPY_SCALAR(std::uint32_t, NPY_UINT); +DECL_NPY_SCALAR(std::uint64_t, NPY_ULONG); +#if defined(__linux__) +DECL_NPY_SCALAR(unsigned long long, NPY_ULONG); +#else +DECL_NPY_SCALAR(unsigned long, NPY_ULONG); +#endif + +// floating point types +DECL_NPY_SCALAR(float, NPY_FLOAT); +DECL_NPY_SCALAR(double, NPY_DOUBLE); +DECL_NPY_SCALAR(long double, NPY_LONGDOUBLE); + +// complex types +DECL_NPY_SCALAR(std::complex, NPY_CFLOAT); +DECL_NPY_SCALAR(std::complex, NPY_CDOUBLE); +DECL_NPY_SCALAR(std::complex, NPY_CLONGDOUBLE); + +#undef DECL_NPY_SCALAR + // This table normalizes typenums by mapping NPY_INT_, NPY_LONG, ... to NPY_INT32_, NPY_INT64, ... // This is needed to correctly handle situations where multiple typenums map to the same type, // e.g. NPY_LONG_ may be equivalent to NPY_INT_ or NPY_LONGLONG_ despite having a different @@ -453,10 +562,6 @@ template struct is_std_array : std::false_type {}; template struct is_std_array> : std::true_type {}; -template -struct is_complex : std::false_type {}; -template -struct is_complex> : std::true_type {}; template struct array_info_scalar { @@ -670,8 +775,59 @@ template struct type_caster> : type_caster> {}; +template +struct type_caster> { + using value_type = T; + using type_info = numpy_scalar_info; + + PYBIND11_TYPE_CASTER(numpy_scalar, type_info::name); + + static handle &target_type() { + static handle tp = npy_api::get().PyArray_TypeObjectFromType_(type_info::typenum); + return tp; + } + + static handle &target_dtype() { + static handle tp = npy_api::get().PyArray_DescrFromType_(type_info::typenum); + return tp; + } + + bool load(handle src, bool) { + if (isinstance(src, target_type())) { + npy_api::get().PyArray_ScalarAsCtype_(src.ptr(), &value.value); + return true; + } + return false; + } + + static handle cast(numpy_scalar src, return_value_policy, handle) { + return npy_api::get().PyArray_Scalar_(&src.value, target_dtype().ptr(), nullptr); + } +}; + PYBIND11_NAMESPACE_END(detail) +template +struct numpy_scalar { + using value_type = T; + + value_type value; + + numpy_scalar() = default; + numpy_scalar(value_type value) : value(value) {} + + operator value_type() { return value; } + numpy_scalar &operator=(value_type value) { + this->value = value; + return *this; + } +}; + +template +numpy_scalar make_scalar(T value) { + return numpy_scalar(value); +} + class dtype : public object { public: PYBIND11_OBJECT_DEFAULT(dtype, object, detail::npy_api::get().PyArrayDescr_Check_) @@ -1409,38 +1565,6 @@ struct compare_buffer_info::valu } }; -template -struct npy_format_descriptor_name; - -template -struct npy_format_descriptor_name::value>> { - static constexpr auto name = const_name::value>( - const_name("bool"), - const_name::value>("numpy.int", "numpy.uint") - + const_name()); -}; - -template -struct npy_format_descriptor_name::value>> { - static constexpr auto name = const_name < std::is_same::value - || std::is_same::value - || std::is_same::value - || std::is_same::value - > (const_name("numpy.float") + const_name(), - const_name("numpy.longdouble")); -}; - -template -struct npy_format_descriptor_name::value>> { - static constexpr auto name = const_name < std::is_same::value - || std::is_same::value - || std::is_same::value - || std::is_same::value - > (const_name("numpy.complex") - + const_name(), - const_name("numpy.longcomplex")); -}; - template struct npy_format_descriptor< T, diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 2cf18c3547..ebd3fff1c2 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -159,6 +159,7 @@ set(PYBIND11_TEST_FILES test_native_enum test_numpy_array test_numpy_dtypes + test_numpy_scalars test_numpy_vectorize test_opaque_types test_operator_overloading diff --git a/tests/test_numpy_scalars.cpp b/tests/test_numpy_scalars.cpp new file mode 100644 index 0000000000..046a9c07a9 --- /dev/null +++ b/tests/test_numpy_scalars.cpp @@ -0,0 +1,52 @@ +/* + tests/test_numpy_scalars.cpp -- strict NumPy scalars + + Copyright (c) 2021 Steve R. Sun + + All rights reserved. Use of this source code is governed by a + BSD-style license that can be found in the LICENSE file. +*/ + +#include + +#include "pybind11_tests.h" + +#include +#include + +namespace py = pybind11; + +template +struct add { + T x; + add(T x) : x(x) {} + T operator()(T y) const { return static_cast(x + y); } +}; + +template +void register_test(py::module &m, const char *name, F &&func) { + m.def((std::string("test_") + name).c_str(), + [=](py::numpy_scalar v) { + return std::make_tuple(name, py::make_scalar(static_cast(func(v.value)))); + }, + py::arg("x")); +} + +TEST_SUBMODULE(numpy_scalars, m) { + using cfloat = std::complex; + using cdouble = std::complex; + + register_test(m, "bool", [](bool x) { return !x; }); + register_test(m, "int8", add(-8)); + register_test(m, "int16", add(-16)); + register_test(m, "int32", add(-32)); + register_test(m, "int64", add(-64)); + register_test(m, "uint8", add(8)); + register_test(m, "uint16", add(16)); + register_test(m, "uint32", add(32)); + register_test(m, "uint64", add(64)); + register_test(m, "float32", add(0.125f)); + register_test(m, "float64", add(0.25f)); + register_test(m, "complex64", add({0, -0.125f})); + register_test(m, "complex128", add({0, -0.25f})); +} diff --git a/tests/test_numpy_scalars.py b/tests/test_numpy_scalars.py new file mode 100644 index 0000000000..52c2861a1c --- /dev/null +++ b/tests/test_numpy_scalars.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +import sys + +import pytest + +from pybind11_tests import numpy_scalars as m + +np = pytest.importorskip("numpy") + +SCALAR_TYPES = { + np.bool_: False, + np.int8: -7, + np.int16: -15, + np.int32: -31, + np.int64: -63, + np.uint8: 9, + np.uint16: 17, + np.uint32: 33, + np.uint64: 65, + np.single: 1.125, + np.double: 1.25, + np.complex64: 1 - 0.125j, + np.complex128: 1 - 0.25j, +} +ALL_TYPES = [int, bool, float, bytes, str] + list(SCALAR_TYPES) + + +def type_name(tp): + try: + return tp.__name__.rstrip("_") + except BaseException: + # no numpy + return str(tp) + + +@pytest.fixture(scope="module", params=list(SCALAR_TYPES), ids=type_name) +def scalar_type(request): + return request.param + + +def expected_signature(tp): + s = "str" if sys.version_info[0] >= 3 else "unicode" + t = type_name(tp) + return f"test_{t}(x: {t}) -> tuple[{s}, {t}]\n" + + +def test_numpy_scalars(scalar_type): + expected = SCALAR_TYPES[scalar_type] + name = type_name(scalar_type) + func = getattr(m, "test_" + name) + assert func.__doc__ == expected_signature(scalar_type) + for tp in ALL_TYPES: + value = tp(1) + if tp is scalar_type: + result = func(value) + assert result[0] == name + assert isinstance(result[1], tp) + assert result[1] == tp(expected) + else: + with pytest.raises(TypeError): + func(value) From d435a02e55c0a4b624ee55a0731f8eb4ef9d1c68 Mon Sep 17 00:00:00 2001 From: Henry Schreiner Date: Fri, 13 Jun 2025 12:12:01 -0400 Subject: [PATCH 02/10] fix: fixes to make the tests pass --- include/pybind11/numpy.h | 27 ++++++++++++++++----------- tests/test_numpy_scalars.py | 2 +- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 1e971fb311..988d8cccac 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -393,25 +393,30 @@ struct npy_format_descriptor_name; template struct npy_format_descriptor_name::value>> { static constexpr auto name = const_name::value>( - const_name("bool"), - const_name::value>("int", "uint") + const_name()); + const_name("numpy.bool"), + const_name::value>("numpy.int", "numpy.uint") + + const_name()); }; template struct npy_format_descriptor_name::value>> { - static constexpr auto name - = const_name < std::is_same::value - || std::is_same::value - > (const_name("float") + const_name(), const_name("longdouble")); + static constexpr auto name = const_name < std::is_same::value + || std::is_same::value + || std::is_same::value + || std::is_same::value + > (const_name("numpy.float") + const_name(), + const_name("numpy.longdouble")); }; template struct npy_format_descriptor_name::value>> { - static constexpr auto name - = const_name < std::is_same::value - || std::is_same::value - > (const_name("complex") + const_name(), - const_name("longcomplex")); + static constexpr auto name = const_name < std::is_same::value + || std::is_same::value + || std::is_same::value + || std::is_same::value + > (const_name("numpy.complex") + + const_name(), + const_name("numpy.longcomplex")); }; template diff --git a/tests/test_numpy_scalars.py b/tests/test_numpy_scalars.py index 52c2861a1c..020465f2d1 100644 --- a/tests/test_numpy_scalars.py +++ b/tests/test_numpy_scalars.py @@ -42,7 +42,7 @@ def scalar_type(request): def expected_signature(tp): s = "str" if sys.version_info[0] >= 3 else "unicode" t = type_name(tp) - return f"test_{t}(x: {t}) -> tuple[{s}, {t}]\n" + return f"test_{t}(x: numpy.{t}) -> tuple[{s}, numpy.{t}]\n" def test_numpy_scalars(scalar_type): From e5e6522e5bfc4f0288e7d7d170cfdfa2683f0c22 Mon Sep 17 00:00:00 2001 From: Henry Schreiner Date: Fri, 13 Jun 2025 15:31:56 -0400 Subject: [PATCH 03/10] fix: use simpler definitions for ints --- include/pybind11/numpy.h | 26 ++++++++------------------ tests/test_numpy_scalars.cpp | 2 +- 2 files changed, 9 insertions(+), 19 deletions(-) diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 988d8cccac..cdfdcf85dc 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -438,24 +438,14 @@ DECL_NPY_SCALAR(signed char, NPY_BYTE); DECL_NPY_SCALAR(unsigned char, NPY_UBYTE); // signed integer types -DECL_NPY_SCALAR(std::int16_t, NPY_SHORT); -DECL_NPY_SCALAR(std::int32_t, NPY_INT); -DECL_NPY_SCALAR(std::int64_t, NPY_LONG); -#if defined(__linux__) -DECL_NPY_SCALAR(long long, NPY_LONG); -#else -DECL_NPY_SCALAR(long, NPY_LONG); -#endif +DECL_NPY_SCALAR(std::int16_t, NPY_INT16); +DECL_NPY_SCALAR(std::int32_t, NPY_INT32); +DECL_NPY_SCALAR(std::int64_t, NPY_INT64); // unsigned integer types -DECL_NPY_SCALAR(std::uint16_t, NPY_USHORT); -DECL_NPY_SCALAR(std::uint32_t, NPY_UINT); -DECL_NPY_SCALAR(std::uint64_t, NPY_ULONG); -#if defined(__linux__) -DECL_NPY_SCALAR(unsigned long long, NPY_ULONG); -#else -DECL_NPY_SCALAR(unsigned long, NPY_ULONG); -#endif +DECL_NPY_SCALAR(std::uint16_t, NPY_UINT16); +DECL_NPY_SCALAR(std::uint32_t, NPY_UINT32); +DECL_NPY_SCALAR(std::uint64_t, NPY_UINT64); // floating point types DECL_NPY_SCALAR(float, NPY_FLOAT); @@ -819,9 +809,9 @@ struct numpy_scalar { value_type value; numpy_scalar() = default; - numpy_scalar(value_type value) : value(value) {} + explicit numpy_scalar(value_type value) : value(value) {} - operator value_type() { return value; } + explicit operator value_type() { return value; } numpy_scalar &operator=(value_type value) { this->value = value; return *this; diff --git a/tests/test_numpy_scalars.cpp b/tests/test_numpy_scalars.cpp index 046a9c07a9..abedc0045f 100644 --- a/tests/test_numpy_scalars.cpp +++ b/tests/test_numpy_scalars.cpp @@ -19,7 +19,7 @@ namespace py = pybind11; template struct add { T x; - add(T x) : x(x) {} + explicit add(T x) : x(x) {} T operator()(T y) const { return static_cast(x + y); } }; From 8379d254784ccb8fcb6b7f920851b94b0715de55 Mon Sep 17 00:00:00 2001 From: Henry Schreiner Date: Tue, 17 Jun 2025 15:21:34 -0400 Subject: [PATCH 04/10] Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- docs/advanced/pycpp/numpy.rst | 2 +- include/pybind11/numpy.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/advanced/pycpp/numpy.rst b/docs/advanced/pycpp/numpy.rst index 29638eb821..5293610a72 100644 --- a/docs/advanced/pycpp/numpy.rst +++ b/docs/advanced/pycpp/numpy.rst @@ -261,7 +261,7 @@ them mapping to respective C++ counterparts. .. note:: - This is a strict type, it will only allows to specify NumPy type as input + This is a strict type, it will only allow to specify NumPy type as input arguments, and does not allow other types of input parameters (e.g., ``py::numpy_scalar`` will not accept Python's builtin ``int`` ). diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index cdfdcf85dc..318aaaefc4 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -811,7 +811,7 @@ struct numpy_scalar { numpy_scalar() = default; explicit numpy_scalar(value_type value) : value(value) {} - explicit operator value_type() { return value; } + explicit operator value_type() const { return value; } numpy_scalar &operator=(value_type value) { this->value = value; return *this; From f928a5467bdd532873b666850ddf2675b186b6e9 Mon Sep 17 00:00:00 2001 From: "Ralf W. Grosse-Kunstleve" Date: Wed, 18 Jun 2025 14:06:34 -0700 Subject: [PATCH 05/10] Modernize test_numpy_scalars.py --- tests/test_numpy_scalars.py | 53 +++++++++++++------------------------ 1 file changed, 19 insertions(+), 34 deletions(-) diff --git a/tests/test_numpy_scalars.py b/tests/test_numpy_scalars.py index 020465f2d1..ca9e3a63b6 100644 --- a/tests/test_numpy_scalars.py +++ b/tests/test_numpy_scalars.py @@ -1,14 +1,12 @@ from __future__ import annotations -import sys - import pytest from pybind11_tests import numpy_scalars as m np = pytest.importorskip("numpy") -SCALAR_TYPES = { +NPY_SCALAR_TYPES = { np.bool_: False, np.int8: -7, np.int16: -15, @@ -23,40 +21,27 @@ np.complex64: 1 - 0.125j, np.complex128: 1 - 0.25j, } -ALL_TYPES = [int, bool, float, bytes, str] + list(SCALAR_TYPES) - - -def type_name(tp): - try: - return tp.__name__.rstrip("_") - except BaseException: - # no numpy - return str(tp) - - -@pytest.fixture(scope="module", params=list(SCALAR_TYPES), ids=type_name) -def scalar_type(request): - return request.param - -def expected_signature(tp): - s = "str" if sys.version_info[0] >= 3 else "unicode" - t = type_name(tp) - return f"test_{t}(x: numpy.{t}) -> tuple[{s}, numpy.{t}]\n" +ALL_SCALAR_TYPES = tuple(NPY_SCALAR_TYPES.keys()) + (int, bool, float, bytes, str) -def test_numpy_scalars(scalar_type): - expected = SCALAR_TYPES[scalar_type] - name = type_name(scalar_type) - func = getattr(m, "test_" + name) - assert func.__doc__ == expected_signature(scalar_type) - for tp in ALL_TYPES: +@pytest.mark.parametrize( + ("npy_scalar_type", "expected_value"), NPY_SCALAR_TYPES.items() +) +def test_numpy_scalars(npy_scalar_type, expected_value): + tpnm = npy_scalar_type.__name__.rstrip("_") + test_tpnm = getattr(m, "test_" + tpnm) + assert ( + test_tpnm.__doc__ + == f"test_{tpnm}(x: numpy.{tpnm}) -> tuple[str, numpy.{tpnm}]\n" + ) + for tp in ALL_SCALAR_TYPES: value = tp(1) - if tp is scalar_type: - result = func(value) - assert result[0] == name - assert isinstance(result[1], tp) - assert result[1] == tp(expected) + if tp is npy_scalar_type: + result_tpnm, result_value = test_tpnm(value) + assert result_tpnm == tpnm + assert isinstance(result_value, npy_scalar_type) + assert result_value == tp(expected_value) else: with pytest.raises(TypeError): - func(value) + test_tpnm(value) From 437f93117d2db46a3323a3e3a000d1486959170f Mon Sep 17 00:00:00 2001 From: "Ralf W. Grosse-Kunstleve" Date: Wed, 18 Jun 2025 14:08:22 -0700 Subject: [PATCH 06/10] Apply doc change suggested in review. --- docs/advanced/pycpp/numpy.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/advanced/pycpp/numpy.rst b/docs/advanced/pycpp/numpy.rst index 5293610a72..0c0447667a 100644 --- a/docs/advanced/pycpp/numpy.rst +++ b/docs/advanced/pycpp/numpy.rst @@ -261,9 +261,9 @@ them mapping to respective C++ counterparts. .. note:: - This is a strict type, it will only allow to specify NumPy type as input - arguments, and does not allow other types of input parameters (e.g., - ``py::numpy_scalar`` will not accept Python's builtin ``int`` ). + ``py::numpy_scalar`` strictly matches NumPy scalar types. For example, + ``py::numpy_scalar`` will accept ``np.int64(123)``, + but **not** a regular Python ``int`` like ``123``. .. note:: From b066f5f18e935ef5e4b172c0b3b981ea6f9c2ef5 Mon Sep 17 00:00:00 2001 From: "Ralf W. Grosse-Kunstleve" Date: Wed, 18 Jun 2025 14:11:50 -0700 Subject: [PATCH 07/10] =?UTF-8?q?Change=20DECL=5FNPY=5FSCALAR=20=E2=86=92?= =?UTF-8?q?=20PYBIND11=5FNUMPY=5FSCALAR=5FIMPL?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/pybind11/numpy.h | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 318aaaefc4..cd84ddc417 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -422,7 +422,7 @@ struct npy_format_descriptor_name::value>> { template struct numpy_scalar_info {}; -#define DECL_NPY_SCALAR(ctype_, typenum_) \ +#define PYBIND11_NUMPY_SCALAR_IMPL(ctype_, typenum_) \ template <> \ struct numpy_scalar_info { \ static constexpr auto name = npy_format_descriptor_name::name; \ @@ -430,34 +430,34 @@ struct numpy_scalar_info {}; } // boolean type -DECL_NPY_SCALAR(bool, NPY_BOOL); +PYBIND11_NUMPY_SCALAR_IMPL(bool, NPY_BOOL); // character types -DECL_NPY_SCALAR(char, NPY_CHAR); -DECL_NPY_SCALAR(signed char, NPY_BYTE); -DECL_NPY_SCALAR(unsigned char, NPY_UBYTE); +PYBIND11_NUMPY_SCALAR_IMPL(char, NPY_CHAR); +PYBIND11_NUMPY_SCALAR_IMPL(signed char, NPY_BYTE); +PYBIND11_NUMPY_SCALAR_IMPL(unsigned char, NPY_UBYTE); // signed integer types -DECL_NPY_SCALAR(std::int16_t, NPY_INT16); -DECL_NPY_SCALAR(std::int32_t, NPY_INT32); -DECL_NPY_SCALAR(std::int64_t, NPY_INT64); +PYBIND11_NUMPY_SCALAR_IMPL(std::int16_t, NPY_INT16); +PYBIND11_NUMPY_SCALAR_IMPL(std::int32_t, NPY_INT32); +PYBIND11_NUMPY_SCALAR_IMPL(std::int64_t, NPY_INT64); // unsigned integer types -DECL_NPY_SCALAR(std::uint16_t, NPY_UINT16); -DECL_NPY_SCALAR(std::uint32_t, NPY_UINT32); -DECL_NPY_SCALAR(std::uint64_t, NPY_UINT64); +PYBIND11_NUMPY_SCALAR_IMPL(std::uint16_t, NPY_UINT16); +PYBIND11_NUMPY_SCALAR_IMPL(std::uint32_t, NPY_UINT32); +PYBIND11_NUMPY_SCALAR_IMPL(std::uint64_t, NPY_UINT64); // floating point types -DECL_NPY_SCALAR(float, NPY_FLOAT); -DECL_NPY_SCALAR(double, NPY_DOUBLE); -DECL_NPY_SCALAR(long double, NPY_LONGDOUBLE); +PYBIND11_NUMPY_SCALAR_IMPL(float, NPY_FLOAT); +PYBIND11_NUMPY_SCALAR_IMPL(double, NPY_DOUBLE); +PYBIND11_NUMPY_SCALAR_IMPL(long double, NPY_LONGDOUBLE); // complex types -DECL_NPY_SCALAR(std::complex, NPY_CFLOAT); -DECL_NPY_SCALAR(std::complex, NPY_CDOUBLE); -DECL_NPY_SCALAR(std::complex, NPY_CLONGDOUBLE); +PYBIND11_NUMPY_SCALAR_IMPL(std::complex, NPY_CFLOAT); +PYBIND11_NUMPY_SCALAR_IMPL(std::complex, NPY_CDOUBLE); +PYBIND11_NUMPY_SCALAR_IMPL(std::complex, NPY_CLONGDOUBLE); -#undef DECL_NPY_SCALAR +#undef PYBIND11_NUMPY_SCALAR_IMPL // This table normalizes typenums by mapping NPY_INT_, NPY_LONG, ... to NPY_INT32_, NPY_INT64, ... // This is needed to correctly handle situations where multiple typenums map to the same type, From ac71c07daa8b7bb5bed902b4b50134bed449f354 Mon Sep 17 00:00:00 2001 From: "Ralf W. Grosse-Kunstleve" Date: Wed, 18 Jun 2025 14:22:00 -0700 Subject: [PATCH 08/10] Add numpy_scalar operator==, operator!= --- include/pybind11/numpy.h | 6 ++++++ tests/test_numpy_scalars.cpp | 3 +++ tests/test_numpy_scalars.py | 7 +++++++ 3 files changed, 16 insertions(+) diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index cd84ddc417..7f62157f5c 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -816,6 +816,12 @@ struct numpy_scalar { this->value = value; return *this; } + + friend bool operator==(const numpy_scalar &a, const numpy_scalar &b) { + return a.value == b.value; + } + + friend bool operator!=(const numpy_scalar &a, const numpy_scalar &b) { return !(a == b); } }; template diff --git a/tests/test_numpy_scalars.cpp b/tests/test_numpy_scalars.cpp index abedc0045f..6d73518188 100644 --- a/tests/test_numpy_scalars.cpp +++ b/tests/test_numpy_scalars.cpp @@ -49,4 +49,7 @@ TEST_SUBMODULE(numpy_scalars, m) { register_test(m, "float64", add(0.25f)); register_test(m, "complex64", add({0, -0.125f})); register_test(m, "complex128", add({0, -0.25f})); + + m.def("test_eq", [](py::numpy_scalar a, py::numpy_scalar b) { return a == b; }); + m.def("test_ne", [](py::numpy_scalar a, py::numpy_scalar b) { return a != b; }); } diff --git a/tests/test_numpy_scalars.py b/tests/test_numpy_scalars.py index ca9e3a63b6..fe9b71f22e 100644 --- a/tests/test_numpy_scalars.py +++ b/tests/test_numpy_scalars.py @@ -45,3 +45,10 @@ def test_numpy_scalars(npy_scalar_type, expected_value): else: with pytest.raises(TypeError): test_tpnm(value) + + +def test_eq_ne(): + assert m.test_eq(np.int32(3), np.int32(3)) + assert not m.test_eq(np.int32(3), np.int32(5)) + assert not m.test_ne(np.int32(3), np.int32(3)) + assert m.test_ne(np.int32(3), np.int32(5)) From c38245a96c8f1078d16418880e46660a63e6107f Mon Sep 17 00:00:00 2001 From: "Ralf W. Grosse-Kunstleve" Date: Wed, 18 Jun 2025 14:24:07 -0700 Subject: [PATCH 09/10] Move C++ test code into namespace pybind11_test_numpy_scalars --- tests/test_numpy_scalars.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test_numpy_scalars.cpp b/tests/test_numpy_scalars.cpp index 6d73518188..739e6a15c8 100644 --- a/tests/test_numpy_scalars.cpp +++ b/tests/test_numpy_scalars.cpp @@ -16,6 +16,8 @@ namespace py = pybind11; +namespace pybind11_test_numpy_scalars { + template struct add { T x; @@ -32,6 +34,10 @@ void register_test(py::module &m, const char *name, F &&func) { py::arg("x")); } +} // namespace pybind11_test_numpy_scalars + +using namespace pybind11_test_numpy_scalars; + TEST_SUBMODULE(numpy_scalars, m) { using cfloat = std::complex; using cdouble = std::complex; From 59e2a6204f410c038762d27bad483be1efbbd883 Mon Sep 17 00:00:00 2001 From: "Ralf W. Grosse-Kunstleve" Date: Wed, 18 Jun 2025 14:32:54 -0700 Subject: [PATCH 10/10] =?UTF-8?q?Fix=20oversight=20(int=20=E2=86=92=20int3?= =?UTF-8?q?2=5Ft)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_numpy_scalars.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_numpy_scalars.cpp b/tests/test_numpy_scalars.cpp index 739e6a15c8..79393ebdd3 100644 --- a/tests/test_numpy_scalars.cpp +++ b/tests/test_numpy_scalars.cpp @@ -56,6 +56,8 @@ TEST_SUBMODULE(numpy_scalars, m) { register_test(m, "complex64", add({0, -0.125f})); register_test(m, "complex128", add({0, -0.25f})); - m.def("test_eq", [](py::numpy_scalar a, py::numpy_scalar b) { return a == b; }); - m.def("test_ne", [](py::numpy_scalar a, py::numpy_scalar b) { return a != b; }); + m.def("test_eq", + [](py::numpy_scalar a, py::numpy_scalar b) { return a == b; }); + m.def("test_ne", + [](py::numpy_scalar a, py::numpy_scalar b) { return a != b; }); }