numpy/
untyped_array.rs

1//! Safe, untyped interface for NumPy's [N-dimensional arrays][ndarray]
2//!
3//! [ndarray]: https://numpy.org/doc/stable/reference/arrays.ndarray.html
4use std::slice;
5
6use pyo3::{
7    ffi, pyobject_native_type_named, types::PyAnyMethods, Bound, PyAny, PyTypeInfo, Python,
8};
9
10use crate::array::{PyArray, PyArrayMethods};
11use crate::cold;
12use crate::dtype::PyArrayDescr;
13use crate::npyffi;
14
15/// A safe, untyped wrapper for NumPy's [`ndarray`] class.
16///
17/// Unlike [`PyArray<T,D>`][crate::PyArray], this type does not constrain either element type `T` nor the dimensionality `D`.
18/// This can be useful to inspect function arguments, but it prevents operating on the elements without further downcasts.
19///
20/// When both element type `T` and index type `D` are known, these values can be downcast to `PyArray<T, D>`. In addition,
21/// `PyArray<T, D>` can be dereferenced to a `PyUntypedArray` and can therefore automatically access its methods.
22///
23/// # Example
24///
25/// Taking `PyUntypedArray` can be helpful to implement polymorphic entry points:
26///
27/// ```
28/// # use pyo3::prelude::*;
29/// use pyo3::exceptions::PyTypeError;
30/// use numpy::{Element, PyUntypedArray, PyArray1, dtype};
31/// use numpy::{PyUntypedArrayMethods, PyArrayMethods, PyArrayDescrMethods};
32///
33/// #[pyfunction]
34/// fn entry_point(py: Python<'_>, array: &Bound<'_, PyUntypedArray>) -> PyResult<()> {
35///     fn implementation<T: Element>(array: &Bound<'_, PyArray1<T>>) -> PyResult<()> {
36///         /* .. */
37///
38///         Ok(())
39///     }
40///
41///     let element_type = array.dtype();
42///
43///     if element_type.is_equiv_to(&dtype::<f32>(py)) {
44///         let array = array.downcast::<PyArray1<f32>>()?;
45///
46///         implementation(array)
47///     } else if element_type.is_equiv_to(&dtype::<f64>(py)) {
48///         let array = array.downcast::<PyArray1<f64>>()?;
49///
50///         implementation(array)
51///     } else {
52///         Err(PyTypeError::new_err(format!("Unsupported element type: {}", element_type)))
53///     }
54/// }
55/// #
56/// # Python::with_gil(|py| {
57/// #   let array = PyArray1::<f64>::zeros(py, 42, false);
58/// #   entry_point(py, array.as_untyped())
59/// # }).unwrap();
60/// ```
61#[repr(transparent)]
62pub struct PyUntypedArray(PyAny);
63
64unsafe impl PyTypeInfo for PyUntypedArray {
65    const NAME: &'static str = "PyUntypedArray";
66    const MODULE: Option<&'static str> = Some("numpy");
67
68    fn type_object_raw<'py>(py: Python<'py>) -> *mut ffi::PyTypeObject {
69        unsafe { npyffi::PY_ARRAY_API.get_type_object(py, npyffi::NpyTypes::PyArray_Type) }
70    }
71
72    fn is_type_of(ob: &Bound<'_, PyAny>) -> bool {
73        unsafe { npyffi::PyArray_Check(ob.py(), ob.as_ptr()) != 0 }
74    }
75}
76
77pyobject_native_type_named!(PyUntypedArray);
78
79/// Implementation of functionality for [`PyUntypedArray`].
80#[doc(alias = "PyUntypedArray")]
81pub trait PyUntypedArrayMethods<'py>: Sealed {
82    /// Returns a raw pointer to the underlying [`PyArrayObject`][npyffi::PyArrayObject].
83    fn as_array_ptr(&self) -> *mut npyffi::PyArrayObject;
84
85    /// Returns the `dtype` of the array.
86    ///
87    /// See also [`ndarray.dtype`][ndarray-dtype] and [`PyArray_DTYPE`][PyArray_DTYPE].
88    ///
89    /// # Example
90    ///
91    /// ```
92    /// use numpy::prelude::*;
93    /// use numpy::{dtype, PyArray};
94    /// use pyo3::Python;
95    ///
96    /// Python::with_gil(|py| {
97    ///    let array = PyArray::from_vec(py, vec![1_i32, 2, 3]);
98    ///
99    ///    assert!(array.dtype().is_equiv_to(&dtype::<i32>(py)));
100    /// });
101    /// ```
102    ///
103    /// [ndarray-dtype]: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.dtype.html
104    /// [PyArray_DTYPE]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_DTYPE
105    fn dtype(&self) -> Bound<'py, PyArrayDescr>;
106
107    /// Returns `true` if the internal data of the array is contiguous,
108    /// indepedently of whether C-style/row-major or Fortran-style/column-major.
109    ///
110    /// # Example
111    ///
112    /// ```
113    /// use numpy::{PyArray1, PyUntypedArrayMethods};
114    /// use pyo3::{types::{IntoPyDict, PyAnyMethods}, Python, ffi::c_str};
115    ///
116    /// # fn main() -> pyo3::PyResult<()> {
117    /// Python::with_gil(|py| {
118    ///     let array = PyArray1::arange(py, 0, 10, 1);
119    ///     assert!(array.is_contiguous());
120    ///
121    ///     let view = py
122    ///         .eval(c_str!("array[::2]"), None, Some(&[("array", array)].into_py_dict(py)?))?
123    ///         .downcast_into::<PyArray1<i32>>()?;
124    ///     assert!(!view.is_contiguous());
125    /// #   Ok(())
126    /// })
127    /// # }
128    /// ```
129    fn is_contiguous(&self) -> bool {
130        unsafe {
131            check_flags(
132                &*self.as_array_ptr(),
133                npyffi::NPY_ARRAY_C_CONTIGUOUS | npyffi::NPY_ARRAY_F_CONTIGUOUS,
134            )
135        }
136    }
137
138    /// Returns `true` if the internal data of the array is Fortran-style/column-major contiguous.
139    fn is_fortran_contiguous(&self) -> bool {
140        unsafe { check_flags(&*self.as_array_ptr(), npyffi::NPY_ARRAY_F_CONTIGUOUS) }
141    }
142
143    /// Returns `true` if the internal data of the array is C-style/row-major contiguous.
144    fn is_c_contiguous(&self) -> bool {
145        unsafe { check_flags(&*self.as_array_ptr(), npyffi::NPY_ARRAY_C_CONTIGUOUS) }
146    }
147
148    /// Returns the number of dimensions of the array.
149    ///
150    /// See also [`ndarray.ndim`][ndarray-ndim] and [`PyArray_NDIM`][PyArray_NDIM].
151    ///
152    /// # Example
153    ///
154    /// ```
155    /// use numpy::{PyArray3, PyUntypedArrayMethods};
156    /// use pyo3::Python;
157    ///
158    /// Python::with_gil(|py| {
159    ///     let arr = PyArray3::<f64>::zeros(py, [4, 5, 6], false);
160    ///
161    ///     assert_eq!(arr.ndim(), 3);
162    /// });
163    /// ```
164    ///
165    /// [ndarray-ndim]: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.ndim.html
166    /// [PyArray_NDIM]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_NDIM
167    #[inline]
168    fn ndim(&self) -> usize {
169        unsafe { (*self.as_array_ptr()).nd as usize }
170    }
171
172    /// Returns a slice indicating how many bytes to advance when iterating along each axis.
173    ///
174    /// See also [`ndarray.strides`][ndarray-strides] and [`PyArray_STRIDES`][PyArray_STRIDES].
175    ///
176    /// # Example
177    ///
178    /// ```
179    /// use numpy::{PyArray3, PyUntypedArrayMethods};
180    /// use pyo3::Python;
181    ///
182    /// Python::with_gil(|py| {
183    ///     let arr = PyArray3::<f64>::zeros(py, [4, 5, 6], false);
184    ///
185    ///     assert_eq!(arr.strides(), &[240, 48, 8]);
186    /// });
187    /// ```
188    /// [ndarray-strides]: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.strides.html
189    /// [PyArray_STRIDES]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_STRIDES
190    #[inline]
191    fn strides(&self) -> &[isize] {
192        let n = self.ndim();
193        if n == 0 {
194            cold();
195            return &[];
196        }
197        let ptr = self.as_array_ptr();
198        unsafe {
199            let p = (*ptr).strides;
200            slice::from_raw_parts(p, n)
201        }
202    }
203
204    /// Returns a slice which contains dimmensions of the array.
205    ///
206    /// See also [`ndarray.shape`][ndaray-shape] and [`PyArray_DIMS`][PyArray_DIMS].
207    ///
208    /// # Example
209    ///
210    /// ```
211    /// use numpy::{PyArray3, PyUntypedArrayMethods};
212    /// use pyo3::Python;
213    ///
214    /// Python::with_gil(|py| {
215    ///     let arr = PyArray3::<f64>::zeros(py, [4, 5, 6], false);
216    ///
217    ///     assert_eq!(arr.shape(), &[4, 5, 6]);
218    /// });
219    /// ```
220    ///
221    /// [ndarray-shape]: https://numpy.org/doc/stable/reference/generated/numpy.ndarray.shape.html
222    /// [PyArray_DIMS]: https://numpy.org/doc/stable/reference/c-api/array.html#c.PyArray_DIMS
223    #[inline]
224    fn shape(&self) -> &[usize] {
225        let n = self.ndim();
226        if n == 0 {
227            cold();
228            return &[];
229        }
230        let ptr = self.as_array_ptr();
231        unsafe {
232            let p = (*ptr).dimensions as *mut usize;
233            slice::from_raw_parts(p, n)
234        }
235    }
236
237    /// Calculates the total number of elements in the array.
238    fn len(&self) -> usize {
239        self.shape().iter().product()
240    }
241
242    /// Returns `true` if the there are no elements in the array.
243    fn is_empty(&self) -> bool {
244        self.shape().contains(&0)
245    }
246}
247
248mod sealed {
249    pub trait Sealed {}
250}
251
252use sealed::Sealed;
253
254fn check_flags(obj: &npyffi::PyArrayObject, flags: i32) -> bool {
255    obj.flags & flags != 0
256}
257
258impl<'py> PyUntypedArrayMethods<'py> for Bound<'py, PyUntypedArray> {
259    #[inline]
260    fn as_array_ptr(&self) -> *mut npyffi::PyArrayObject {
261        self.as_ptr().cast()
262    }
263
264    fn dtype(&self) -> Bound<'py, PyArrayDescr> {
265        unsafe {
266            let descr_ptr = (*self.as_array_ptr()).descr;
267            Bound::from_borrowed_ptr(self.py(), descr_ptr.cast()).downcast_into_unchecked()
268        }
269    }
270}
271
272impl Sealed for Bound<'_, PyUntypedArray> {}
273
274// We won't be able to provide a `Deref` impl from `Bound<'_, PyArray<T, D>>` to
275// `Bound<'_, PyUntypedArray>`, so this seems to be the next best thing to do
276impl<'py, T, D> PyUntypedArrayMethods<'py> for Bound<'py, PyArray<T, D>> {
277    #[inline]
278    fn as_array_ptr(&self) -> *mut npyffi::PyArrayObject {
279        self.as_untyped().as_array_ptr()
280    }
281
282    #[inline]
283    fn dtype(&self) -> Bound<'py, PyArrayDescr> {
284        self.as_untyped().dtype()
285    }
286}
287
288impl<T, D> Sealed for Bound<'_, PyArray<T, D>> {}