numpy/array_like.rs
1use std::marker::PhantomData;
2use std::ops::Deref;
3
4use ndarray::{Array1, Dimension, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
5use pyo3::{types::PyAnyMethods, Borrowed, FromPyObject, PyAny, PyErr, PyResult};
6
7use crate::npyffi::NPY_ARRAY_FORCECAST;
8use crate::{array::PyArrayMethods, PY_ARRAY_API};
9use crate::{Element, IntoPyArray, PyArray, PyReadonlyArray, PyUntypedArray};
10
11pub trait Coerce: Sealed {
12 const ALLOW_TYPE_CHANGE: bool;
13}
14
15mod sealed {
16 pub trait Sealed {}
17}
18
19use sealed::Sealed;
20
21/// Marker type to indicate that the element type received via [`PyArrayLike`] must match the specified type exactly.
22#[derive(Debug)]
23pub struct TypeMustMatch;
24
25impl Sealed for TypeMustMatch {}
26
27impl Coerce for TypeMustMatch {
28 const ALLOW_TYPE_CHANGE: bool = false;
29}
30
31/// Marker type to indicate that the element type received via [`PyArrayLike`] can be cast to the specified type by NumPy's [`asarray`](https://numpy.org/doc/stable/reference/generated/numpy.asarray.html).
32#[derive(Debug)]
33pub struct AllowTypeChange;
34
35impl Sealed for AllowTypeChange {}
36
37impl Coerce for AllowTypeChange {
38 const ALLOW_TYPE_CHANGE: bool = true;
39}
40
41/// Receiver for arrays or array-like types.
42///
43/// When building API using NumPy in Python, it is common for functions to additionally accept any array-like type such as `list[float]` as arguments.
44/// `PyArrayLike` enables the same pattern in Rust extensions, i.e. by taking this type as the argument of a `#[pyfunction]`,
45/// one will always get access to a [`PyReadonlyArray`] that will either reference to the NumPy array originally passed into the function
46/// or a temporary one created by converting the input type into a NumPy array.
47///
48/// Depending on whether [`TypeMustMatch`] or [`AllowTypeChange`] is used for the `C` type parameter,
49/// the element type must either match the specific type `T` exactly or will be cast to it by NumPy's [`asarray`](https://numpy.org/doc/stable/reference/generated/numpy.asarray.html).
50///
51/// # Example
52///
53/// `PyArrayLike1<'py, T, TypeMustMatch>` will enable you to receive both NumPy arrays and sequences
54///
55/// ```rust
56/// # use pyo3::prelude::*;
57/// use pyo3::py_run;
58/// use numpy::{get_array_module, PyArrayLike1, TypeMustMatch};
59///
60/// #[pyfunction]
61/// fn sum_up<'py>(py: Python<'py>, array: PyArrayLike1<'py, f64, TypeMustMatch>) -> f64 {
62/// array.as_array().sum()
63/// }
64///
65/// Python::attach(|py| {
66/// let np = get_array_module(py).unwrap();
67/// let sum_up = wrap_pyfunction!(sum_up)(py).unwrap();
68///
69/// py_run!(py, np sum_up, r"assert sum_up(np.array([1., 2., 3.])) == 6.");
70/// py_run!(py, np sum_up, r"assert sum_up((1., 2., 3.)) == 6.");
71/// });
72/// ```
73///
74/// but it will not cast the element type if that is required
75///
76/// ```rust,should_panic
77/// use pyo3::prelude::*;
78/// use pyo3::py_run;
79/// use numpy::{get_array_module, PyArrayLike1, TypeMustMatch};
80///
81/// #[pyfunction]
82/// fn sum_up<'py>(py: Python<'py>, array: PyArrayLike1<'py, i32, TypeMustMatch>) -> i32 {
83/// array.as_array().sum()
84/// }
85///
86/// Python::attach(|py| {
87/// let np = get_array_module(py).unwrap();
88/// let sum_up = wrap_pyfunction!(sum_up)(py).unwrap();
89///
90/// py_run!(py, np sum_up, r"assert sum_up((1., 2., 3.)) == 6");
91/// });
92/// ```
93///
94/// whereas `PyArrayLike1<'py, T, AllowTypeChange>` will do even at the cost loosing precision
95///
96/// ```rust
97/// use pyo3::prelude::*;
98/// use pyo3::py_run;
99/// use numpy::{get_array_module, AllowTypeChange, PyArrayLike1};
100///
101/// #[pyfunction]
102/// fn sum_up<'py>(py: Python<'py>, array: PyArrayLike1<'py, i32, AllowTypeChange>) -> i32 {
103/// array.as_array().sum()
104/// }
105///
106/// Python::attach(|py| {
107/// let np = get_array_module(py).unwrap();
108/// let sum_up = wrap_pyfunction!(sum_up)(py).unwrap();
109///
110/// py_run!(py, np sum_up, r"assert sum_up((1.5, 2.5)) == 3");
111/// });
112/// ```
113#[derive(Debug)]
114#[repr(transparent)]
115pub struct PyArrayLike<'py, T, D, C = TypeMustMatch>(PyReadonlyArray<'py, T, D>, PhantomData<C>)
116where
117 T: Element,
118 D: Dimension,
119 C: Coerce;
120
121impl<'py, T, D, C> Deref for PyArrayLike<'py, T, D, C>
122where
123 T: Element,
124 D: Dimension,
125 C: Coerce,
126{
127 type Target = PyReadonlyArray<'py, T, D>;
128
129 fn deref(&self) -> &Self::Target {
130 &self.0
131 }
132}
133
134impl<'a, 'py, T, D, C> FromPyObject<'a, 'py> for PyArrayLike<'py, T, D, C>
135where
136 T: Element + 'py,
137 D: Dimension + 'py,
138 C: Coerce,
139 Vec<T>: FromPyObject<'a, 'py>,
140{
141 type Error = PyErr;
142
143 fn extract(ob: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
144 if let Ok(array) = ob.cast::<PyArray<T, D>>() {
145 return Ok(Self(array.readonly(), PhantomData));
146 }
147
148 let py = ob.py();
149
150 // If the input is already an ndarray and `TypeMustMatch` is used then no type conversion
151 // should be performed.
152 if (C::ALLOW_TYPE_CHANGE || ob.cast::<PyUntypedArray>().is_err())
153 && matches!(D::NDIM, None | Some(1))
154 {
155 if let Ok(vec) = ob.extract::<Vec<T>>() {
156 let array = Array1::from(vec)
157 .into_dimensionality()
158 .expect("D being compatible to Ix1")
159 .into_pyarray(py)
160 .readonly();
161 return Ok(Self(array, PhantomData));
162 }
163 }
164
165 let (dtype, flags) = if C::ALLOW_TYPE_CHANGE {
166 (Some(T::get_dtype(py)), NPY_ARRAY_FORCECAST)
167 } else {
168 (None, 0)
169 };
170
171 let newtype = dtype
172 .map(|dt| dt.into_ptr().cast())
173 .unwrap_or_else(std::ptr::null_mut);
174
175 let array = unsafe {
176 let ptr = PY_ARRAY_API.PyArray_FromAny(
177 py,
178 ob.as_ptr(),
179 newtype,
180 0,
181 0,
182 flags,
183 std::ptr::null_mut(),
184 );
185
186 pyo3::Bound::from_owned_ptr_or_err(py, ptr)?
187 };
188
189 Ok(Self(array.extract()?, PhantomData))
190 }
191}
192
193/// Receiver for zero-dimensional arrays or array-like types.
194pub type PyArrayLike0<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix0, C>;
195
196/// Receiver for one-dimensional arrays or array-like types.
197pub type PyArrayLike1<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix1, C>;
198
199/// Receiver for two-dimensional arrays or array-like types.
200pub type PyArrayLike2<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix2, C>;
201
202/// Receiver for three-dimensional arrays or array-like types.
203pub type PyArrayLike3<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix3, C>;
204
205/// Receiver for four-dimensional arrays or array-like types.
206pub type PyArrayLike4<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix4, C>;
207
208/// Receiver for five-dimensional arrays or array-like types.
209pub type PyArrayLike5<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix5, C>;
210
211/// Receiver for six-dimensional arrays or array-like types.
212pub type PyArrayLike6<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix6, C>;
213
214/// Receiver for arrays or array-like types whose dimensionality is determined at runtime.
215pub type PyArrayLikeDyn<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, IxDyn, C>;