Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/pybind11/detail/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,10 @@
# define PYBIND11_HAS_STRING_VIEW 1
#endif

#if defined(PYBIND11_CPP20) && defined(__has_include) && __has_include(<span>)
# define PYBIND11_HAS_SPAN 1
#endif

#if (defined(PYPY_VERSION) || defined(GRAALVM_PYTHON)) && !defined(PYBIND11_SIMPLE_GIL_MANAGEMENT)
# define PYBIND11_SIMPLE_GIL_MANAGEMENT
#endif
Expand Down
18 changes: 18 additions & 0 deletions include/pybind11/numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
#include <utility>
#include <vector>

#ifdef PYBIND11_HAS_SPAN
# include <span>
#endif

#if defined(PYBIND11_NUMPY_1_ONLY)
# error "PYBIND11_NUMPY_1_ONLY is no longer supported (see PR #5595)."
#endif
Expand Down Expand Up @@ -1143,6 +1147,13 @@ class array : public buffer {
/// Dimensions of the array
const ssize_t *shape() const { return detail::array_proxy(m_ptr)->dimensions; }

#ifdef PYBIND11_HAS_SPAN
/// Dimensions of the array as a span
std::span<const ssize_t, std::dynamic_extent> shape_span() const {
return std::span(shape(), static_cast<std::size_t>(ndim()));
}
#endif

/// Dimension along a given axis
ssize_t shape(ssize_t dim) const {
if (dim >= ndim()) {
Expand All @@ -1154,6 +1165,13 @@ class array : public buffer {
/// Strides of the array
const ssize_t *strides() const { return detail::array_proxy(m_ptr)->strides; }

#ifdef PYBIND11_HAS_SPAN
/// Strides of the array as a span
std::span<const ssize_t, std::dynamic_extent> strides_span() const {
return std::span(strides(), static_cast<std::size_t>(ndim()));
}
#endif

/// Stride along a given axis
ssize_t strides(ssize_t dim) const {
if (dim >= ndim()) {
Expand Down
17 changes: 17 additions & 0 deletions tests/test_numpy_array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include <cstdint>
#include <utility>
#include <vector>

// Size / dtype checks.
struct DtypeCheck {
Expand Down Expand Up @@ -246,6 +247,22 @@ TEST_SUBMODULE(numpy_array, sm) {
sm.def("nbytes", [](const arr &a) { return a.nbytes(); });
sm.def("owndata", [](const arr &a) { return a.owndata(); });

#ifdef PYBIND11_HAS_SPAN
// test_shape_strides_span
sm.def("shape_span", [](const arr &a) {
auto span = a.shape_span();
return std::vector<ssize_t>(span.begin(), span.end());
});
sm.def("strides_span", [](const arr &a) {
auto span = a.strides_span();
return std::vector<ssize_t>(span.begin(), span.end());
});
// Test that spans can be used to construct new arrays
sm.def("array_from_spans", [](const arr &a) {
return py::array(a.dtype(), a.shape_span(), a.strides_span(), a.data(), a);
});
#endif

// test_index_offset
def_index_fn(index_at, const arr &);
def_index_fn(index_at_t, const arr_t &);
Expand Down
39 changes: 39 additions & 0 deletions tests/test_numpy_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,45 @@ def test_array_attributes():
assert not m.owndata(a)


@pytest.mark.skipif(not hasattr(m, "shape_span"), reason="std::span not available")
def test_shape_strides_span():
# Test 0-dimensional array (scalar)
a = np.array(42, "f8")
assert m.ndim(a) == 0
assert m.shape_span(a) == []
assert m.strides_span(a) == []

# Test 1-dimensional array
a = np.array([1, 2, 3, 4], "u2")
assert m.ndim(a) == 1
assert m.shape_span(a) == [4]
assert m.strides_span(a) == [2]

# Test 2-dimensional array
a = np.array([[1, 2, 3], [4, 5, 6]], "u2").view()
a.flags.writeable = False
assert m.ndim(a) == 2
assert m.shape_span(a) == [2, 3]
assert m.strides_span(a) == [6, 2]

# Test 3-dimensional array
a = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], "i4")
assert m.ndim(a) == 3
assert m.shape_span(a) == [2, 2, 2]
# Verify spans match regular shape/strides
assert list(m.shape_span(a)) == list(m.shape(a))
assert list(m.strides_span(a)) == list(m.strides(a))

# Test that spans can be used to construct new arrays
original = np.array([[1, 2, 3], [4, 5, 6]], "f4")
new_array = m.array_from_spans(original)
assert new_array.shape == original.shape
assert new_array.strides == original.strides
assert new_array.dtype == original.dtype
# Verify data is shared (since we pass the same data pointer)
np.testing.assert_array_equal(new_array, original)


@pytest.mark.parametrize(
("args", "ret"), [([], 0), ([0], 0), ([1], 3), ([0, 1], 1), ([1, 2], 5)]
)
Expand Down
Loading