numpy/array_like.rs
1use std::marker::PhantomData;
2use std::ops::Deref;
3
4use ndarray::{Array1, Dimension, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
5use pyo3::{
6 intern,
7 sync::GILOnceCell,
8 types::{PyAnyMethods, PyDict},
9 Bound, FromPyObject, Py, PyAny, PyResult,
10};
11
12use crate::array::PyArrayMethods;
13use crate::{get_array_module, Element, IntoPyArray, PyArray, PyReadonlyArray};
14
15pub trait Coerce: Sealed {
16 const VAL: bool;
17}
18
19mod sealed {
20 pub trait Sealed {}
21}
22
23use sealed::Sealed;
24
25/// Marker type to indicate that the element type received via [`PyArrayLike`] must match the specified type exactly.
26#[derive(Debug)]
27pub struct TypeMustMatch;
28
29impl Sealed for TypeMustMatch {}
30
31impl Coerce for TypeMustMatch {
32 const VAL: bool = false;
33}
34
35/// Marker type to indicate that the element type received via [`PyArrayLike`] can be cast to the specified type by NumPy's [`asarray`](https://numpy.org/doc/stable/reference/generated/numpy.asarray.html).
36#[derive(Debug)]
37pub struct AllowTypeChange;
38
39impl Sealed for AllowTypeChange {}
40
41impl Coerce for AllowTypeChange {
42 const VAL: bool = true;
43}
44
45/// Receiver for arrays or array-like types.
46///
47/// When building API using NumPy in Python, it is common for functions to additionally accept any array-like type such as `list[float]` as arguments.
48/// `PyArrayLike` enables the same pattern in Rust extensions, i.e. by taking this type as the argument of a `#[pyfunction]`,
49/// one will always get access to a [`PyReadonlyArray`] that will either reference to the NumPy array originally passed into the function
50/// or a temporary one created by converting the input type into a NumPy array.
51///
52/// Depending on whether [`TypeMustMatch`] or [`AllowTypeChange`] is used for the `C` type parameter,
53/// the element type must either match the specific type `T` exactly or will be cast to it by NumPy's [`asarray`](https://numpy.org/doc/stable/reference/generated/numpy.asarray.html).
54///
55/// # Example
56///
57/// `PyArrayLike1<'py, T, TypeMustMatch>` will enable you to receive both NumPy arrays and sequences
58///
59/// ```rust
60/// # use pyo3::prelude::*;
61/// use pyo3::py_run;
62/// use numpy::{get_array_module, PyArrayLike1, TypeMustMatch};
63///
64/// #[pyfunction]
65/// fn sum_up<'py>(py: Python<'py>, array: PyArrayLike1<'py, f64, TypeMustMatch>) -> f64 {
66/// array.as_array().sum()
67/// }
68///
69/// Python::with_gil(|py| {
70/// let np = get_array_module(py).unwrap();
71/// let sum_up = wrap_pyfunction!(sum_up)(py).unwrap();
72///
73/// py_run!(py, np sum_up, r"assert sum_up(np.array([1., 2., 3.])) == 6.");
74/// py_run!(py, np sum_up, r"assert sum_up((1., 2., 3.)) == 6.");
75/// });
76/// ```
77///
78/// but it will not cast the element type if that is required
79///
80/// ```rust,should_panic
81/// use pyo3::prelude::*;
82/// use pyo3::py_run;
83/// use numpy::{get_array_module, PyArrayLike1, TypeMustMatch};
84///
85/// #[pyfunction]
86/// fn sum_up<'py>(py: Python<'py>, array: PyArrayLike1<'py, i32, TypeMustMatch>) -> i32 {
87/// array.as_array().sum()
88/// }
89///
90/// Python::with_gil(|py| {
91/// let np = get_array_module(py).unwrap();
92/// let sum_up = wrap_pyfunction!(sum_up)(py).unwrap();
93///
94/// py_run!(py, np sum_up, r"assert sum_up((1., 2., 3.)) == 6");
95/// });
96/// ```
97///
98/// whereas `PyArrayLike1<'py, T, AllowTypeChange>` will do even at the cost loosing precision
99///
100/// ```rust
101/// use pyo3::prelude::*;
102/// use pyo3::py_run;
103/// use numpy::{get_array_module, AllowTypeChange, PyArrayLike1};
104///
105/// #[pyfunction]
106/// fn sum_up<'py>(py: Python<'py>, array: PyArrayLike1<'py, i32, AllowTypeChange>) -> i32 {
107/// array.as_array().sum()
108/// }
109///
110/// Python::with_gil(|py| {
111/// let np = get_array_module(py).unwrap();
112/// let sum_up = wrap_pyfunction!(sum_up)(py).unwrap();
113///
114/// py_run!(py, np sum_up, r"assert sum_up((1.5, 2.5)) == 3");
115/// });
116/// ```
117#[derive(Debug)]
118#[repr(transparent)]
119pub struct PyArrayLike<'py, T, D, C = TypeMustMatch>(PyReadonlyArray<'py, T, D>, PhantomData<C>)
120where
121 T: Element,
122 D: Dimension,
123 C: Coerce;
124
125impl<'py, T, D, C> Deref for PyArrayLike<'py, T, D, C>
126where
127 T: Element,
128 D: Dimension,
129 C: Coerce,
130{
131 type Target = PyReadonlyArray<'py, T, D>;
132
133 fn deref(&self) -> &Self::Target {
134 &self.0
135 }
136}
137
138impl<'py, T, D, C> FromPyObject<'py> for PyArrayLike<'py, T, D, C>
139where
140 T: Element + 'py,
141 D: Dimension + 'py,
142 C: Coerce,
143 Vec<T>: FromPyObject<'py>,
144{
145 fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
146 if let Ok(array) = ob.downcast::<PyArray<T, D>>() {
147 return Ok(Self(array.readonly(), PhantomData));
148 }
149
150 let py = ob.py();
151
152 if matches!(D::NDIM, None | Some(1)) {
153 if let Ok(vec) = ob.extract::<Vec<T>>() {
154 let array = Array1::from(vec)
155 .into_dimensionality()
156 .expect("D being compatible to Ix1")
157 .into_pyarray(py)
158 .readonly();
159 return Ok(Self(array, PhantomData));
160 }
161 }
162
163 static AS_ARRAY: GILOnceCell<Py<PyAny>> = GILOnceCell::new();
164
165 let as_array = AS_ARRAY
166 .get_or_try_init(py, || {
167 get_array_module(py)?.getattr("asarray").map(Into::into)
168 })?
169 .bind(py);
170
171 let kwargs = if C::VAL {
172 let kwargs = PyDict::new(py);
173 kwargs.set_item(intern!(py, "dtype"), T::get_dtype(py))?;
174 Some(kwargs)
175 } else {
176 None
177 };
178
179 let array = as_array.call((ob,), kwargs.as_ref())?.extract()?;
180 Ok(Self(array, PhantomData))
181 }
182}
183
184/// Receiver for zero-dimensional arrays or array-like types.
185pub type PyArrayLike0<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix0, C>;
186
187/// Receiver for one-dimensional arrays or array-like types.
188pub type PyArrayLike1<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix1, C>;
189
190/// Receiver for two-dimensional arrays or array-like types.
191pub type PyArrayLike2<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix2, C>;
192
193/// Receiver for three-dimensional arrays or array-like types.
194pub type PyArrayLike3<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix3, C>;
195
196/// Receiver for four-dimensional arrays or array-like types.
197pub type PyArrayLike4<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix4, C>;
198
199/// Receiver for five-dimensional arrays or array-like types.
200pub type PyArrayLike5<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix5, C>;
201
202/// Receiver for six-dimensional arrays or array-like types.
203pub type PyArrayLike6<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix6, C>;
204
205/// Receiver for arrays or array-like types whose dimensionality is determined at runtime.
206pub type PyArrayLikeDyn<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, IxDyn, C>;