numpy/
array_like.rs

1use std::marker::PhantomData;
2use std::ops::Deref;
3
4use ndarray::{Array1, Dimension, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
5use pyo3::{
6    intern,
7    sync::GILOnceCell,
8    types::{PyAnyMethods, PyDict},
9    Bound, FromPyObject, Py, PyAny, PyResult,
10};
11
12use crate::array::PyArrayMethods;
13use crate::{get_array_module, Element, IntoPyArray, PyArray, PyReadonlyArray};
14
15pub trait Coerce: Sealed {
16    const VAL: bool;
17}
18
19mod sealed {
20    pub trait Sealed {}
21}
22
23use sealed::Sealed;
24
25/// Marker type to indicate that the element type received via [`PyArrayLike`] must match the specified type exactly.
26#[derive(Debug)]
27pub struct TypeMustMatch;
28
29impl Sealed for TypeMustMatch {}
30
31impl Coerce for TypeMustMatch {
32    const VAL: bool = false;
33}
34
35/// Marker type to indicate that the element type received via [`PyArrayLike`] can be cast to the specified type by NumPy's [`asarray`](https://numpy.org/doc/stable/reference/generated/numpy.asarray.html).
36#[derive(Debug)]
37pub struct AllowTypeChange;
38
39impl Sealed for AllowTypeChange {}
40
41impl Coerce for AllowTypeChange {
42    const VAL: bool = true;
43}
44
45/// Receiver for arrays or array-like types.
46///
47/// When building API using NumPy in Python, it is common for functions to additionally accept any array-like type such as `list[float]` as arguments.
48/// `PyArrayLike` enables the same pattern in Rust extensions, i.e. by taking this type as the argument of a `#[pyfunction]`,
49/// one will always get access to a [`PyReadonlyArray`] that will either reference to the NumPy array originally passed into the function
50/// or a temporary one created by converting the input type into a NumPy array.
51///
52/// Depending on whether [`TypeMustMatch`] or [`AllowTypeChange`] is used for the `C` type parameter,
53/// the element type must either match the specific type `T` exactly or will be cast to it by NumPy's [`asarray`](https://numpy.org/doc/stable/reference/generated/numpy.asarray.html).
54///
55/// # Example
56///
57/// `PyArrayLike1<'py, T, TypeMustMatch>` will enable you to receive both NumPy arrays and sequences
58///
59/// ```rust
60/// # use pyo3::prelude::*;
61/// use pyo3::py_run;
62/// use numpy::{get_array_module, PyArrayLike1, TypeMustMatch};
63///
64/// #[pyfunction]
65/// fn sum_up<'py>(py: Python<'py>, array: PyArrayLike1<'py, f64, TypeMustMatch>) -> f64 {
66///     array.as_array().sum()
67/// }
68///
69/// Python::with_gil(|py| {
70///     let np = get_array_module(py).unwrap();
71///     let sum_up = wrap_pyfunction!(sum_up)(py).unwrap();
72///
73///     py_run!(py, np sum_up, r"assert sum_up(np.array([1., 2., 3.])) == 6.");
74///     py_run!(py, np sum_up, r"assert sum_up((1., 2., 3.)) == 6.");
75/// });
76/// ```
77///
78/// but it will not cast the element type if that is required
79///
80/// ```rust,should_panic
81/// use pyo3::prelude::*;
82/// use pyo3::py_run;
83/// use numpy::{get_array_module, PyArrayLike1, TypeMustMatch};
84///
85/// #[pyfunction]
86/// fn sum_up<'py>(py: Python<'py>, array: PyArrayLike1<'py, i32, TypeMustMatch>) -> i32 {
87///     array.as_array().sum()
88/// }
89///
90/// Python::with_gil(|py| {
91///     let np = get_array_module(py).unwrap();
92///     let sum_up = wrap_pyfunction!(sum_up)(py).unwrap();
93///
94///     py_run!(py, np sum_up, r"assert sum_up((1., 2., 3.)) == 6");
95/// });
96/// ```
97///
98/// whereas `PyArrayLike1<'py, T, AllowTypeChange>` will do even at the cost loosing precision
99///
100/// ```rust
101/// use pyo3::prelude::*;
102/// use pyo3::py_run;
103/// use numpy::{get_array_module, AllowTypeChange, PyArrayLike1};
104///
105/// #[pyfunction]
106/// fn sum_up<'py>(py: Python<'py>, array: PyArrayLike1<'py, i32, AllowTypeChange>) -> i32 {
107///     array.as_array().sum()
108/// }
109///
110/// Python::with_gil(|py| {
111///     let np = get_array_module(py).unwrap();
112///     let sum_up = wrap_pyfunction!(sum_up)(py).unwrap();
113///
114///     py_run!(py, np sum_up, r"assert sum_up((1.5, 2.5)) == 3");
115/// });
116/// ```
117#[derive(Debug)]
118#[repr(transparent)]
119pub struct PyArrayLike<'py, T, D, C = TypeMustMatch>(PyReadonlyArray<'py, T, D>, PhantomData<C>)
120where
121    T: Element,
122    D: Dimension,
123    C: Coerce;
124
125impl<'py, T, D, C> Deref for PyArrayLike<'py, T, D, C>
126where
127    T: Element,
128    D: Dimension,
129    C: Coerce,
130{
131    type Target = PyReadonlyArray<'py, T, D>;
132
133    fn deref(&self) -> &Self::Target {
134        &self.0
135    }
136}
137
138impl<'py, T, D, C> FromPyObject<'py> for PyArrayLike<'py, T, D, C>
139where
140    T: Element + 'py,
141    D: Dimension + 'py,
142    C: Coerce,
143    Vec<T>: FromPyObject<'py>,
144{
145    fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
146        if let Ok(array) = ob.downcast::<PyArray<T, D>>() {
147            return Ok(Self(array.readonly(), PhantomData));
148        }
149
150        let py = ob.py();
151
152        if matches!(D::NDIM, None | Some(1)) {
153            if let Ok(vec) = ob.extract::<Vec<T>>() {
154                let array = Array1::from(vec)
155                    .into_dimensionality()
156                    .expect("D being compatible to Ix1")
157                    .into_pyarray(py)
158                    .readonly();
159                return Ok(Self(array, PhantomData));
160            }
161        }
162
163        static AS_ARRAY: GILOnceCell<Py<PyAny>> = GILOnceCell::new();
164
165        let as_array = AS_ARRAY
166            .get_or_try_init(py, || {
167                get_array_module(py)?.getattr("asarray").map(Into::into)
168            })?
169            .bind(py);
170
171        let kwargs = if C::VAL {
172            let kwargs = PyDict::new(py);
173            kwargs.set_item(intern!(py, "dtype"), T::get_dtype(py))?;
174            Some(kwargs)
175        } else {
176            None
177        };
178
179        let array = as_array.call((ob,), kwargs.as_ref())?.extract()?;
180        Ok(Self(array, PhantomData))
181    }
182}
183
184/// Receiver for zero-dimensional arrays or array-like types.
185pub type PyArrayLike0<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix0, C>;
186
187/// Receiver for one-dimensional arrays or array-like types.
188pub type PyArrayLike1<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix1, C>;
189
190/// Receiver for two-dimensional arrays or array-like types.
191pub type PyArrayLike2<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix2, C>;
192
193/// Receiver for three-dimensional arrays or array-like types.
194pub type PyArrayLike3<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix3, C>;
195
196/// Receiver for four-dimensional arrays or array-like types.
197pub type PyArrayLike4<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix4, C>;
198
199/// Receiver for five-dimensional arrays or array-like types.
200pub type PyArrayLike5<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix5, C>;
201
202/// Receiver for six-dimensional arrays or array-like types.
203pub type PyArrayLike6<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix6, C>;
204
205/// Receiver for arrays or array-like types whose dimensionality is determined at runtime.
206pub type PyArrayLikeDyn<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, IxDyn, C>;