numpy/
dtype.rs

1use std::mem::size_of;
2use std::os::raw::{c_int, c_long, c_longlong, c_short, c_uint, c_ulong, c_ulonglong, c_ushort};
3use std::ptr;
4
5#[cfg(feature = "half")]
6use half::{bf16, f16};
7use num_traits::{Bounded, Zero};
8#[cfg(feature = "half")]
9use pyo3::sync::GILOnceCell;
10use pyo3::{
11    conversion::IntoPyObject,
12    exceptions::{PyIndexError, PyValueError},
13    ffi::{self, PyTuple_Size},
14    pyobject_native_type_named,
15    types::{PyAnyMethods, PyDict, PyDictMethods, PyTuple, PyType},
16    Borrowed, Bound, Py, PyAny, PyObject, PyResult, PyTypeInfo, Python,
17};
18
19use crate::npyffi::{
20    NpyTypes, PyArray_Descr, PyDataType_ALIGNMENT, PyDataType_ELSIZE, PyDataType_FIELDS,
21    PyDataType_FLAGS, PyDataType_NAMES, PyDataType_SUBARRAY, NPY_ALIGNED_STRUCT,
22    NPY_BYTEORDER_CHAR, NPY_ITEM_HASOBJECT, NPY_TYPES, PY_ARRAY_API,
23};
24
25pub use num_complex::{Complex32, Complex64};
26
27/// Binding of [`numpy.dtype`][dtype].
28///
29/// # Example
30///
31/// ```
32/// use numpy::{dtype, get_array_module, PyArrayDescr, PyArrayDescrMethods};
33/// use numpy::pyo3::{types::{IntoPyDict, PyAnyMethods}, Python, ffi::c_str};
34///
35/// # fn main() -> pyo3::PyResult<()> {
36/// Python::with_gil(|py| {
37///     let locals = [("np", get_array_module(py)?)].into_py_dict(py)?;
38///
39///     let dt = py
40///         .eval(c_str!("np.array([1, 2, 3.0]).dtype"), Some(&locals), None)?
41///         .downcast_into::<PyArrayDescr>()?;
42///
43///     assert!(dt.is_equiv_to(&dtype::<f64>(py)));
44/// #   Ok(())
45/// })
46/// # }
47/// ```
48///
49/// [dtype]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.html
50#[repr(transparent)]
51pub struct PyArrayDescr(PyAny);
52
53pyobject_native_type_named!(PyArrayDescr);
54
55unsafe impl PyTypeInfo for PyArrayDescr {
56    const NAME: &'static str = "PyArrayDescr";
57    const MODULE: Option<&'static str> = Some("numpy");
58
59    #[inline]
60    fn type_object_raw<'py>(py: Python<'py>) -> *mut ffi::PyTypeObject {
61        unsafe { PY_ARRAY_API.get_type_object(py, NpyTypes::PyArrayDescr_Type) }
62    }
63}
64
65/// Returns the type descriptor ("dtype") for a registered type.
66#[inline]
67pub fn dtype<'py, T: Element>(py: Python<'py>) -> Bound<'py, PyArrayDescr> {
68    T::get_dtype(py)
69}
70
71impl PyArrayDescr {
72    /// Creates a new type descriptor ("dtype") object from an arbitrary object.
73    ///
74    /// Equivalent to invoking the constructor of [`numpy.dtype`][dtype].
75    ///
76    /// [dtype]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.html
77    #[inline]
78    pub fn new<'a, 'py, T>(py: Python<'py>, ob: T) -> PyResult<Bound<'py, Self>>
79    where
80        T: IntoPyObject<'py>,
81    {
82        fn inner<'py>(
83            py: Python<'py>,
84            obj: Borrowed<'_, 'py, PyAny>,
85        ) -> PyResult<Bound<'py, PyArrayDescr>> {
86            let mut descr: *mut PyArray_Descr = ptr::null_mut();
87            unsafe {
88                // None is an invalid input here and is not converted to NPY_DEFAULT_TYPE
89                PY_ARRAY_API.PyArray_DescrConverter2(py, obj.as_ptr(), &mut descr);
90                Bound::from_owned_ptr_or_err(py, descr.cast())
91                    .map(|any| any.downcast_into_unchecked())
92            }
93        }
94
95        inner(
96            py,
97            ob.into_pyobject(py)
98                .map_err(Into::into)?
99                .into_any()
100                .as_borrowed(),
101        )
102    }
103
104    /// Shortcut for creating a type descriptor of `object` type.
105    #[inline]
106    pub fn object(py: Python<'_>) -> Bound<'_, Self> {
107        Self::from_npy_type(py, NPY_TYPES::NPY_OBJECT)
108    }
109
110    /// Returns the type descriptor for a registered type.
111    #[inline]
112    pub fn of<'py, T: Element>(py: Python<'py>) -> Bound<'py, Self> {
113        T::get_dtype(py)
114    }
115
116    fn from_npy_type<'py>(py: Python<'py>, npy_type: NPY_TYPES) -> Bound<'py, Self> {
117        unsafe {
118            let descr = PY_ARRAY_API.PyArray_DescrFromType(py, npy_type as _);
119            Bound::from_owned_ptr(py, descr.cast()).downcast_into_unchecked()
120        }
121    }
122
123    pub(crate) fn new_from_npy_type<'py>(py: Python<'py>, npy_type: NPY_TYPES) -> Bound<'py, Self> {
124        unsafe {
125            let descr = PY_ARRAY_API.PyArray_DescrNewFromType(py, npy_type as _);
126            Bound::from_owned_ptr(py, descr.cast()).downcast_into_unchecked()
127        }
128    }
129}
130
131/// Implementation of functionality for [`PyArrayDescr`].
132#[doc(alias = "PyArrayDescr")]
133pub trait PyArrayDescrMethods<'py>: Sealed {
134    /// Returns `self` as `*mut PyArray_Descr`.
135    fn as_dtype_ptr(&self) -> *mut PyArray_Descr;
136
137    /// Returns `self` as `*mut PyArray_Descr` while increasing the reference count.
138    ///
139    /// Useful in cases where the descriptor is stolen by the API.
140    fn into_dtype_ptr(self) -> *mut PyArray_Descr;
141
142    /// Returns true if two type descriptors are equivalent.
143    fn is_equiv_to(&self, other: &Self) -> bool;
144
145    /// Returns the [array scalar][arrays-scalars] corresponding to this type descriptor.
146    ///
147    /// Equivalent to [`numpy.dtype.type`][dtype-type].
148    ///
149    /// [arrays-scalars]: https://numpy.org/doc/stable/reference/arrays.scalars.html
150    /// [dtype-type]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.type.html
151    fn typeobj(&self) -> Bound<'py, PyType>;
152
153    /// Returns a unique number for each of the 21 different built-in
154    /// [enumerated types][enumerated-types].
155    ///
156    /// These are roughly ordered from least-to-most precision.
157    ///
158    /// Equivalent to [`numpy.dtype.num`][dtype-num].
159    ///
160    /// [enumerated-types]: https://numpy.org/doc/stable/reference/c-api/dtype.html#enumerated-types
161    /// [dtype-num]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.num.html
162    fn num(&self) -> c_int {
163        unsafe { &*self.as_dtype_ptr() }.type_num
164    }
165
166    /// Returns the element size of this type descriptor.
167    ///
168    /// Equivalent to [`numpy.dtype.itemsize`][dtype-itemsize].
169    ///
170    /// [dtype-itemsiize]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.itemsize.html
171    fn itemsize(&self) -> usize;
172
173    /// Returns the required alignment (bytes) of this type descriptor according to the compiler.
174    ///
175    /// Equivalent to [`numpy.dtype.alignment`][dtype-alignment].
176    ///
177    /// [dtype-alignment]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.alignment.html
178    fn alignment(&self) -> usize;
179
180    /// Returns an ASCII character indicating the byte-order of this type descriptor object.
181    ///
182    /// All built-in data-type objects have byteorder either `=` or `|`.
183    ///
184    /// Equivalent to [`numpy.dtype.byteorder`][dtype-byteorder].
185    ///
186    /// [dtype-byteorder]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.byteorder.html
187    fn byteorder(&self) -> u8 {
188        unsafe { &*self.as_dtype_ptr() }.byteorder.max(0) as _
189    }
190
191    /// Returns a unique ASCII character for each of the 21 different built-in types.
192    ///
193    /// Note that structured data types are categorized as `V` (void).
194    ///
195    /// Equivalent to [`numpy.dtype.char`][dtype-char].
196    ///
197    /// [dtype-char]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.char.html
198    fn char(&self) -> u8 {
199        unsafe { &*self.as_dtype_ptr() }.type_.max(0) as _
200    }
201
202    /// Returns an ASCII character (one of `biufcmMOSUV`) identifying the general kind of data.
203    ///
204    /// Note that structured data types are categorized as `V` (void).
205    ///
206    /// Equivalent to [`numpy.dtype.kind`][dtype-kind].
207    ///
208    /// [dtype-kind]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.kind.html
209    fn kind(&self) -> u8 {
210        unsafe { &*self.as_dtype_ptr() }.kind.max(0) as _
211    }
212
213    /// Returns bit-flags describing how this type descriptor is to be interpreted.
214    ///
215    /// Equivalent to [`numpy.dtype.flags`][dtype-flags].
216    ///
217    /// [dtype-flags]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.flags.html
218    fn flags(&self) -> u64;
219
220    /// Returns the number of dimensions if this type descriptor represents a sub-array, and zero otherwise.
221    ///
222    /// Equivalent to [`numpy.dtype.ndim`][dtype-ndim].
223    ///
224    /// [dtype-ndim]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.ndim.html
225    fn ndim(&self) -> usize;
226
227    /// Returns the type descriptor for the base element of subarrays, regardless of their dimension or shape.
228    ///
229    /// If the dtype is not a subarray, returns self.
230    ///
231    /// Equivalent to [`numpy.dtype.base`][dtype-base].
232    ///
233    /// [dtype-base]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.base.html
234    fn base(&self) -> Bound<'py, PyArrayDescr>;
235
236    /// Returns the shape of the sub-array.
237    ///
238    /// If the dtype is not a sub-array, an empty vector is returned.
239    ///
240    /// Equivalent to [`numpy.dtype.shape`][dtype-shape].
241    ///
242    /// [dtype-shape]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.shape.html
243    fn shape(&self) -> Vec<usize>;
244
245    /// Returns true if the type descriptor contains any reference-counted objects in any fields or sub-dtypes.
246    ///
247    /// Equivalent to [`numpy.dtype.hasobject`][dtype-hasobject].
248    ///
249    /// [dtype-hasobject]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.hasobject.html
250    fn has_object(&self) -> bool {
251        self.flags() & NPY_ITEM_HASOBJECT != 0
252    }
253
254    /// Returns true if the type descriptor is a struct which maintains field alignment.
255    ///
256    /// This flag is sticky, so when combining multiple structs together, it is preserved
257    /// and produces new dtypes which are also aligned.
258    ///
259    /// Equivalent to [`numpy.dtype.isalignedstruct`][dtype-isalignedstruct].
260    ///
261    /// [dtype-isalignedstruct]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.isalignedstruct.html
262    fn is_aligned_struct(&self) -> bool {
263        self.flags() & NPY_ALIGNED_STRUCT != 0
264    }
265
266    /// Returns true if the type descriptor is a sub-array.
267    ///
268    /// Equivalent to PyDataType_HASSUBARRAY(self).
269    fn has_subarray(&self) -> bool;
270
271    /// Returns true if the type descriptor is a structured type.
272    ///
273    /// Equivalent to PyDataType_HASFIELDS(self).
274    fn has_fields(&self) -> bool;
275
276    /// Returns true if type descriptor byteorder is native, or `None` if not applicable.
277    fn is_native_byteorder(&self) -> Option<bool> {
278        // based on PyArray_ISNBO(self->byteorder)
279        match self.byteorder() {
280            b'=' => Some(true),
281            b'|' => None,
282            byteorder => Some(byteorder == NPY_BYTEORDER_CHAR::NPY_NATBYTE as u8),
283        }
284    }
285
286    /// Returns an ordered list of field names, or `None` if there are no fields.
287    ///
288    /// The names are ordered according to increasing byte offset.
289    ///
290    /// Equivalent to [`numpy.dtype.names`][dtype-names].
291    ///
292    /// [dtype-names]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.names.html
293    fn names(&self) -> Option<Vec<String>>;
294
295    /// Returns the type descriptor and offset of the field with the given name.
296    ///
297    /// This method will return an error if this type descriptor is not structured,
298    /// or if it does not contain a field with a given name.
299    ///
300    /// The list of all names can be found via [`PyArrayDescr::names`].
301    ///
302    /// Equivalent to retrieving a single item from [`numpy.dtype.fields`][dtype-fields].
303    ///
304    /// [dtype-fields]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.fields.html
305    fn get_field(&self, name: &str) -> PyResult<(Bound<'py, PyArrayDescr>, usize)>;
306}
307
308mod sealed {
309    pub trait Sealed {}
310}
311
312use sealed::Sealed;
313
314impl<'py> PyArrayDescrMethods<'py> for Bound<'py, PyArrayDescr> {
315    fn as_dtype_ptr(&self) -> *mut PyArray_Descr {
316        self.as_ptr() as _
317    }
318
319    fn into_dtype_ptr(self) -> *mut PyArray_Descr {
320        self.into_ptr() as _
321    }
322
323    fn is_equiv_to(&self, other: &Self) -> bool {
324        let self_ptr = self.as_dtype_ptr();
325        let other_ptr = other.as_dtype_ptr();
326
327        unsafe {
328            self_ptr == other_ptr
329                || PY_ARRAY_API.PyArray_EquivTypes(self.py(), self_ptr, other_ptr) != 0
330        }
331    }
332
333    fn typeobj(&self) -> Bound<'py, PyType> {
334        let dtype_type_ptr = unsafe { &*self.as_dtype_ptr() }.typeobj;
335        unsafe { PyType::from_borrowed_type_ptr(self.py(), dtype_type_ptr) }
336    }
337
338    fn itemsize(&self) -> usize {
339        unsafe { PyDataType_ELSIZE(self.py(), self.as_dtype_ptr()).max(0) as _ }
340    }
341
342    fn alignment(&self) -> usize {
343        unsafe { PyDataType_ALIGNMENT(self.py(), self.as_dtype_ptr()).max(0) as _ }
344    }
345
346    fn flags(&self) -> u64 {
347        unsafe { PyDataType_FLAGS(self.py(), self.as_dtype_ptr()) as _ }
348    }
349
350    fn ndim(&self) -> usize {
351        let subarray = unsafe { PyDataType_SUBARRAY(self.py(), self.as_dtype_ptr()).as_ref() };
352        match subarray {
353            None => 0,
354            Some(subarray) => unsafe { PyTuple_Size(subarray.shape) }.max(0) as _,
355        }
356    }
357
358    fn base(&self) -> Bound<'py, PyArrayDescr> {
359        let subarray = unsafe { PyDataType_SUBARRAY(self.py(), self.as_dtype_ptr()).as_ref() };
360        match subarray {
361            None => self.clone(),
362            Some(subarray) => unsafe {
363                Bound::from_borrowed_ptr(self.py(), subarray.base.cast()).downcast_into_unchecked()
364            },
365        }
366    }
367
368    fn shape(&self) -> Vec<usize> {
369        let subarray = unsafe { PyDataType_SUBARRAY(self.py(), self.as_dtype_ptr()).as_ref() };
370        match subarray {
371            None => Vec::new(),
372            Some(subarray) => {
373                // NumPy guarantees that shape is a tuple of non-negative integers so this should never panic.
374                let shape = unsafe { Borrowed::from_ptr(self.py(), subarray.shape) };
375                shape.extract().unwrap()
376            }
377        }
378    }
379
380    fn has_subarray(&self) -> bool {
381        unsafe { !PyDataType_SUBARRAY(self.py(), self.as_dtype_ptr()).is_null() }
382    }
383
384    fn has_fields(&self) -> bool {
385        unsafe { !PyDataType_NAMES(self.py(), self.as_dtype_ptr()).is_null() }
386    }
387
388    fn names(&self) -> Option<Vec<String>> {
389        if !self.has_fields() {
390            return None;
391        }
392        let names = unsafe {
393            Borrowed::from_ptr(self.py(), PyDataType_NAMES(self.py(), self.as_dtype_ptr()))
394        };
395        names.extract().ok()
396    }
397
398    fn get_field(&self, name: &str) -> PyResult<(Bound<'py, PyArrayDescr>, usize)> {
399        if !self.has_fields() {
400            return Err(PyValueError::new_err(
401                "cannot get field information: type descriptor has no fields",
402            ));
403        }
404        let dict = unsafe {
405            Borrowed::from_ptr(self.py(), PyDataType_FIELDS(self.py(), self.as_dtype_ptr()))
406        };
407        let dict = unsafe { dict.downcast_unchecked::<PyDict>() };
408        // NumPy guarantees that fields are tuples of proper size and type, so this should never panic.
409        let tuple = dict
410            .get_item(name)?
411            .ok_or_else(|| PyIndexError::new_err(name.to_owned()))?
412            .downcast_into::<PyTuple>()
413            .unwrap();
414        // Note that we cannot just extract the entire tuple since the third element can be a title.
415        let dtype = tuple
416            .get_item(0)
417            .unwrap()
418            .downcast_into::<PyArrayDescr>()
419            .unwrap();
420        let offset = tuple.get_item(1).unwrap().extract().unwrap();
421        Ok((dtype, offset))
422    }
423}
424
425impl Sealed for Bound<'_, PyArrayDescr> {}
426
427/// Represents that a type can be an element of `PyArray`.
428///
429/// Currently, only integer/float/complex/object types are supported. The [NumPy documentation][enumerated-types]
430/// list the other built-in types which we are not yet implemented.
431///
432/// Note that NumPy's integer types like `numpy.int_` and `numpy.uint` are based on C's integer hierarchy
433/// which implies that their widths change depending on the platform's [data model][data-models].
434/// For example, `numpy.int_` matches C's `long` which is 32 bits wide on Windows (using the LLP64 data model)
435/// but 64 bits wide on Linux (using the LP64 data model).
436///
437/// In contrast, Rust's [`isize`] and [`usize`] types are defined to have the same width as a pointer
438/// and are therefore always 64 bits wide on 64-bit platforms. If you want to match NumPy's behaviour,
439/// consider using the [`c_long`][std::ffi::c_long] and [`c_ulong`][std::ffi::c_ulong] type aliases.
440///
441/// # Safety
442///
443/// A type `T` that implements this trait should be safe when managed by a NumPy
444/// array, thus implementing this trait is marked unsafe. Data types that don't
445/// contain Python objects (i.e., either the object type itself or record types
446/// containing object-type fields) are assumed to be trivially copyable, which
447/// is reflected in the `IS_COPY` flag. Furthermore, it is assumed that for
448/// the object type the elements are pointers into the Python heap and that the
449/// corresponding `Clone` implemenation will never panic as it only increases
450/// the reference count.
451///
452/// # Custom element types
453///
454/// Note that we cannot safely store `Py<T>` where `T: PyClass`, because the type information would be
455/// eliminated in the resulting NumPy array.
456/// In other words, objects are always treated as `Py<PyAny>` (a.k.a. `PyObject`) by Python code,
457/// and only `Py<PyAny>` can be stored in a type safe manner.
458///
459/// You can however create [`Array<Py<T>, D>`][ndarray::Array] and turn that into a NumPy array
460/// safely and efficiently using [`from_owned_object_array`][crate::PyArray::from_owned_object_array].
461///
462/// [enumerated-types]: https://numpy.org/doc/stable/reference/c-api/dtype.html#enumerated-types
463/// [data-models]: https://en.wikipedia.org/wiki/64-bit_computing#64-bit_data_models
464pub unsafe trait Element: Sized + Send + Sync {
465    /// Flag that indicates whether this type is trivially copyable.
466    ///
467    /// It should be set to true for all trivially copyable types (like scalar types
468    /// and record/array types only containing trivially copyable fields and elements).
469    ///
470    /// This flag should *always* be set to `false` for object types or record types
471    /// that contain object-type fields.
472    const IS_COPY: bool;
473
474    /// Returns the associated type descriptor ("dtype") for the given element type.
475    fn get_dtype(py: Python<'_>) -> Bound<'_, PyArrayDescr>;
476
477    /// Create a clone of the value while the GIL is guaranteed to be held.
478    fn clone_ref(&self, py: Python<'_>) -> Self;
479
480    /// Create an owned copy of the slice while the GIL is guaranteed to be held.
481    ///
482    /// Some types may provide implementations of this method that are more efficient
483    /// than simply mapping the `py_clone` method to each element in the slice.
484    #[inline]
485    fn vec_from_slice(py: Python<'_>, slc: &[Self]) -> Vec<Self> {
486        slc.iter().map(|elem| elem.clone_ref(py)).collect()
487    }
488
489    /// Create an owned copy of the array while the GIL is guaranteed to be held.
490    ///
491    /// Some types may provide implementations of this method that are more efficient
492    /// than simply mapping the `py_clone` method to each element in the view.
493    #[inline]
494    fn array_from_view<D>(
495        py: Python<'_>,
496        view: ::ndarray::ArrayView<'_, Self, D>,
497    ) -> ::ndarray::Array<Self, D>
498    where
499        D: ::ndarray::Dimension,
500    {
501        view.map(|elem| elem.clone_ref(py))
502    }
503}
504
505fn npy_int_type_lookup<T, T0, T1, T2>(npy_types: [NPY_TYPES; 3]) -> NPY_TYPES {
506    // `npy_common.h` defines the integer aliases. In order, it checks:
507    // NPY_BITSOF_LONG, NPY_BITSOF_LONGLONG, NPY_BITSOF_INT, NPY_BITSOF_SHORT, NPY_BITSOF_CHAR
508    // and assigns the alias to the first matching size, so we should check in this order.
509    match size_of::<T>() {
510        x if x == size_of::<T0>() => npy_types[0],
511        x if x == size_of::<T1>() => npy_types[1],
512        x if x == size_of::<T2>() => npy_types[2],
513        _ => panic!("Unable to match integer type descriptor: {:?}", npy_types),
514    }
515}
516
517fn npy_int_type<T: Bounded + Zero + Sized + PartialEq>() -> NPY_TYPES {
518    let is_unsigned = T::min_value() == T::zero();
519    let bit_width = 8 * size_of::<T>();
520
521    match (is_unsigned, bit_width) {
522        (false, 8) => NPY_TYPES::NPY_BYTE,
523        (false, 16) => NPY_TYPES::NPY_SHORT,
524        (false, 32) => npy_int_type_lookup::<i32, c_long, c_int, c_short>([
525            NPY_TYPES::NPY_LONG,
526            NPY_TYPES::NPY_INT,
527            NPY_TYPES::NPY_SHORT,
528        ]),
529        (false, 64) => npy_int_type_lookup::<i64, c_long, c_longlong, c_int>([
530            NPY_TYPES::NPY_LONG,
531            NPY_TYPES::NPY_LONGLONG,
532            NPY_TYPES::NPY_INT,
533        ]),
534        (true, 8) => NPY_TYPES::NPY_UBYTE,
535        (true, 16) => NPY_TYPES::NPY_USHORT,
536        (true, 32) => npy_int_type_lookup::<u32, c_ulong, c_uint, c_ushort>([
537            NPY_TYPES::NPY_ULONG,
538            NPY_TYPES::NPY_UINT,
539            NPY_TYPES::NPY_USHORT,
540        ]),
541        (true, 64) => npy_int_type_lookup::<u64, c_ulong, c_ulonglong, c_uint>([
542            NPY_TYPES::NPY_ULONG,
543            NPY_TYPES::NPY_ULONGLONG,
544            NPY_TYPES::NPY_UINT,
545        ]),
546        _ => unreachable!(),
547    }
548}
549
550// Invoke within the `Element` impl for a `Clone` type to provide an efficient
551// implementation of the cloning methods
552macro_rules! clone_methods_impl {
553    ($Self:ty) => {
554        #[inline]
555        fn clone_ref(&self, _py: ::pyo3::Python<'_>) -> $Self {
556            ::std::clone::Clone::clone(self)
557        }
558
559        #[inline]
560        fn vec_from_slice(_py: ::pyo3::Python<'_>, slc: &[$Self]) -> Vec<$Self> {
561            ::std::borrow::ToOwned::to_owned(slc)
562        }
563
564        #[inline]
565        fn array_from_view<D>(
566            _py: ::pyo3::Python<'_>,
567            view: ::ndarray::ArrayView<'_, $Self, D>,
568        ) -> ::ndarray::Array<$Self, D>
569        where
570            D: ::ndarray::Dimension,
571        {
572            ::ndarray::ArrayView::to_owned(&view)
573        }
574    };
575}
576pub(crate) use clone_methods_impl;
577use pyo3::BoundObject;
578
579macro_rules! impl_element_scalar {
580    (@impl: $ty:ty, $npy_type:expr $(,#[$meta:meta])*) => {
581        $(#[$meta])*
582        unsafe impl Element for $ty {
583            const IS_COPY: bool = true;
584
585            fn get_dtype(py: Python<'_>) -> Bound<'_, PyArrayDescr> {
586                PyArrayDescr::from_npy_type(py, $npy_type)
587            }
588
589            clone_methods_impl!($ty);
590        }
591    };
592    ($ty:ty => $npy_type:ident $(,#[$meta:meta])*) => {
593        impl_element_scalar!(@impl: $ty, NPY_TYPES::$npy_type $(,#[$meta])*);
594    };
595    ($($tys:ty),+) => {
596        $(impl_element_scalar!(@impl: $tys, npy_int_type::<$tys>());)+
597    };
598}
599
600impl_element_scalar!(bool => NPY_BOOL);
601
602impl_element_scalar!(i8, i16, i32, i64);
603impl_element_scalar!(u8, u16, u32, u64);
604
605impl_element_scalar!(f32 => NPY_FLOAT);
606impl_element_scalar!(f64 => NPY_DOUBLE);
607
608#[cfg(feature = "half")]
609impl_element_scalar!(f16 => NPY_HALF);
610
611#[cfg(feature = "half")]
612unsafe impl Element for bf16 {
613    const IS_COPY: bool = true;
614
615    fn get_dtype(py: Python<'_>) -> Bound<'_, PyArrayDescr> {
616        static DTYPE: GILOnceCell<Py<PyArrayDescr>> = GILOnceCell::new();
617
618        DTYPE
619            .get_or_init(py, || {
620                PyArrayDescr::new(py, "bfloat16").expect("A package which provides a `bfloat16` data type for NumPy is required to use the `half::bf16` element type.").unbind()
621            })
622            .clone_ref(py)
623            .into_bound(py)
624    }
625
626    clone_methods_impl!(Self);
627}
628
629impl_element_scalar!(Complex32 => NPY_CFLOAT,
630    #[doc = "Complex type with `f32` components which maps to `numpy.csingle` (`numpy.complex64`)."]);
631impl_element_scalar!(Complex64 => NPY_CDOUBLE,
632    #[doc = "Complex type with `f64` components which maps to `numpy.cdouble` (`numpy.complex128`)."]);
633
634#[cfg(any(target_pointer_width = "32", target_pointer_width = "64"))]
635impl_element_scalar!(usize, isize);
636
637unsafe impl Element for PyObject {
638    const IS_COPY: bool = false;
639
640    fn get_dtype(py: Python<'_>) -> Bound<'_, PyArrayDescr> {
641        PyArrayDescr::object(py)
642    }
643
644    #[inline]
645    fn clone_ref(&self, py: Python<'_>) -> Self {
646        Py::clone_ref(self, py)
647    }
648}
649
650#[cfg(test)]
651mod tests {
652    use super::*;
653
654    use pyo3::types::PyString;
655    use pyo3::{py_run, types::PyTypeMethods};
656
657    use crate::npyffi::{is_numpy_2, NPY_NEEDS_PYAPI};
658
659    #[test]
660    fn test_dtype_new() {
661        Python::with_gil(|py| {
662            assert!(PyArrayDescr::new(py, "float64")
663                .unwrap()
664                .is(dtype::<f64>(py)));
665
666            let dt = PyArrayDescr::new(py, [("a", "O"), ("b", "?")].as_ref()).unwrap();
667            assert_eq!(dt.names(), Some(vec!["a".to_owned(), "b".to_owned()]));
668            assert!(dt.has_object());
669            assert!(dt.get_field("a").unwrap().0.is(dtype::<PyObject>(py)));
670            assert!(dt.get_field("b").unwrap().0.is(dtype::<bool>(py)));
671
672            assert!(PyArrayDescr::new(py, 123_usize).is_err());
673        });
674    }
675
676    #[test]
677    fn test_dtype_names() {
678        fn type_name<T: Element>(py: Python<'_>) -> Bound<'_, PyString> {
679            dtype::<T>(py).typeobj().qualname().unwrap()
680        }
681        Python::with_gil(|py| {
682            if is_numpy_2(py) {
683                assert_eq!(type_name::<bool>(py), "bool");
684            } else {
685                assert_eq!(type_name::<bool>(py), "bool_");
686            }
687
688            assert_eq!(type_name::<i8>(py), "int8");
689            assert_eq!(type_name::<i16>(py), "int16");
690            assert_eq!(type_name::<i32>(py), "int32");
691            assert_eq!(type_name::<i64>(py), "int64");
692            assert_eq!(type_name::<u8>(py), "uint8");
693            assert_eq!(type_name::<u16>(py), "uint16");
694            assert_eq!(type_name::<u32>(py), "uint32");
695            assert_eq!(type_name::<u64>(py), "uint64");
696            assert_eq!(type_name::<f32>(py), "float32");
697            assert_eq!(type_name::<f64>(py), "float64");
698
699            assert_eq!(type_name::<Complex32>(py), "complex64");
700            assert_eq!(type_name::<Complex64>(py), "complex128");
701
702            #[cfg(target_pointer_width = "32")]
703            {
704                assert_eq!(type_name::<usize>(py), "uint32");
705                assert_eq!(type_name::<isize>(py), "int32");
706            }
707
708            #[cfg(target_pointer_width = "64")]
709            {
710                assert_eq!(type_name::<usize>(py), "uint64");
711                assert_eq!(type_name::<isize>(py), "int64");
712            }
713        });
714    }
715
716    #[test]
717    fn test_dtype_methods_scalar() {
718        Python::with_gil(|py| {
719            let dt = dtype::<f64>(py);
720
721            assert_eq!(dt.num(), NPY_TYPES::NPY_DOUBLE as c_int);
722            assert_eq!(dt.flags(), 0);
723            assert_eq!(dt.typeobj().qualname().unwrap(), "float64");
724            assert_eq!(dt.char(), b'd');
725            assert_eq!(dt.kind(), b'f');
726            assert_eq!(dt.byteorder(), b'=');
727            assert_eq!(dt.is_native_byteorder(), Some(true));
728            assert_eq!(dt.itemsize(), 8);
729            assert_eq!(dt.alignment(), 8);
730            assert!(!dt.has_object());
731            assert!(dt.names().is_none());
732            assert!(!dt.has_fields());
733            assert!(!dt.is_aligned_struct());
734            assert!(!dt.has_subarray());
735            assert!(dt.base().is_equiv_to(&dt));
736            assert_eq!(dt.ndim(), 0);
737            assert_eq!(dt.shape(), Vec::<usize>::new());
738        });
739    }
740
741    #[test]
742    fn test_dtype_methods_subarray() {
743        Python::with_gil(|py| {
744            let locals = PyDict::new(py);
745            py_run!(
746                py,
747                *locals,
748                "dtype = __import__('numpy').dtype(('f8', (2, 3)))"
749            );
750            let dt = locals
751                .get_item("dtype")
752                .unwrap()
753                .unwrap()
754                .downcast_into::<PyArrayDescr>()
755                .unwrap();
756
757            assert_eq!(dt.num(), NPY_TYPES::NPY_VOID as c_int);
758            assert_eq!(dt.flags(), 0);
759            assert_eq!(dt.typeobj().qualname().unwrap(), "void");
760            assert_eq!(dt.char(), b'V');
761            assert_eq!(dt.kind(), b'V');
762            assert_eq!(dt.byteorder(), b'|');
763            assert_eq!(dt.is_native_byteorder(), None);
764            assert_eq!(dt.itemsize(), 48);
765            assert_eq!(dt.alignment(), 8);
766            assert!(!dt.has_object());
767            assert!(dt.names().is_none());
768            assert!(!dt.has_fields());
769            assert!(!dt.is_aligned_struct());
770            assert!(dt.has_subarray());
771            assert_eq!(dt.ndim(), 2);
772            assert_eq!(dt.shape(), vec![2, 3]);
773            assert!(dt.base().is_equiv_to(&dtype::<f64>(py)));
774        });
775    }
776
777    #[test]
778    fn test_dtype_methods_record() {
779        Python::with_gil(|py| {
780            let locals = PyDict::new(py);
781            py_run!(
782                py,
783                *locals,
784                "dtype = __import__('numpy').dtype([('x', 'u1'), ('y', 'f8'), ('z', 'O')], align=True)"
785            );
786            let dt = locals
787                .get_item("dtype")
788                .unwrap()
789                .unwrap()
790                .downcast_into::<PyArrayDescr>()
791                .unwrap();
792
793            assert_eq!(dt.num(), NPY_TYPES::NPY_VOID as c_int);
794            assert_ne!(dt.flags() & NPY_ITEM_HASOBJECT, 0);
795            assert_ne!(dt.flags() & NPY_NEEDS_PYAPI, 0);
796            assert_ne!(dt.flags() & NPY_ALIGNED_STRUCT, 0);
797            assert_eq!(dt.typeobj().qualname().unwrap(), "void");
798            assert_eq!(dt.char(), b'V');
799            assert_eq!(dt.kind(), b'V');
800            assert_eq!(dt.byteorder(), b'|');
801            assert_eq!(dt.is_native_byteorder(), None);
802            assert_eq!(dt.itemsize(), 24);
803            assert_eq!(dt.alignment(), 8);
804            assert!(dt.has_object());
805            assert_eq!(
806                dt.names(),
807                Some(vec!["x".to_owned(), "y".to_owned(), "z".to_owned()])
808            );
809            assert!(dt.has_fields());
810            assert!(dt.is_aligned_struct());
811            assert!(!dt.has_subarray());
812            assert_eq!(dt.ndim(), 0);
813            assert_eq!(dt.shape(), Vec::<usize>::new());
814            assert!(dt.base().is_equiv_to(&dt));
815            let x = dt.get_field("x").unwrap();
816            assert!(x.0.is_equiv_to(&dtype::<u8>(py)));
817            assert_eq!(x.1, 0);
818            let y = dt.get_field("y").unwrap();
819            assert!(y.0.is_equiv_to(&dtype::<f64>(py)));
820            assert_eq!(y.1, 8);
821            let z = dt.get_field("z").unwrap();
822            assert!(z.0.is_equiv_to(&dtype::<PyObject>(py)));
823            assert_eq!(z.1, 16);
824        });
825    }
826}