numpy/
array_like.rs

1use std::marker::PhantomData;
2use std::ops::Deref;
3
4use ndarray::{Array1, Dimension, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
5use pyo3::{types::PyAnyMethods, Borrowed, FromPyObject, PyAny, PyErr, PyResult};
6
7use crate::npyffi::NPY_ARRAY_FORCECAST;
8use crate::{array::PyArrayMethods, PY_ARRAY_API};
9use crate::{Element, IntoPyArray, PyArray, PyReadonlyArray, PyUntypedArray};
10
11pub trait Coerce: Sealed {
12    const ALLOW_TYPE_CHANGE: bool;
13}
14
15mod sealed {
16    pub trait Sealed {}
17}
18
19use sealed::Sealed;
20
21/// Marker type to indicate that the element type received via [`PyArrayLike`] must match the specified type exactly.
22#[derive(Debug)]
23pub struct TypeMustMatch;
24
25impl Sealed for TypeMustMatch {}
26
27impl Coerce for TypeMustMatch {
28    const ALLOW_TYPE_CHANGE: bool = false;
29}
30
31/// Marker type to indicate that the element type received via [`PyArrayLike`] can be cast to the specified type by NumPy's [`asarray`](https://numpy.org/doc/stable/reference/generated/numpy.asarray.html).
32#[derive(Debug)]
33pub struct AllowTypeChange;
34
35impl Sealed for AllowTypeChange {}
36
37impl Coerce for AllowTypeChange {
38    const ALLOW_TYPE_CHANGE: bool = true;
39}
40
41/// Receiver for arrays or array-like types.
42///
43/// When building API using NumPy in Python, it is common for functions to additionally accept any array-like type such as `list[float]` as arguments.
44/// `PyArrayLike` enables the same pattern in Rust extensions, i.e. by taking this type as the argument of a `#[pyfunction]`,
45/// one will always get access to a [`PyReadonlyArray`] that will either reference to the NumPy array originally passed into the function
46/// or a temporary one created by converting the input type into a NumPy array.
47///
48/// Depending on whether [`TypeMustMatch`] or [`AllowTypeChange`] is used for the `C` type parameter,
49/// the element type must either match the specific type `T` exactly or will be cast to it by NumPy's [`asarray`](https://numpy.org/doc/stable/reference/generated/numpy.asarray.html).
50///
51/// # Example
52///
53/// `PyArrayLike1<'py, T, TypeMustMatch>` will enable you to receive both NumPy arrays and sequences
54///
55/// ```rust
56/// # use pyo3::prelude::*;
57/// use pyo3::py_run;
58/// use numpy::{get_array_module, PyArrayLike1, TypeMustMatch};
59///
60/// #[pyfunction]
61/// fn sum_up<'py>(py: Python<'py>, array: PyArrayLike1<'py, f64, TypeMustMatch>) -> f64 {
62///     array.as_array().sum()
63/// }
64///
65/// Python::attach(|py| {
66///     let np = get_array_module(py).unwrap();
67///     let sum_up = wrap_pyfunction!(sum_up)(py).unwrap();
68///
69///     py_run!(py, np sum_up, r"assert sum_up(np.array([1., 2., 3.])) == 6.");
70///     py_run!(py, np sum_up, r"assert sum_up((1., 2., 3.)) == 6.");
71/// });
72/// ```
73///
74/// but it will not cast the element type if that is required
75///
76/// ```rust,should_panic
77/// use pyo3::prelude::*;
78/// use pyo3::py_run;
79/// use numpy::{get_array_module, PyArrayLike1, TypeMustMatch};
80///
81/// #[pyfunction]
82/// fn sum_up<'py>(py: Python<'py>, array: PyArrayLike1<'py, i32, TypeMustMatch>) -> i32 {
83///     array.as_array().sum()
84/// }
85///
86/// Python::attach(|py| {
87///     let np = get_array_module(py).unwrap();
88///     let sum_up = wrap_pyfunction!(sum_up)(py).unwrap();
89///
90///     py_run!(py, np sum_up, r"assert sum_up((1., 2., 3.)) == 6");
91/// });
92/// ```
93///
94/// whereas `PyArrayLike1<'py, T, AllowTypeChange>` will do even at the cost loosing precision
95///
96/// ```rust
97/// use pyo3::prelude::*;
98/// use pyo3::py_run;
99/// use numpy::{get_array_module, AllowTypeChange, PyArrayLike1};
100///
101/// #[pyfunction]
102/// fn sum_up<'py>(py: Python<'py>, array: PyArrayLike1<'py, i32, AllowTypeChange>) -> i32 {
103///     array.as_array().sum()
104/// }
105///
106/// Python::attach(|py| {
107///     let np = get_array_module(py).unwrap();
108///     let sum_up = wrap_pyfunction!(sum_up)(py).unwrap();
109///
110///     py_run!(py, np sum_up, r"assert sum_up((1.5, 2.5)) == 3");
111/// });
112/// ```
113#[derive(Debug)]
114#[repr(transparent)]
115pub struct PyArrayLike<'py, T, D, C = TypeMustMatch>(PyReadonlyArray<'py, T, D>, PhantomData<C>)
116where
117    T: Element,
118    D: Dimension,
119    C: Coerce;
120
121impl<'py, T, D, C> Deref for PyArrayLike<'py, T, D, C>
122where
123    T: Element,
124    D: Dimension,
125    C: Coerce,
126{
127    type Target = PyReadonlyArray<'py, T, D>;
128
129    fn deref(&self) -> &Self::Target {
130        &self.0
131    }
132}
133
134impl<'a, 'py, T, D, C> FromPyObject<'a, 'py> for PyArrayLike<'py, T, D, C>
135where
136    T: Element + 'py,
137    D: Dimension + 'py,
138    C: Coerce,
139    Vec<T>: FromPyObject<'a, 'py>,
140{
141    type Error = PyErr;
142
143    fn extract(ob: Borrowed<'a, 'py, PyAny>) -> PyResult<Self> {
144        if let Ok(array) = ob.cast::<PyArray<T, D>>() {
145            return Ok(Self(array.readonly(), PhantomData));
146        }
147
148        let py = ob.py();
149
150        // If the input is already an ndarray and `TypeMustMatch` is used then no type conversion
151        // should be performed.
152        if (C::ALLOW_TYPE_CHANGE || ob.cast::<PyUntypedArray>().is_err())
153            && matches!(D::NDIM, None | Some(1))
154        {
155            if let Ok(vec) = ob.extract::<Vec<T>>() {
156                let array = Array1::from(vec)
157                    .into_dimensionality()
158                    .expect("D being compatible to Ix1")
159                    .into_pyarray(py)
160                    .readonly();
161                return Ok(Self(array, PhantomData));
162            }
163        }
164
165        let (dtype, flags) = if C::ALLOW_TYPE_CHANGE {
166            (Some(T::get_dtype(py)), NPY_ARRAY_FORCECAST)
167        } else {
168            (None, 0)
169        };
170
171        let newtype = dtype
172            .map(|dt| dt.into_ptr().cast())
173            .unwrap_or_else(std::ptr::null_mut);
174
175        let array = unsafe {
176            let ptr = PY_ARRAY_API.PyArray_FromAny(
177                py,
178                ob.as_ptr(),
179                newtype,
180                0,
181                0,
182                flags,
183                std::ptr::null_mut(),
184            );
185
186            pyo3::Bound::from_owned_ptr_or_err(py, ptr)?
187        };
188
189        Ok(Self(array.extract()?, PhantomData))
190    }
191}
192
193/// Receiver for zero-dimensional arrays or array-like types.
194pub type PyArrayLike0<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix0, C>;
195
196/// Receiver for one-dimensional arrays or array-like types.
197pub type PyArrayLike1<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix1, C>;
198
199/// Receiver for two-dimensional arrays or array-like types.
200pub type PyArrayLike2<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix2, C>;
201
202/// Receiver for three-dimensional arrays or array-like types.
203pub type PyArrayLike3<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix3, C>;
204
205/// Receiver for four-dimensional arrays or array-like types.
206pub type PyArrayLike4<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix4, C>;
207
208/// Receiver for five-dimensional arrays or array-like types.
209pub type PyArrayLike5<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix5, C>;
210
211/// Receiver for six-dimensional arrays or array-like types.
212pub type PyArrayLike6<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, Ix6, C>;
213
214/// Receiver for arrays or array-like types whose dimensionality is determined at runtime.
215pub type PyArrayLikeDyn<'py, T, C = TypeMustMatch> = PyArrayLike<'py, T, IxDyn, C>;