numpy/
dtype.rs

1use std::mem::size_of;
2use std::os::raw::{c_int, c_long, c_longlong, c_short, c_uint, c_ulong, c_ulonglong, c_ushort};
3use std::ptr;
4
5#[cfg(feature = "half")]
6use half::{bf16, f16};
7use num_traits::{Bounded, Zero};
8#[cfg(feature = "half")]
9use pyo3::sync::PyOnceLock;
10use pyo3::{
11    conversion::IntoPyObject,
12    exceptions::{PyIndexError, PyValueError},
13    ffi::{self, PyTuple_Size},
14    pyobject_native_type_named,
15    types::{PyAnyMethods, PyDict, PyDictMethods, PyTuple, PyType},
16    Borrowed, Bound, Py, PyAny, PyResult, PyTypeInfo, Python,
17};
18
19use crate::npyffi::{
20    NpyTypes, PyArray_Descr, PyDataType_ALIGNMENT, PyDataType_ELSIZE, PyDataType_FIELDS,
21    PyDataType_FLAGS, PyDataType_NAMES, PyDataType_SUBARRAY, NPY_ALIGNED_STRUCT,
22    NPY_BYTEORDER_CHAR, NPY_ITEM_HASOBJECT, NPY_TYPES, PY_ARRAY_API,
23};
24
25pub use num_complex::{Complex32, Complex64};
26
27/// Binding of [`numpy.dtype`][dtype].
28///
29/// # Example
30///
31/// ```
32/// use numpy::{dtype, get_array_module, PyArrayDescr, PyArrayDescrMethods};
33/// use numpy::pyo3::{types::{IntoPyDict, PyAnyMethods}, Python, ffi::c_str};
34///
35/// # fn main() -> pyo3::PyResult<()> {
36/// Python::attach(|py| {
37///     let locals = [("np", get_array_module(py)?)].into_py_dict(py)?;
38///
39///     let dt = py
40///         .eval(c_str!("np.array([1, 2, 3.0]).dtype"), Some(&locals), None)?
41///         .cast_into::<PyArrayDescr>()?;
42///
43///     assert!(dt.is_equiv_to(&dtype::<f64>(py)));
44/// #   Ok(())
45/// })
46/// # }
47/// ```
48///
49/// [dtype]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.html
50#[repr(transparent)]
51pub struct PyArrayDescr(PyAny);
52
53pyobject_native_type_named!(PyArrayDescr);
54
55unsafe impl PyTypeInfo for PyArrayDescr {
56    const NAME: &'static str = "PyArrayDescr";
57    const MODULE: Option<&'static str> = Some("numpy");
58
59    #[inline]
60    fn type_object_raw<'py>(py: Python<'py>) -> *mut ffi::PyTypeObject {
61        unsafe { PY_ARRAY_API.get_type_object(py, NpyTypes::PyArrayDescr_Type) }
62    }
63}
64
65/// Returns the type descriptor ("dtype") for a registered type.
66#[inline]
67pub fn dtype<'py, T: Element>(py: Python<'py>) -> Bound<'py, PyArrayDescr> {
68    T::get_dtype(py)
69}
70
71impl PyArrayDescr {
72    /// Creates a new type descriptor ("dtype") object from an arbitrary object.
73    ///
74    /// Equivalent to invoking the constructor of [`numpy.dtype`][dtype].
75    ///
76    /// [dtype]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.html
77    #[inline]
78    pub fn new<'a, 'py, T>(py: Python<'py>, ob: T) -> PyResult<Bound<'py, Self>>
79    where
80        T: IntoPyObject<'py>,
81    {
82        fn inner<'py>(
83            py: Python<'py>,
84            obj: Borrowed<'_, 'py, PyAny>,
85        ) -> PyResult<Bound<'py, PyArrayDescr>> {
86            let mut descr: *mut PyArray_Descr = ptr::null_mut();
87            unsafe {
88                // None is an invalid input here and is not converted to NPY_DEFAULT_TYPE
89                PY_ARRAY_API.PyArray_DescrConverter2(py, obj.as_ptr(), &mut descr);
90                Bound::from_owned_ptr_or_err(py, descr.cast()).map(|any| any.cast_into_unchecked())
91            }
92        }
93
94        inner(
95            py,
96            ob.into_pyobject(py)
97                .map_err(Into::into)?
98                .into_any()
99                .as_borrowed(),
100        )
101    }
102
103    /// Shortcut for creating a type descriptor of `object` type.
104    #[inline]
105    pub fn object(py: Python<'_>) -> Bound<'_, Self> {
106        Self::from_npy_type(py, NPY_TYPES::NPY_OBJECT)
107    }
108
109    /// Returns the type descriptor for a registered type.
110    #[inline]
111    pub fn of<'py, T: Element>(py: Python<'py>) -> Bound<'py, Self> {
112        T::get_dtype(py)
113    }
114
115    fn from_npy_type<'py>(py: Python<'py>, npy_type: NPY_TYPES) -> Bound<'py, Self> {
116        unsafe {
117            let descr = PY_ARRAY_API.PyArray_DescrFromType(py, npy_type as _);
118            Bound::from_owned_ptr(py, descr.cast()).cast_into_unchecked()
119        }
120    }
121
122    pub(crate) fn new_from_npy_type<'py>(py: Python<'py>, npy_type: NPY_TYPES) -> Bound<'py, Self> {
123        unsafe {
124            let descr = PY_ARRAY_API.PyArray_DescrNewFromType(py, npy_type as _);
125            Bound::from_owned_ptr(py, descr.cast()).cast_into_unchecked()
126        }
127    }
128}
129
130/// Implementation of functionality for [`PyArrayDescr`].
131#[doc(alias = "PyArrayDescr")]
132pub trait PyArrayDescrMethods<'py>: Sealed {
133    /// Returns `self` as `*mut PyArray_Descr`.
134    fn as_dtype_ptr(&self) -> *mut PyArray_Descr;
135
136    /// Returns `self` as `*mut PyArray_Descr` while increasing the reference count.
137    ///
138    /// Useful in cases where the descriptor is stolen by the API.
139    fn into_dtype_ptr(self) -> *mut PyArray_Descr;
140
141    /// Returns true if two type descriptors are equivalent.
142    fn is_equiv_to(&self, other: &Self) -> bool;
143
144    /// Returns the [array scalar][arrays-scalars] corresponding to this type descriptor.
145    ///
146    /// Equivalent to [`numpy.dtype.type`][dtype-type].
147    ///
148    /// [arrays-scalars]: https://numpy.org/doc/stable/reference/arrays.scalars.html
149    /// [dtype-type]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.type.html
150    fn typeobj(&self) -> Bound<'py, PyType>;
151
152    /// Returns a unique number for each of the 21 different built-in
153    /// [enumerated types][enumerated-types].
154    ///
155    /// These are roughly ordered from least-to-most precision.
156    ///
157    /// Equivalent to [`numpy.dtype.num`][dtype-num].
158    ///
159    /// [enumerated-types]: https://numpy.org/doc/stable/reference/c-api/dtype.html#enumerated-types
160    /// [dtype-num]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.num.html
161    fn num(&self) -> c_int {
162        unsafe { &*self.as_dtype_ptr() }.type_num
163    }
164
165    /// Returns the element size of this type descriptor.
166    ///
167    /// Equivalent to [`numpy.dtype.itemsize`][dtype-itemsize].
168    ///
169    /// [dtype-itemsiize]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.itemsize.html
170    fn itemsize(&self) -> usize;
171
172    /// Returns the required alignment (bytes) of this type descriptor according to the compiler.
173    ///
174    /// Equivalent to [`numpy.dtype.alignment`][dtype-alignment].
175    ///
176    /// [dtype-alignment]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.alignment.html
177    fn alignment(&self) -> usize;
178
179    /// Returns an ASCII character indicating the byte-order of this type descriptor object.
180    ///
181    /// All built-in data-type objects have byteorder either `=` or `|`.
182    ///
183    /// Equivalent to [`numpy.dtype.byteorder`][dtype-byteorder].
184    ///
185    /// [dtype-byteorder]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.byteorder.html
186    fn byteorder(&self) -> u8 {
187        unsafe { &*self.as_dtype_ptr() }.byteorder.max(0) as _
188    }
189
190    /// Returns a unique ASCII character for each of the 21 different built-in types.
191    ///
192    /// Note that structured data types are categorized as `V` (void).
193    ///
194    /// Equivalent to [`numpy.dtype.char`][dtype-char].
195    ///
196    /// [dtype-char]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.char.html
197    fn char(&self) -> u8 {
198        unsafe { &*self.as_dtype_ptr() }.type_.max(0) as _
199    }
200
201    /// Returns an ASCII character (one of `biufcmMOSUV`) identifying the general kind of data.
202    ///
203    /// Note that structured data types are categorized as `V` (void).
204    ///
205    /// Equivalent to [`numpy.dtype.kind`][dtype-kind].
206    ///
207    /// [dtype-kind]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.kind.html
208    fn kind(&self) -> u8 {
209        unsafe { &*self.as_dtype_ptr() }.kind.max(0) as _
210    }
211
212    /// Returns bit-flags describing how this type descriptor is to be interpreted.
213    ///
214    /// Equivalent to [`numpy.dtype.flags`][dtype-flags].
215    ///
216    /// [dtype-flags]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.flags.html
217    fn flags(&self) -> u64;
218
219    /// Returns the number of dimensions if this type descriptor represents a sub-array, and zero otherwise.
220    ///
221    /// Equivalent to [`numpy.dtype.ndim`][dtype-ndim].
222    ///
223    /// [dtype-ndim]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.ndim.html
224    fn ndim(&self) -> usize;
225
226    /// Returns the type descriptor for the base element of subarrays, regardless of their dimension or shape.
227    ///
228    /// If the dtype is not a subarray, returns self.
229    ///
230    /// Equivalent to [`numpy.dtype.base`][dtype-base].
231    ///
232    /// [dtype-base]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.base.html
233    fn base(&self) -> Bound<'py, PyArrayDescr>;
234
235    /// Returns the shape of the sub-array.
236    ///
237    /// If the dtype is not a sub-array, an empty vector is returned.
238    ///
239    /// Equivalent to [`numpy.dtype.shape`][dtype-shape].
240    ///
241    /// [dtype-shape]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.shape.html
242    fn shape(&self) -> Vec<usize>;
243
244    /// Returns true if the type descriptor contains any reference-counted objects in any fields or sub-dtypes.
245    ///
246    /// Equivalent to [`numpy.dtype.hasobject`][dtype-hasobject].
247    ///
248    /// [dtype-hasobject]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.hasobject.html
249    fn has_object(&self) -> bool {
250        self.flags() & NPY_ITEM_HASOBJECT != 0
251    }
252
253    /// Returns true if the type descriptor is a struct which maintains field alignment.
254    ///
255    /// This flag is sticky, so when combining multiple structs together, it is preserved
256    /// and produces new dtypes which are also aligned.
257    ///
258    /// Equivalent to [`numpy.dtype.isalignedstruct`][dtype-isalignedstruct].
259    ///
260    /// [dtype-isalignedstruct]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.isalignedstruct.html
261    fn is_aligned_struct(&self) -> bool {
262        self.flags() & NPY_ALIGNED_STRUCT != 0
263    }
264
265    /// Returns true if the type descriptor is a sub-array.
266    ///
267    /// Equivalent to PyDataType_HASSUBARRAY(self).
268    fn has_subarray(&self) -> bool;
269
270    /// Returns true if the type descriptor is a structured type.
271    ///
272    /// Equivalent to PyDataType_HASFIELDS(self).
273    fn has_fields(&self) -> bool;
274
275    /// Returns true if type descriptor byteorder is native, or `None` if not applicable.
276    fn is_native_byteorder(&self) -> Option<bool> {
277        // based on PyArray_ISNBO(self->byteorder)
278        match self.byteorder() {
279            b'=' => Some(true),
280            b'|' => None,
281            byteorder => Some(byteorder == NPY_BYTEORDER_CHAR::NPY_NATBYTE as u8),
282        }
283    }
284
285    /// Returns an ordered list of field names, or `None` if there are no fields.
286    ///
287    /// The names are ordered according to increasing byte offset.
288    ///
289    /// Equivalent to [`numpy.dtype.names`][dtype-names].
290    ///
291    /// [dtype-names]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.names.html
292    fn names(&self) -> Option<Vec<String>>;
293
294    /// Returns the type descriptor and offset of the field with the given name.
295    ///
296    /// This method will return an error if this type descriptor is not structured,
297    /// or if it does not contain a field with a given name.
298    ///
299    /// The list of all names can be found via [`PyArrayDescr::names`].
300    ///
301    /// Equivalent to retrieving a single item from [`numpy.dtype.fields`][dtype-fields].
302    ///
303    /// [dtype-fields]: https://numpy.org/doc/stable/reference/generated/numpy.dtype.fields.html
304    fn get_field(&self, name: &str) -> PyResult<(Bound<'py, PyArrayDescr>, usize)>;
305}
306
307mod sealed {
308    pub trait Sealed {}
309}
310
311use sealed::Sealed;
312
313impl<'py> PyArrayDescrMethods<'py> for Bound<'py, PyArrayDescr> {
314    fn as_dtype_ptr(&self) -> *mut PyArray_Descr {
315        self.as_ptr() as _
316    }
317
318    fn into_dtype_ptr(self) -> *mut PyArray_Descr {
319        self.into_ptr() as _
320    }
321
322    fn is_equiv_to(&self, other: &Self) -> bool {
323        let self_ptr = self.as_dtype_ptr();
324        let other_ptr = other.as_dtype_ptr();
325
326        unsafe {
327            self_ptr == other_ptr
328                || PY_ARRAY_API.PyArray_EquivTypes(self.py(), self_ptr, other_ptr) != 0
329        }
330    }
331
332    fn typeobj(&self) -> Bound<'py, PyType> {
333        let dtype_type_ptr = unsafe { &*self.as_dtype_ptr() }.typeobj;
334        unsafe { PyType::from_borrowed_type_ptr(self.py(), dtype_type_ptr) }
335    }
336
337    fn itemsize(&self) -> usize {
338        unsafe { PyDataType_ELSIZE(self.py(), self.as_dtype_ptr()).max(0) as _ }
339    }
340
341    fn alignment(&self) -> usize {
342        unsafe { PyDataType_ALIGNMENT(self.py(), self.as_dtype_ptr()).max(0) as _ }
343    }
344
345    fn flags(&self) -> u64 {
346        unsafe { PyDataType_FLAGS(self.py(), self.as_dtype_ptr()) as _ }
347    }
348
349    fn ndim(&self) -> usize {
350        let subarray = unsafe { PyDataType_SUBARRAY(self.py(), self.as_dtype_ptr()).as_ref() };
351        match subarray {
352            None => 0,
353            Some(subarray) => unsafe { PyTuple_Size(subarray.shape) }.max(0) as _,
354        }
355    }
356
357    fn base(&self) -> Bound<'py, PyArrayDescr> {
358        let subarray = unsafe { PyDataType_SUBARRAY(self.py(), self.as_dtype_ptr()).as_ref() };
359        match subarray {
360            None => self.clone(),
361            Some(subarray) => unsafe {
362                Bound::from_borrowed_ptr(self.py(), subarray.base.cast()).cast_into_unchecked()
363            },
364        }
365    }
366
367    fn shape(&self) -> Vec<usize> {
368        let subarray = unsafe { PyDataType_SUBARRAY(self.py(), self.as_dtype_ptr()).as_ref() };
369        match subarray {
370            None => Vec::new(),
371            Some(subarray) => {
372                // NumPy guarantees that shape is a tuple of non-negative integers so this should never panic.
373                let shape = unsafe { Borrowed::from_ptr(self.py(), subarray.shape) };
374                shape.extract().unwrap()
375            }
376        }
377    }
378
379    fn has_subarray(&self) -> bool {
380        unsafe { !PyDataType_SUBARRAY(self.py(), self.as_dtype_ptr()).is_null() }
381    }
382
383    fn has_fields(&self) -> bool {
384        unsafe { !PyDataType_NAMES(self.py(), self.as_dtype_ptr()).is_null() }
385    }
386
387    fn names(&self) -> Option<Vec<String>> {
388        if !self.has_fields() {
389            return None;
390        }
391        let names = unsafe {
392            Borrowed::from_ptr(self.py(), PyDataType_NAMES(self.py(), self.as_dtype_ptr()))
393        };
394        names.extract().ok()
395    }
396
397    fn get_field(&self, name: &str) -> PyResult<(Bound<'py, PyArrayDescr>, usize)> {
398        if !self.has_fields() {
399            return Err(PyValueError::new_err(
400                "cannot get field information: type descriptor has no fields",
401            ));
402        }
403        let dict = unsafe {
404            Borrowed::from_ptr(self.py(), PyDataType_FIELDS(self.py(), self.as_dtype_ptr()))
405        };
406        let dict = unsafe { dict.cast_unchecked::<PyDict>() };
407        // NumPy guarantees that fields are tuples of proper size and type, so this should never panic.
408        let tuple = dict
409            .get_item(name)?
410            .ok_or_else(|| PyIndexError::new_err(name.to_owned()))?
411            .cast_into::<PyTuple>()
412            .unwrap();
413        // Note that we cannot just extract the entire tuple since the third element can be a title.
414        let dtype = tuple
415            .get_item(0)
416            .unwrap()
417            .cast_into::<PyArrayDescr>()
418            .unwrap();
419        let offset = tuple.get_item(1).unwrap().extract().unwrap();
420        Ok((dtype, offset))
421    }
422}
423
424impl Sealed for Bound<'_, PyArrayDescr> {}
425
426/// Represents that a type can be an element of `PyArray`.
427///
428/// Currently, only integer/float/complex/object types are supported. The [NumPy documentation][enumerated-types]
429/// list the other built-in types which we are not yet implemented.
430///
431/// Note that NumPy's integer types like `numpy.int_` and `numpy.uint` are based on C's integer hierarchy
432/// which implies that their widths change depending on the platform's [data model][data-models].
433/// For example, `numpy.int_` matches C's `long` which is 32 bits wide on Windows (using the LLP64 data model)
434/// but 64 bits wide on Linux (using the LP64 data model).
435///
436/// In contrast, Rust's [`isize`] and [`usize`] types are defined to have the same width as a pointer
437/// and are therefore always 64 bits wide on 64-bit platforms. If you want to match NumPy's behaviour,
438/// consider using the [`c_long`][std::ffi::c_long] and [`c_ulong`][std::ffi::c_ulong] type aliases.
439///
440/// # Safety
441///
442/// A type `T` that implements this trait should be safe when managed by a NumPy
443/// array, thus implementing this trait is marked unsafe. Data types that don't
444/// contain Python objects (i.e., either the object type itself or record types
445/// containing object-type fields) are assumed to be trivially copyable, which
446/// is reflected in the `IS_COPY` flag. Furthermore, it is assumed that for
447/// the object type the elements are pointers into the Python heap and that the
448/// corresponding `Clone` implemenation will never panic as it only increases
449/// the reference count.
450///
451/// # Custom element types
452///
453/// Note that we cannot safely store `Py<T>` where `T: PyClass`, because the type information would be
454/// eliminated in the resulting NumPy array.
455/// In other words, objects are always treated as `Py<PyAny>` (a.k.a. `PyObject`) by Python code,
456/// and only `Py<PyAny>` can be stored in a type safe manner.
457///
458/// You can however create [`Array<Py<T>, D>`][ndarray::Array] and turn that into a NumPy array
459/// safely and efficiently using [`from_owned_object_array`][crate::PyArray::from_owned_object_array].
460///
461/// [enumerated-types]: https://numpy.org/doc/stable/reference/c-api/dtype.html#enumerated-types
462/// [data-models]: https://en.wikipedia.org/wiki/64-bit_computing#64-bit_data_models
463pub unsafe trait Element: Sized + Send + Sync {
464    /// Flag that indicates whether this type is trivially copyable.
465    ///
466    /// It should be set to true for all trivially copyable types (like scalar types
467    /// and record/array types only containing trivially copyable fields and elements).
468    ///
469    /// This flag should *always* be set to `false` for object types or record types
470    /// that contain object-type fields.
471    const IS_COPY: bool;
472
473    /// Returns the associated type descriptor ("dtype") for the given element type.
474    fn get_dtype(py: Python<'_>) -> Bound<'_, PyArrayDescr>;
475
476    /// Create a clone of the value while the GIL is guaranteed to be held.
477    fn clone_ref(&self, py: Python<'_>) -> Self;
478
479    /// Create an owned copy of the slice while the GIL is guaranteed to be held.
480    ///
481    /// Some types may provide implementations of this method that are more efficient
482    /// than simply mapping the `py_clone` method to each element in the slice.
483    #[inline]
484    fn vec_from_slice(py: Python<'_>, slc: &[Self]) -> Vec<Self> {
485        slc.iter().map(|elem| elem.clone_ref(py)).collect()
486    }
487
488    /// Create an owned copy of the array while the GIL is guaranteed to be held.
489    ///
490    /// Some types may provide implementations of this method that are more efficient
491    /// than simply mapping the `py_clone` method to each element in the view.
492    #[inline]
493    fn array_from_view<D>(
494        py: Python<'_>,
495        view: ::ndarray::ArrayView<'_, Self, D>,
496    ) -> ::ndarray::Array<Self, D>
497    where
498        D: ::ndarray::Dimension,
499    {
500        view.map(|elem| elem.clone_ref(py))
501    }
502}
503
504fn npy_int_type_lookup<T, T0, T1, T2>(npy_types: [NPY_TYPES; 3]) -> NPY_TYPES {
505    // `npy_common.h` defines the integer aliases. In order, it checks:
506    // NPY_BITSOF_LONG, NPY_BITSOF_LONGLONG, NPY_BITSOF_INT, NPY_BITSOF_SHORT, NPY_BITSOF_CHAR
507    // and assigns the alias to the first matching size, so we should check in this order.
508    match size_of::<T>() {
509        x if x == size_of::<T0>() => npy_types[0],
510        x if x == size_of::<T1>() => npy_types[1],
511        x if x == size_of::<T2>() => npy_types[2],
512        _ => panic!("Unable to match integer type descriptor: {npy_types:?}"),
513    }
514}
515
516fn npy_int_type<T: Bounded + Zero + Sized + PartialEq>() -> NPY_TYPES {
517    let is_unsigned = T::min_value() == T::zero();
518    let bit_width = 8 * size_of::<T>();
519
520    match (is_unsigned, bit_width) {
521        (false, 8) => NPY_TYPES::NPY_BYTE,
522        (false, 16) => NPY_TYPES::NPY_SHORT,
523        (false, 32) => npy_int_type_lookup::<i32, c_long, c_int, c_short>([
524            NPY_TYPES::NPY_LONG,
525            NPY_TYPES::NPY_INT,
526            NPY_TYPES::NPY_SHORT,
527        ]),
528        (false, 64) => npy_int_type_lookup::<i64, c_long, c_longlong, c_int>([
529            NPY_TYPES::NPY_LONG,
530            NPY_TYPES::NPY_LONGLONG,
531            NPY_TYPES::NPY_INT,
532        ]),
533        (true, 8) => NPY_TYPES::NPY_UBYTE,
534        (true, 16) => NPY_TYPES::NPY_USHORT,
535        (true, 32) => npy_int_type_lookup::<u32, c_ulong, c_uint, c_ushort>([
536            NPY_TYPES::NPY_ULONG,
537            NPY_TYPES::NPY_UINT,
538            NPY_TYPES::NPY_USHORT,
539        ]),
540        (true, 64) => npy_int_type_lookup::<u64, c_ulong, c_ulonglong, c_uint>([
541            NPY_TYPES::NPY_ULONG,
542            NPY_TYPES::NPY_ULONGLONG,
543            NPY_TYPES::NPY_UINT,
544        ]),
545        _ => unreachable!(),
546    }
547}
548
549// Invoke within the `Element` impl for a `Clone` type to provide an efficient
550// implementation of the cloning methods
551macro_rules! clone_methods_impl {
552    ($Self:ty) => {
553        #[inline]
554        fn clone_ref(&self, _py: ::pyo3::Python<'_>) -> $Self {
555            ::std::clone::Clone::clone(self)
556        }
557
558        #[inline]
559        fn vec_from_slice(_py: ::pyo3::Python<'_>, slc: &[$Self]) -> Vec<$Self> {
560            ::std::borrow::ToOwned::to_owned(slc)
561        }
562
563        #[inline]
564        fn array_from_view<D>(
565            _py: ::pyo3::Python<'_>,
566            view: ::ndarray::ArrayView<'_, $Self, D>,
567        ) -> ::ndarray::Array<$Self, D>
568        where
569            D: ::ndarray::Dimension,
570        {
571            ::ndarray::ArrayView::to_owned(&view)
572        }
573    };
574}
575pub(crate) use clone_methods_impl;
576use pyo3::BoundObject;
577
578macro_rules! impl_element_scalar {
579    (@impl: $ty:ty, $npy_type:expr $(,#[$meta:meta])*) => {
580        $(#[$meta])*
581        unsafe impl Element for $ty {
582            const IS_COPY: bool = true;
583
584            fn get_dtype(py: Python<'_>) -> Bound<'_, PyArrayDescr> {
585                PyArrayDescr::from_npy_type(py, $npy_type)
586            }
587
588            clone_methods_impl!($ty);
589        }
590    };
591    ($ty:ty => $npy_type:ident $(,#[$meta:meta])*) => {
592        impl_element_scalar!(@impl: $ty, NPY_TYPES::$npy_type $(,#[$meta])*);
593    };
594    ($($tys:ty),+) => {
595        $(impl_element_scalar!(@impl: $tys, npy_int_type::<$tys>());)+
596    };
597}
598
599impl_element_scalar!(bool => NPY_BOOL);
600
601impl_element_scalar!(i8, i16, i32, i64);
602impl_element_scalar!(u8, u16, u32, u64);
603
604impl_element_scalar!(f32 => NPY_FLOAT);
605impl_element_scalar!(f64 => NPY_DOUBLE);
606
607#[cfg(feature = "half")]
608impl_element_scalar!(f16 => NPY_HALF);
609
610#[cfg(feature = "half")]
611unsafe impl Element for bf16 {
612    const IS_COPY: bool = true;
613
614    fn get_dtype(py: Python<'_>) -> Bound<'_, PyArrayDescr> {
615        static DTYPE: PyOnceLock<Py<PyArrayDescr>> = PyOnceLock::new();
616
617        DTYPE
618            .get_or_init(py, || {
619                PyArrayDescr::new(py, "bfloat16").expect("A package which provides a `bfloat16` data type for NumPy is required to use the `half::bf16` element type.").unbind()
620            })
621            .clone_ref(py)
622            .into_bound(py)
623    }
624
625    clone_methods_impl!(Self);
626}
627
628impl_element_scalar!(Complex32 => NPY_CFLOAT,
629    #[doc = "Complex type with `f32` components which maps to `numpy.csingle` (`numpy.complex64`)."]);
630impl_element_scalar!(Complex64 => NPY_CDOUBLE,
631    #[doc = "Complex type with `f64` components which maps to `numpy.cdouble` (`numpy.complex128`)."]);
632
633#[cfg(any(target_pointer_width = "32", target_pointer_width = "64"))]
634impl_element_scalar!(usize, isize);
635
636unsafe impl Element for Py<PyAny> {
637    const IS_COPY: bool = false;
638
639    fn get_dtype(py: Python<'_>) -> Bound<'_, PyArrayDescr> {
640        PyArrayDescr::object(py)
641    }
642
643    #[inline]
644    fn clone_ref(&self, py: Python<'_>) -> Self {
645        Py::clone_ref(self, py)
646    }
647}
648
649#[cfg(test)]
650mod tests {
651    use super::*;
652
653    use pyo3::types::PyString;
654    use pyo3::{py_run, types::PyTypeMethods};
655
656    use crate::npyffi::{is_numpy_2, NPY_NEEDS_PYAPI};
657
658    #[test]
659    fn test_dtype_new() {
660        Python::attach(|py| {
661            assert!(PyArrayDescr::new(py, "float64")
662                .unwrap()
663                .is(dtype::<f64>(py)));
664
665            let dt = PyArrayDescr::new(py, [("a", "O"), ("b", "?")].as_ref()).unwrap();
666            assert_eq!(dt.names(), Some(vec!["a".to_owned(), "b".to_owned()]));
667            assert!(dt.has_object());
668            assert!(dt.get_field("a").unwrap().0.is(dtype::<Py<PyAny>>(py)));
669            assert!(dt.get_field("b").unwrap().0.is(dtype::<bool>(py)));
670
671            assert!(PyArrayDescr::new(py, 123_usize).is_err());
672        });
673    }
674
675    #[test]
676    fn test_dtype_names() {
677        fn type_name<T: Element>(py: Python<'_>) -> Bound<'_, PyString> {
678            dtype::<T>(py).typeobj().qualname().unwrap()
679        }
680        Python::attach(|py| {
681            if is_numpy_2(py) {
682                assert_eq!(type_name::<bool>(py), "bool");
683            } else {
684                assert_eq!(type_name::<bool>(py), "bool_");
685            }
686
687            assert_eq!(type_name::<i8>(py), "int8");
688            assert_eq!(type_name::<i16>(py), "int16");
689            assert_eq!(type_name::<i32>(py), "int32");
690            assert_eq!(type_name::<i64>(py), "int64");
691            assert_eq!(type_name::<u8>(py), "uint8");
692            assert_eq!(type_name::<u16>(py), "uint16");
693            assert_eq!(type_name::<u32>(py), "uint32");
694            assert_eq!(type_name::<u64>(py), "uint64");
695            assert_eq!(type_name::<f32>(py), "float32");
696            assert_eq!(type_name::<f64>(py), "float64");
697
698            assert_eq!(type_name::<Complex32>(py), "complex64");
699            assert_eq!(type_name::<Complex64>(py), "complex128");
700
701            #[cfg(target_pointer_width = "32")]
702            {
703                assert_eq!(type_name::<usize>(py), "uint32");
704                assert_eq!(type_name::<isize>(py), "int32");
705            }
706
707            #[cfg(target_pointer_width = "64")]
708            {
709                assert_eq!(type_name::<usize>(py), "uint64");
710                assert_eq!(type_name::<isize>(py), "int64");
711            }
712        });
713    }
714
715    #[test]
716    fn test_dtype_methods_scalar() {
717        Python::attach(|py| {
718            let dt = dtype::<f64>(py);
719
720            assert_eq!(dt.num(), NPY_TYPES::NPY_DOUBLE as c_int);
721            assert_eq!(dt.flags(), 0);
722            assert_eq!(dt.typeobj().qualname().unwrap(), "float64");
723            assert_eq!(dt.char(), b'd');
724            assert_eq!(dt.kind(), b'f');
725            assert_eq!(dt.byteorder(), b'=');
726            assert_eq!(dt.is_native_byteorder(), Some(true));
727            assert_eq!(dt.itemsize(), 8);
728            assert_eq!(dt.alignment(), 8);
729            assert!(!dt.has_object());
730            assert!(dt.names().is_none());
731            assert!(!dt.has_fields());
732            assert!(!dt.is_aligned_struct());
733            assert!(!dt.has_subarray());
734            assert!(dt.base().is_equiv_to(&dt));
735            assert_eq!(dt.ndim(), 0);
736            assert_eq!(dt.shape(), Vec::<usize>::new());
737        });
738    }
739
740    #[test]
741    fn test_dtype_methods_subarray() {
742        Python::attach(|py| {
743            let locals = PyDict::new(py);
744            py_run!(
745                py,
746                *locals,
747                "dtype = __import__('numpy').dtype(('f8', (2, 3)))"
748            );
749            let dt = locals
750                .get_item("dtype")
751                .unwrap()
752                .unwrap()
753                .cast_into::<PyArrayDescr>()
754                .unwrap();
755
756            assert_eq!(dt.num(), NPY_TYPES::NPY_VOID as c_int);
757            assert_eq!(dt.flags(), 0);
758            assert_eq!(dt.typeobj().qualname().unwrap(), "void");
759            assert_eq!(dt.char(), b'V');
760            assert_eq!(dt.kind(), b'V');
761            assert_eq!(dt.byteorder(), b'|');
762            assert_eq!(dt.is_native_byteorder(), None);
763            assert_eq!(dt.itemsize(), 48);
764            assert_eq!(dt.alignment(), 8);
765            assert!(!dt.has_object());
766            assert!(dt.names().is_none());
767            assert!(!dt.has_fields());
768            assert!(!dt.is_aligned_struct());
769            assert!(dt.has_subarray());
770            assert_eq!(dt.ndim(), 2);
771            assert_eq!(dt.shape(), vec![2, 3]);
772            assert!(dt.base().is_equiv_to(&dtype::<f64>(py)));
773        });
774    }
775
776    #[test]
777    fn test_dtype_methods_record() {
778        Python::attach(|py| {
779            let locals = PyDict::new(py);
780            py_run!(
781                py,
782                *locals,
783                "dtype = __import__('numpy').dtype([('x', 'u1'), ('y', 'f8'), ('z', 'O')], align=True)"
784            );
785            let dt = locals
786                .get_item("dtype")
787                .unwrap()
788                .unwrap()
789                .cast_into::<PyArrayDescr>()
790                .unwrap();
791
792            assert_eq!(dt.num(), NPY_TYPES::NPY_VOID as c_int);
793            assert_ne!(dt.flags() & NPY_ITEM_HASOBJECT, 0);
794            assert_ne!(dt.flags() & NPY_NEEDS_PYAPI, 0);
795            assert_ne!(dt.flags() & NPY_ALIGNED_STRUCT, 0);
796            assert_eq!(dt.typeobj().qualname().unwrap(), "void");
797            assert_eq!(dt.char(), b'V');
798            assert_eq!(dt.kind(), b'V');
799            assert_eq!(dt.byteorder(), b'|');
800            assert_eq!(dt.is_native_byteorder(), None);
801            assert_eq!(dt.itemsize(), 24);
802            assert_eq!(dt.alignment(), 8);
803            assert!(dt.has_object());
804            assert_eq!(
805                dt.names(),
806                Some(vec!["x".to_owned(), "y".to_owned(), "z".to_owned()])
807            );
808            assert!(dt.has_fields());
809            assert!(dt.is_aligned_struct());
810            assert!(!dt.has_subarray());
811            assert_eq!(dt.ndim(), 0);
812            assert_eq!(dt.shape(), Vec::<usize>::new());
813            assert!(dt.base().is_equiv_to(&dt));
814            let x = dt.get_field("x").unwrap();
815            assert!(x.0.is_equiv_to(&dtype::<u8>(py)));
816            assert_eq!(x.1, 0);
817            let y = dt.get_field("y").unwrap();
818            assert!(y.0.is_equiv_to(&dtype::<f64>(py)));
819            assert_eq!(y.1, 8);
820            let z = dt.get_field("z").unwrap();
821            assert!(z.0.is_equiv_to(&dtype::<Py<PyAny>>(py)));
822            assert_eq!(z.1, 16);
823        });
824    }
825}