numpy/array_like.rs
1use std::marker::PhantomData;
2use std::ops::Deref;
3
4use ndarray::{Array1, Dimension, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
5use pyo3::{
6 intern,
7 sync::PyOnceLock,
8 types::{PyAnyMethods, PyDict},
9 Borrowed, FromPyObject, Py, PyAny, PyErr, PyResult,
10};
11
12use crate::array::PyArrayMethods;
13use crate::{get_array_module, Element, IntoPyArray, PyArray, PyReadonlyArray};
14
15pub trait Coerce: Sealed {
16 const VAL: bool;
17}
18
19mod sealed {
20 pub trait Sealed {}
21}
22
23use sealed::Sealed;
24
25/// Marker type to indicate that the element type received via [`PyArrayLike`] must match the specified type exactly.
26#[derive(Debug)]
27pub struct TypeMustMatch;
28
29impl Sealed for TypeMustMatch {}
30
31impl Coerce for TypeMustMatch {
32 const VAL: bool = false;
33}
34
35/// Marker type to indicate that the element type received via [`PyArrayLike`] can be cast to the specified type by NumPy's [`asarray`](https://numpy.org/doc/stable/reference/generated/numpy.asarray.html).
36#[derive(Debug)]
37pub struct AllowTypeChange;
38
39impl Sealed for AllowTypeChange {}
40
41impl Coerce for AllowTypeChange {
42 const VAL: bool = true;
43}
44
45/// Receiver for arrays or array-like types.
46///
47/// When building API using NumPy in Python, it is common for functions to additionally accept any array-like type such as `list[float]` as arguments.
48/// `PyArrayLike` enables the same pattern in Rust extensions, i.e. by taking this type as the argument of a `#[pyfunction]`,
49/// one will always get access to a [`PyReadonlyArray`] that will either reference to the NumPy array originally passed into the function
50/// or a temporary one created by converting the input type into a NumPy array.
51///
52/// Depending on whether [`TypeMustMatch`] or [`AllowTypeChange`] is used for the `C` type parameter,
53/// the element type must either match the specific type `T` exactly or will be cast to it by NumPy's [`asarray`](https://numpy.org/doc/stable/reference/generated/numpy.asarray.html).
54///
55/// # Example
56///
57/// `PyArrayLike1<'py, T, TypeMustMatch>` will enable you to receive both NumPy arrays and sequences
58///
59/// ```rust
60/// # use pyo3::prelude::*;
61/// use pyo3::py_run;
62/// use numpy::{get_array_module, PyArrayLike1, TypeMustMatch};
63///
64/// #[pyfunction]
65/// fn sum_up<'py>(py: Python<'py>, array: PyArrayLike1<'py, f64, TypeMustMatch>) -> f64 {
66/// array.as_array().sum()
67/// }
68///
69/// Python::attach(|py| {
70/// let np = get_array_module(py).unwrap();
71/// let sum_up = wrap_pyfunction!(sum_up)(py).unwrap();
72///
73/// py_run!(py, np sum_up, r"assert sum_up(np.array([1., 2., 3.])) == 6.");
74/// py_run!(py, np sum_up, r"assert sum_up((1., 2., 3.)) == 6.");
75/// });
76/// ```
77///
78/// but it will not cast the element type if that is required
79///
80/// ```rust,should_panic
81/// use pyo3::prelude::*;
82/// use pyo3::py_run;
83/// use numpy::{get_array_module, PyArrayLike1, TypeMustMatch};
84///
85/// #[pyfunction]
86/// fn sum_up<'py>(py: Python<'py>, array: PyArrayLike1<'py, i32, TypeMustMatch>) -> i32 {
87/// array.as_array().sum()
88/// }
89///
90/// Python::attach(|py| {
91/// let np = get_array_module(py).unwrap();
92/// let sum_up = wrap_pyfunction!(sum_up)(py).unwrap();
93///
94/// py_run!(py, np sum_up, r"assert sum_up((1., 2., 3.)) == 6");
95/// });
96/// ```
97///
98/// whereas `PyArrayLike1<'py, T, AllowTypeChange>` will do even at the cost loosing precision
99///
100/// ```rust
101/// use pyo3::prelude::*;
102/// use pyo3::py_run;
103/// use numpy::{get_array_module, AllowTypeChange, PyArrayLike1};
104///
105/// #[pyfunction]
106/// fn sum_up<'py>(py: Python<'py>, array: PyArrayLike1<'py, i32, AllowTypeChange>) -> i32 {
107/// array.as_array().sum()
108/// }
109///
110/// Python::attach(|py| {
111/// let np = get_array_module(py).unwrap();
112/// let sum_up = wrap_pyfunction!(sum_up)(py).unwrap();
113///
114/// py_run!(py, np sum_up, r"assert sum_up((1.5, 2.5)) == 3");
115/// });
116/// ```
117#[derive(Debug)]
118#[repr(transparent)]
119pub struct PyArrayLike<'py, T, D, C = TypeMustMatch>(PyReadonlyArray<'py, T, D>, PhantomData<C>)
120where
121 T: Element,
122 D: Dimension,
123 C: Coerce;
124
125impl<'py, T, D, C> Deref for PyArrayLike<'py, T, D, C>
126where
127 T: Element,
128 D: Dimension,
129 C: Coerce,
130{
131 type Target = PyReadonlyArray<'py, T, D>;
132
133 fn deref(&self) -> &Self::Target {
134 &self.0
135 }
136}
137
138impl<'a, 'py, T, D, C> FromPyObject<'a, 'py> for PyArrayLike<'py, T, D, C>
139where
140 T: Element + 'py,
141 D: Dimension + 'py,
142 C: Coerce,
143 Vec<T>: FromPyObject<'a, 'py>,
144{
145 type Error = PyErr;
146
147 fn extract(ob: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
148 if let Ok(array) = ob.cast::<PyArray<T, D>>() {
149 return Ok(Self(array.readonly(), PhantomData));
150 }
151
152 let py = ob.py();
153
154 if matches!(D::NDIM, None | Some(1)) {
155 if let Ok(vec) = ob.extract::<Vec<T>>() {
156 let array = Array1::from(vec)
157 .into_dimensionality()
158 .expect("D being compatible to Ix1")
159 .into_pyarray(py)
160 .readonly();
161 return Ok(Self(array, PhantomData));
162 }
163 }
164
165 static AS_ARRAY: PyOnceLock<Py<PyAny>> = PyOnceLock::new();
166
167 let as_array = AS_ARRAY
168 .get_or_try_init(py, || {
169 get_array_module(py)?.getattr("asarray").map(Into::into)
170 })?
171 .bind(py);
172
173 let kwargs = if C::VAL {
174 let kwargs = PyDict::new(py);
175 kwargs.set_item(intern!(py, "dtype"), T::get_dtype(py))?;
176 Some(kwargs)
177 } else {
178 None
179 };
180
181 let array = as_array.call((ob,), kwargs.as_ref())?.extract()?;
182 Ok(Self(array, PhantomData))
183 }
184}
185
186/// Receiver for zero-dimensional arrays or array-like types.
187pub type PyArrayLike0<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix0, C>;
188
189/// Receiver for one-dimensional arrays or array-like types.
190pub type PyArrayLike1<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix1, C>;
191
192/// Receiver for two-dimensional arrays or array-like types.
193pub type PyArrayLike2<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix2, C>;
194
195/// Receiver for three-dimensional arrays or array-like types.
196pub type PyArrayLike3<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix3, C>;
197
198/// Receiver for four-dimensional arrays or array-like types.
199pub type PyArrayLike4<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix4, C>;
200
201/// Receiver for five-dimensional arrays or array-like types.
202pub type PyArrayLike5<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix5, C>;
203
204/// Receiver for six-dimensional arrays or array-like types.
205pub type PyArrayLike6<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix6, C>;
206
207/// Receiver for arrays or array-like types whose dimensionality is determined at runtime.
208pub type PyArrayLikeDyn<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, IxDyn, C>;