-
Notifications
You must be signed in to change notification settings - Fork 24
Improve dpnp.partition implementation
#2766
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2da3758
bd67b44
3967a79
c90f0b2
680ae20
69b42d4
f70212a
06d3950
b939685
5f41788
9ebb038
6422635
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -39,8 +39,9 @@ | |
|
|
||
| """ | ||
|
|
||
| from collections.abc import Sequence | ||
|
|
||
| import dpctl.tensor as dpt | ||
| import numpy | ||
| from dpctl.tensor._numpy_helper import normalize_axis_index | ||
|
|
||
| import dpnp | ||
|
|
@@ -51,7 +52,6 @@ | |
| ) | ||
| from .dpnp_array import dpnp_array | ||
| from .dpnp_utils import ( | ||
| call_origin, | ||
| map_dtype_to_device, | ||
| ) | ||
|
|
||
|
|
@@ -147,7 +147,7 @@ def argsort( | |
|
|
||
| Limitations | ||
| ----------- | ||
| Parameters `order` is only supported with its default value. | ||
| Parameter `order` is only supported with its default value. | ||
| Otherwise ``NotImplementedError`` exception will be raised. | ||
| Sorting algorithms ``"quicksort"`` and ``"heapsort"`` are not supported. | ||
|
|
||
|
|
@@ -201,44 +201,128 @@ def argsort( | |
| ) | ||
|
|
||
|
|
||
| def partition(x1, kth, axis=-1, kind="introselect", order=None): | ||
| def partition(a, kth, axis=-1, kind="introselect", order=None): | ||
| """ | ||
| Return a partitioned copy of an array. | ||
|
|
||
| For full documentation refer to :obj:`numpy.partition`. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| a : {dpnp.ndarray, usm_ndarray} | ||
| Array to be sorted. | ||
| kth : {int, sequence of ints} | ||
| Element index to partition by. The k-th value of the element will be in | ||
| its final sorted position and all smaller elements will be moved before | ||
| it and all equal or greater elements behind it. The order of all | ||
| elements in the partitions is undefined. If provided with a sequence of | ||
| k-th it will partition all elements indexed by k-th of them into their | ||
| sorted position at once. | ||
| axis : {None, int}, optional | ||
| Axis along which to sort. If ``None``, the array is flattened before | ||
| sorting. The default is ``-1``, which sorts along the last axis. | ||
|
|
||
| Default: ``-1``. | ||
|
|
||
| Returns | ||
| ------- | ||
| out : dpnp.ndarray | ||
| Array of the same type and shape as `a`. | ||
|
|
||
| Limitations | ||
| ----------- | ||
| Input array is supported as :obj:`dpnp.ndarray`. | ||
| Input `kth` is supported as :obj:`int`. | ||
| Parameters `axis`, `kind` and `order` are supported only with default | ||
| values. | ||
| Parameters `kind` and `order` are only supported with its default value. | ||
| Otherwise ``NotImplementedError`` exception will be raised. | ||
|
|
||
| See Also | ||
| -------- | ||
| :obj:`dpnp.ndarray.partition` : Equivalent method. | ||
| :obj:`dpnp.argpartition` : Indirect partition. | ||
| :obj:`dpnp.sort` : Full sorting. | ||
|
|
||
| Examples | ||
| -------- | ||
| >>> import dpnp as np | ||
| >>> a = np.array([7, 1, 7, 7, 1, 5, 7, 2, 3, 2, 6, 2, 3, 0]) | ||
| >>> p = np.partition(a, 4) | ||
| >>> p | ||
| array([0, 1, 1, 2, 2, 2, 3, 3, 5, 7, 7, 7, 7, 6]) # may vary | ||
|
|
||
| ``p[4]`` is 2; all elements in ``p[:4]`` are less than or equal to | ||
| ``p[4]``, and all elements in ``p[5:]`` are greater than or equal to | ||
| ``p[4]``. The partition is:: | ||
|
|
||
| [0, 1, 1, 2], [2], [2, 3, 3, 5, 7, 7, 7, 7, 6] | ||
|
|
||
| The next example shows the use of multiple values passed to `kth`. | ||
|
|
||
| >>> p2 = np.partition(a, (4, 8)) | ||
| >>> p2 | ||
| array([0, 1, 1, 2, 2, 2, 3, 3, 5, 6, 7, 7, 7, 7]) | ||
|
|
||
| ``p2[4]`` is 2 and ``p2[8]`` is 5. All elements in ``p2[:4]`` are less | ||
| than or equal to ``p2[4]``, all elements in ``p2[5:8]`` are greater than or | ||
| equal to ``p2[4]`` and less than or equal to ``p2[8]``, and all elements in | ||
| ``p2[9:]`` are greater than or equal to ``p2[8]``. The partition is:: | ||
|
|
||
| [0, 1, 1, 2], [2], [2, 3, 3], [5], [6, 7, 7, 7, 7] | ||
|
|
||
| """ | ||
|
|
||
| x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False) | ||
| if x1_desc: | ||
| if dpnp.is_cuda_backend(x1_desc.get_array()): # pragma: no cover | ||
| raise NotImplementedError( | ||
| "Running on CUDA is currently not supported" | ||
| ) | ||
| dpnp.check_supported_arrays_type(a) | ||
|
|
||
| if not isinstance(kth, int): | ||
| pass | ||
| elif x1_desc.ndim == 0: | ||
| pass | ||
| elif kth >= x1_desc.shape[x1_desc.ndim - 1] or x1_desc.ndim + kth < 0: | ||
| pass | ||
| elif axis != -1: | ||
| pass | ||
| elif kind != "introselect": | ||
| pass | ||
| elif order is not None: | ||
| pass | ||
| else: | ||
| return dpnp_partition(x1_desc, kth, axis, kind, order).get_pyobj() | ||
| if kind != "introselect": | ||
| raise NotImplementedError( | ||
| "`kind` keyword argument is only supported with its default value." | ||
| ) | ||
| if order is not None: | ||
| raise NotImplementedError( | ||
| "`order` keyword argument is only supported with its default value." | ||
| ) | ||
|
|
||
| return call_origin(numpy.partition, x1, kth, axis, kind, order) | ||
| if axis is None: | ||
| a = dpnp.ravel(a) | ||
| axis = -1 | ||
|
|
||
| nd = a.ndim | ||
| axis = normalize_axis_index(axis, nd) | ||
| length = a.shape[axis] | ||
|
|
||
| if isinstance(kth, int): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks too strict because does not support integer-like scalars as
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I meant Should we use
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could, but it's not mandatory, because we explicitly state the support of int and sequence of itns in the documentation. |
||
| kth = (kth,) | ||
| elif not isinstance(kth, Sequence): | ||
| raise TypeError( | ||
| f"kth must be int or sequence of ints, but got {type(kth)}" | ||
| ) | ||
| elif not all(isinstance(k, int) for k in kth): | ||
| raise TypeError("kth is a sequence, but not all elements are integers") | ||
|
|
||
| nkth = len(kth) | ||
| if nkth == 0 or a.size == 0: | ||
| return dpnp.copy(a) | ||
|
|
||
| # validate kth | ||
| kth = list(kth) | ||
| for i in range(nkth): | ||
| if kth[i] < 0: | ||
| kth[i] += length | ||
|
|
||
| if not 0 <= kth[i] < length: | ||
| raise ValueError(f"kth(={kth[i]}) out of bounds {length}") | ||
|
|
||
| dt = a.dtype | ||
| if ( | ||
| nd > 1 | ||
| or nkth > 1 | ||
| or dpnp.issubdtype(dt, dpnp.unsignedinteger) | ||
| or dt in (dpnp.int8, dpnp.int16) | ||
| or dpnp.is_cuda_backend(a.get_array()) | ||
| ): | ||
| # sort is a faster path in case of ndim > 1 | ||
| return dpnp.sort(a, axis=axis) | ||
|
|
||
| desc = dpnp.get_dpnp_descriptor(a, copy_when_nondefault_queue=False) | ||
| return dpnp_partition(desc, kth[0], axis, kind, order).get_pyobj() | ||
|
|
||
|
|
||
| def sort(a, axis=-1, kind=None, order=None, *, descending=False, stable=None): | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.