1use std::mem::size_of;
2use std::os::raw::{c_int, c_long, c_longlong, c_short, c_uint, c_ulong, c_ulonglong, c_ushort};
3use std::ptr;
4
5#[cfg(feature = "half")]
6use half::{bf16, f16};
7use num_traits::{Bounded, Zero};
8#[cfg(feature = "half")]
9use pyo3::sync::GILOnceCell;
10use pyo3::{
11 conversion::IntoPyObject,
12 exceptions::{PyIndexError, PyValueError},
13 ffi::{self, PyTuple_Size},
14 pyobject_native_type_named,
15 types::{PyAnyMethods, PyDict, PyDictMethods, PyTuple, PyType},
16 Borrowed, Bound, Py, PyAny, PyObject, PyResult, PyTypeInfo, Python,
17};
18
19use crate::npyffi::{
20 NpyTypes, PyArray_Descr, PyDataType_ALIGNMENT, PyDataType_ELSIZE, PyDataType_FIELDS,
21 PyDataType_FLAGS, PyDataType_NAMES, PyDataType_SUBARRAY, NPY_ALIGNED_STRUCT,
22 NPY_BYTEORDER_CHAR, NPY_ITEM_HASOBJECT, NPY_TYPES, PY_ARRAY_API,
23};
24
25pub use num_complex::{Complex32, Complex64};
26
27#[repr(transparent)]
51pub struct PyArrayDescr(PyAny);
52
53pyobject_native_type_named!(PyArrayDescr);
54
55unsafe impl PyTypeInfo for PyArrayDescr {
56 const NAME: &'static str = "PyArrayDescr";
57 const MODULE: Option<&'static str> = Some("numpy");
58
59 #[inline]
60 fn type_object_raw<'py>(py: Python<'py>) -> *mut ffi::PyTypeObject {
61 unsafe { PY_ARRAY_API.get_type_object(py, NpyTypes::PyArrayDescr_Type) }
62 }
63}
64
65#[inline]
67pub fn dtype<'py, T: Element>(py: Python<'py>) -> Bound<'py, PyArrayDescr> {
68 T::get_dtype(py)
69}
70
71impl PyArrayDescr {
72 #[inline]
78 pub fn new<'a, 'py, T>(py: Python<'py>, ob: T) -> PyResult<Bound<'py, Self>>
79 where
80 T: IntoPyObject<'py>,
81 {
82 fn inner<'py>(
83 py: Python<'py>,
84 obj: Borrowed<'_, 'py, PyAny>,
85 ) -> PyResult<Bound<'py, PyArrayDescr>> {
86 let mut descr: *mut PyArray_Descr = ptr::null_mut();
87 unsafe {
88 PY_ARRAY_API.PyArray_DescrConverter2(py, obj.as_ptr(), &mut descr);
90 Bound::from_owned_ptr_or_err(py, descr.cast())
91 .map(|any| any.downcast_into_unchecked())
92 }
93 }
94
95 inner(
96 py,
97 ob.into_pyobject(py)
98 .map_err(Into::into)?
99 .into_any()
100 .as_borrowed(),
101 )
102 }
103
104 #[inline]
106 pub fn object(py: Python<'_>) -> Bound<'_, Self> {
107 Self::from_npy_type(py, NPY_TYPES::NPY_OBJECT)
108 }
109
110 #[inline]
112 pub fn of<'py, T: Element>(py: Python<'py>) -> Bound<'py, Self> {
113 T::get_dtype(py)
114 }
115
116 fn from_npy_type<'py>(py: Python<'py>, npy_type: NPY_TYPES) -> Bound<'py, Self> {
117 unsafe {
118 let descr = PY_ARRAY_API.PyArray_DescrFromType(py, npy_type as _);
119 Bound::from_owned_ptr(py, descr.cast()).downcast_into_unchecked()
120 }
121 }
122
123 pub(crate) fn new_from_npy_type<'py>(py: Python<'py>, npy_type: NPY_TYPES) -> Bound<'py, Self> {
124 unsafe {
125 let descr = PY_ARRAY_API.PyArray_DescrNewFromType(py, npy_type as _);
126 Bound::from_owned_ptr(py, descr.cast()).downcast_into_unchecked()
127 }
128 }
129}
130
131#[doc(alias = "PyArrayDescr")]
133pub trait PyArrayDescrMethods<'py>: Sealed {
134 fn as_dtype_ptr(&self) -> *mut PyArray_Descr;
136
137 fn into_dtype_ptr(self) -> *mut PyArray_Descr;
141
142 fn is_equiv_to(&self, other: &Self) -> bool;
144
145 fn typeobj(&self) -> Bound<'py, PyType>;
152
153 fn num(&self) -> c_int {
163 unsafe { &*self.as_dtype_ptr() }.type_num
164 }
165
166 fn itemsize(&self) -> usize;
172
173 fn alignment(&self) -> usize;
179
180 fn byteorder(&self) -> u8 {
188 unsafe { &*self.as_dtype_ptr() }.byteorder.max(0) as _
189 }
190
191 fn char(&self) -> u8 {
199 unsafe { &*self.as_dtype_ptr() }.type_.max(0) as _
200 }
201
202 fn kind(&self) -> u8 {
210 unsafe { &*self.as_dtype_ptr() }.kind.max(0) as _
211 }
212
213 fn flags(&self) -> u64;
219
220 fn ndim(&self) -> usize;
226
227 fn base(&self) -> Bound<'py, PyArrayDescr>;
235
236 fn shape(&self) -> Vec<usize>;
244
245 fn has_object(&self) -> bool {
251 self.flags() & NPY_ITEM_HASOBJECT != 0
252 }
253
254 fn is_aligned_struct(&self) -> bool {
263 self.flags() & NPY_ALIGNED_STRUCT != 0
264 }
265
266 fn has_subarray(&self) -> bool;
270
271 fn has_fields(&self) -> bool;
275
276 fn is_native_byteorder(&self) -> Option<bool> {
278 match self.byteorder() {
280 b'=' => Some(true),
281 b'|' => None,
282 byteorder => Some(byteorder == NPY_BYTEORDER_CHAR::NPY_NATBYTE as u8),
283 }
284 }
285
286 fn names(&self) -> Option<Vec<String>>;
294
295 fn get_field(&self, name: &str) -> PyResult<(Bound<'py, PyArrayDescr>, usize)>;
306}
307
308mod sealed {
309 pub trait Sealed {}
310}
311
312use sealed::Sealed;
313
314impl<'py> PyArrayDescrMethods<'py> for Bound<'py, PyArrayDescr> {
315 fn as_dtype_ptr(&self) -> *mut PyArray_Descr {
316 self.as_ptr() as _
317 }
318
319 fn into_dtype_ptr(self) -> *mut PyArray_Descr {
320 self.into_ptr() as _
321 }
322
323 fn is_equiv_to(&self, other: &Self) -> bool {
324 let self_ptr = self.as_dtype_ptr();
325 let other_ptr = other.as_dtype_ptr();
326
327 unsafe {
328 self_ptr == other_ptr
329 || PY_ARRAY_API.PyArray_EquivTypes(self.py(), self_ptr, other_ptr) != 0
330 }
331 }
332
333 fn typeobj(&self) -> Bound<'py, PyType> {
334 let dtype_type_ptr = unsafe { &*self.as_dtype_ptr() }.typeobj;
335 unsafe { PyType::from_borrowed_type_ptr(self.py(), dtype_type_ptr) }
336 }
337
338 fn itemsize(&self) -> usize {
339 unsafe { PyDataType_ELSIZE(self.py(), self.as_dtype_ptr()).max(0) as _ }
340 }
341
342 fn alignment(&self) -> usize {
343 unsafe { PyDataType_ALIGNMENT(self.py(), self.as_dtype_ptr()).max(0) as _ }
344 }
345
346 fn flags(&self) -> u64 {
347 unsafe { PyDataType_FLAGS(self.py(), self.as_dtype_ptr()) as _ }
348 }
349
350 fn ndim(&self) -> usize {
351 let subarray = unsafe { PyDataType_SUBARRAY(self.py(), self.as_dtype_ptr()).as_ref() };
352 match subarray {
353 None => 0,
354 Some(subarray) => unsafe { PyTuple_Size(subarray.shape) }.max(0) as _,
355 }
356 }
357
358 fn base(&self) -> Bound<'py, PyArrayDescr> {
359 let subarray = unsafe { PyDataType_SUBARRAY(self.py(), self.as_dtype_ptr()).as_ref() };
360 match subarray {
361 None => self.clone(),
362 Some(subarray) => unsafe {
363 Bound::from_borrowed_ptr(self.py(), subarray.base.cast()).downcast_into_unchecked()
364 },
365 }
366 }
367
368 fn shape(&self) -> Vec<usize> {
369 let subarray = unsafe { PyDataType_SUBARRAY(self.py(), self.as_dtype_ptr()).as_ref() };
370 match subarray {
371 None => Vec::new(),
372 Some(subarray) => {
373 let shape = unsafe { Borrowed::from_ptr(self.py(), subarray.shape) };
375 shape.extract().unwrap()
376 }
377 }
378 }
379
380 fn has_subarray(&self) -> bool {
381 unsafe { !PyDataType_SUBARRAY(self.py(), self.as_dtype_ptr()).is_null() }
382 }
383
384 fn has_fields(&self) -> bool {
385 unsafe { !PyDataType_NAMES(self.py(), self.as_dtype_ptr()).is_null() }
386 }
387
388 fn names(&self) -> Option<Vec<String>> {
389 if !self.has_fields() {
390 return None;
391 }
392 let names = unsafe {
393 Borrowed::from_ptr(self.py(), PyDataType_NAMES(self.py(), self.as_dtype_ptr()))
394 };
395 names.extract().ok()
396 }
397
398 fn get_field(&self, name: &str) -> PyResult<(Bound<'py, PyArrayDescr>, usize)> {
399 if !self.has_fields() {
400 return Err(PyValueError::new_err(
401 "cannot get field information: type descriptor has no fields",
402 ));
403 }
404 let dict = unsafe {
405 Borrowed::from_ptr(self.py(), PyDataType_FIELDS(self.py(), self.as_dtype_ptr()))
406 };
407 let dict = unsafe { dict.downcast_unchecked::<PyDict>() };
408 let tuple = dict
410 .get_item(name)?
411 .ok_or_else(|| PyIndexError::new_err(name.to_owned()))?
412 .downcast_into::<PyTuple>()
413 .unwrap();
414 let dtype = tuple
416 .get_item(0)
417 .unwrap()
418 .downcast_into::<PyArrayDescr>()
419 .unwrap();
420 let offset = tuple.get_item(1).unwrap().extract().unwrap();
421 Ok((dtype, offset))
422 }
423}
424
425impl Sealed for Bound<'_, PyArrayDescr> {}
426
427pub unsafe trait Element: Sized + Send + Sync {
465 const IS_COPY: bool;
473
474 fn get_dtype(py: Python<'_>) -> Bound<'_, PyArrayDescr>;
476
477 fn clone_ref(&self, py: Python<'_>) -> Self;
479
480 #[inline]
485 fn vec_from_slice(py: Python<'_>, slc: &[Self]) -> Vec<Self> {
486 slc.iter().map(|elem| elem.clone_ref(py)).collect()
487 }
488
489 #[inline]
494 fn array_from_view<D>(
495 py: Python<'_>,
496 view: ::ndarray::ArrayView<'_, Self, D>,
497 ) -> ::ndarray::Array<Self, D>
498 where
499 D: ::ndarray::Dimension,
500 {
501 view.map(|elem| elem.clone_ref(py))
502 }
503}
504
505fn npy_int_type_lookup<T, T0, T1, T2>(npy_types: [NPY_TYPES; 3]) -> NPY_TYPES {
506 match size_of::<T>() {
510 x if x == size_of::<T0>() => npy_types[0],
511 x if x == size_of::<T1>() => npy_types[1],
512 x if x == size_of::<T2>() => npy_types[2],
513 _ => panic!("Unable to match integer type descriptor: {:?}", npy_types),
514 }
515}
516
517fn npy_int_type<T: Bounded + Zero + Sized + PartialEq>() -> NPY_TYPES {
518 let is_unsigned = T::min_value() == T::zero();
519 let bit_width = 8 * size_of::<T>();
520
521 match (is_unsigned, bit_width) {
522 (false, 8) => NPY_TYPES::NPY_BYTE,
523 (false, 16) => NPY_TYPES::NPY_SHORT,
524 (false, 32) => npy_int_type_lookup::<i32, c_long, c_int, c_short>([
525 NPY_TYPES::NPY_LONG,
526 NPY_TYPES::NPY_INT,
527 NPY_TYPES::NPY_SHORT,
528 ]),
529 (false, 64) => npy_int_type_lookup::<i64, c_long, c_longlong, c_int>([
530 NPY_TYPES::NPY_LONG,
531 NPY_TYPES::NPY_LONGLONG,
532 NPY_TYPES::NPY_INT,
533 ]),
534 (true, 8) => NPY_TYPES::NPY_UBYTE,
535 (true, 16) => NPY_TYPES::NPY_USHORT,
536 (true, 32) => npy_int_type_lookup::<u32, c_ulong, c_uint, c_ushort>([
537 NPY_TYPES::NPY_ULONG,
538 NPY_TYPES::NPY_UINT,
539 NPY_TYPES::NPY_USHORT,
540 ]),
541 (true, 64) => npy_int_type_lookup::<u64, c_ulong, c_ulonglong, c_uint>([
542 NPY_TYPES::NPY_ULONG,
543 NPY_TYPES::NPY_ULONGLONG,
544 NPY_TYPES::NPY_UINT,
545 ]),
546 _ => unreachable!(),
547 }
548}
549
550macro_rules! clone_methods_impl {
553 ($Self:ty) => {
554 #[inline]
555 fn clone_ref(&self, _py: ::pyo3::Python<'_>) -> $Self {
556 ::std::clone::Clone::clone(self)
557 }
558
559 #[inline]
560 fn vec_from_slice(_py: ::pyo3::Python<'_>, slc: &[$Self]) -> Vec<$Self> {
561 ::std::borrow::ToOwned::to_owned(slc)
562 }
563
564 #[inline]
565 fn array_from_view<D>(
566 _py: ::pyo3::Python<'_>,
567 view: ::ndarray::ArrayView<'_, $Self, D>,
568 ) -> ::ndarray::Array<$Self, D>
569 where
570 D: ::ndarray::Dimension,
571 {
572 ::ndarray::ArrayView::to_owned(&view)
573 }
574 };
575}
576pub(crate) use clone_methods_impl;
577use pyo3::BoundObject;
578
579macro_rules! impl_element_scalar {
580 (@impl: $ty:ty, $npy_type:expr $(,#[$meta:meta])*) => {
581 $(#[$meta])*
582 unsafe impl Element for $ty {
583 const IS_COPY: bool = true;
584
585 fn get_dtype(py: Python<'_>) -> Bound<'_, PyArrayDescr> {
586 PyArrayDescr::from_npy_type(py, $npy_type)
587 }
588
589 clone_methods_impl!($ty);
590 }
591 };
592 ($ty:ty => $npy_type:ident $(,#[$meta:meta])*) => {
593 impl_element_scalar!(@impl: $ty, NPY_TYPES::$npy_type $(,#[$meta])*);
594 };
595 ($($tys:ty),+) => {
596 $(impl_element_scalar!(@impl: $tys, npy_int_type::<$tys>());)+
597 };
598}
599
600impl_element_scalar!(bool => NPY_BOOL);
601
602impl_element_scalar!(i8, i16, i32, i64);
603impl_element_scalar!(u8, u16, u32, u64);
604
605impl_element_scalar!(f32 => NPY_FLOAT);
606impl_element_scalar!(f64 => NPY_DOUBLE);
607
608#[cfg(feature = "half")]
609impl_element_scalar!(f16 => NPY_HALF);
610
611#[cfg(feature = "half")]
612unsafe impl Element for bf16 {
613 const IS_COPY: bool = true;
614
615 fn get_dtype(py: Python<'_>) -> Bound<'_, PyArrayDescr> {
616 static DTYPE: GILOnceCell<Py<PyArrayDescr>> = GILOnceCell::new();
617
618 DTYPE
619 .get_or_init(py, || {
620 PyArrayDescr::new(py, "bfloat16").expect("A package which provides a `bfloat16` data type for NumPy is required to use the `half::bf16` element type.").unbind()
621 })
622 .clone_ref(py)
623 .into_bound(py)
624 }
625
626 clone_methods_impl!(Self);
627}
628
629impl_element_scalar!(Complex32 => NPY_CFLOAT,
630 #[doc = "Complex type with `f32` components which maps to `numpy.csingle` (`numpy.complex64`)."]);
631impl_element_scalar!(Complex64 => NPY_CDOUBLE,
632 #[doc = "Complex type with `f64` components which maps to `numpy.cdouble` (`numpy.complex128`)."]);
633
634#[cfg(any(target_pointer_width = "32", target_pointer_width = "64"))]
635impl_element_scalar!(usize, isize);
636
637unsafe impl Element for PyObject {
638 const IS_COPY: bool = false;
639
640 fn get_dtype(py: Python<'_>) -> Bound<'_, PyArrayDescr> {
641 PyArrayDescr::object(py)
642 }
643
644 #[inline]
645 fn clone_ref(&self, py: Python<'_>) -> Self {
646 Py::clone_ref(self, py)
647 }
648}
649
650#[cfg(test)]
651mod tests {
652 use super::*;
653
654 use pyo3::types::PyString;
655 use pyo3::{py_run, types::PyTypeMethods};
656
657 use crate::npyffi::{is_numpy_2, NPY_NEEDS_PYAPI};
658
659 #[test]
660 fn test_dtype_new() {
661 Python::with_gil(|py| {
662 assert!(PyArrayDescr::new(py, "float64")
663 .unwrap()
664 .is(dtype::<f64>(py)));
665
666 let dt = PyArrayDescr::new(py, [("a", "O"), ("b", "?")].as_ref()).unwrap();
667 assert_eq!(dt.names(), Some(vec!["a".to_owned(), "b".to_owned()]));
668 assert!(dt.has_object());
669 assert!(dt.get_field("a").unwrap().0.is(dtype::<PyObject>(py)));
670 assert!(dt.get_field("b").unwrap().0.is(dtype::<bool>(py)));
671
672 assert!(PyArrayDescr::new(py, 123_usize).is_err());
673 });
674 }
675
676 #[test]
677 fn test_dtype_names() {
678 fn type_name<T: Element>(py: Python<'_>) -> Bound<'_, PyString> {
679 dtype::<T>(py).typeobj().qualname().unwrap()
680 }
681 Python::with_gil(|py| {
682 if is_numpy_2(py) {
683 assert_eq!(type_name::<bool>(py), "bool");
684 } else {
685 assert_eq!(type_name::<bool>(py), "bool_");
686 }
687
688 assert_eq!(type_name::<i8>(py), "int8");
689 assert_eq!(type_name::<i16>(py), "int16");
690 assert_eq!(type_name::<i32>(py), "int32");
691 assert_eq!(type_name::<i64>(py), "int64");
692 assert_eq!(type_name::<u8>(py), "uint8");
693 assert_eq!(type_name::<u16>(py), "uint16");
694 assert_eq!(type_name::<u32>(py), "uint32");
695 assert_eq!(type_name::<u64>(py), "uint64");
696 assert_eq!(type_name::<f32>(py), "float32");
697 assert_eq!(type_name::<f64>(py), "float64");
698
699 assert_eq!(type_name::<Complex32>(py), "complex64");
700 assert_eq!(type_name::<Complex64>(py), "complex128");
701
702 #[cfg(target_pointer_width = "32")]
703 {
704 assert_eq!(type_name::<usize>(py), "uint32");
705 assert_eq!(type_name::<isize>(py), "int32");
706 }
707
708 #[cfg(target_pointer_width = "64")]
709 {
710 assert_eq!(type_name::<usize>(py), "uint64");
711 assert_eq!(type_name::<isize>(py), "int64");
712 }
713 });
714 }
715
716 #[test]
717 fn test_dtype_methods_scalar() {
718 Python::with_gil(|py| {
719 let dt = dtype::<f64>(py);
720
721 assert_eq!(dt.num(), NPY_TYPES::NPY_DOUBLE as c_int);
722 assert_eq!(dt.flags(), 0);
723 assert_eq!(dt.typeobj().qualname().unwrap(), "float64");
724 assert_eq!(dt.char(), b'd');
725 assert_eq!(dt.kind(), b'f');
726 assert_eq!(dt.byteorder(), b'=');
727 assert_eq!(dt.is_native_byteorder(), Some(true));
728 assert_eq!(dt.itemsize(), 8);
729 assert_eq!(dt.alignment(), 8);
730 assert!(!dt.has_object());
731 assert!(dt.names().is_none());
732 assert!(!dt.has_fields());
733 assert!(!dt.is_aligned_struct());
734 assert!(!dt.has_subarray());
735 assert!(dt.base().is_equiv_to(&dt));
736 assert_eq!(dt.ndim(), 0);
737 assert_eq!(dt.shape(), Vec::<usize>::new());
738 });
739 }
740
741 #[test]
742 fn test_dtype_methods_subarray() {
743 Python::with_gil(|py| {
744 let locals = PyDict::new(py);
745 py_run!(
746 py,
747 *locals,
748 "dtype = __import__('numpy').dtype(('f8', (2, 3)))"
749 );
750 let dt = locals
751 .get_item("dtype")
752 .unwrap()
753 .unwrap()
754 .downcast_into::<PyArrayDescr>()
755 .unwrap();
756
757 assert_eq!(dt.num(), NPY_TYPES::NPY_VOID as c_int);
758 assert_eq!(dt.flags(), 0);
759 assert_eq!(dt.typeobj().qualname().unwrap(), "void");
760 assert_eq!(dt.char(), b'V');
761 assert_eq!(dt.kind(), b'V');
762 assert_eq!(dt.byteorder(), b'|');
763 assert_eq!(dt.is_native_byteorder(), None);
764 assert_eq!(dt.itemsize(), 48);
765 assert_eq!(dt.alignment(), 8);
766 assert!(!dt.has_object());
767 assert!(dt.names().is_none());
768 assert!(!dt.has_fields());
769 assert!(!dt.is_aligned_struct());
770 assert!(dt.has_subarray());
771 assert_eq!(dt.ndim(), 2);
772 assert_eq!(dt.shape(), vec![2, 3]);
773 assert!(dt.base().is_equiv_to(&dtype::<f64>(py)));
774 });
775 }
776
777 #[test]
778 fn test_dtype_methods_record() {
779 Python::with_gil(|py| {
780 let locals = PyDict::new(py);
781 py_run!(
782 py,
783 *locals,
784 "dtype = __import__('numpy').dtype([('x', 'u1'), ('y', 'f8'), ('z', 'O')], align=True)"
785 );
786 let dt = locals
787 .get_item("dtype")
788 .unwrap()
789 .unwrap()
790 .downcast_into::<PyArrayDescr>()
791 .unwrap();
792
793 assert_eq!(dt.num(), NPY_TYPES::NPY_VOID as c_int);
794 assert_ne!(dt.flags() & NPY_ITEM_HASOBJECT, 0);
795 assert_ne!(dt.flags() & NPY_NEEDS_PYAPI, 0);
796 assert_ne!(dt.flags() & NPY_ALIGNED_STRUCT, 0);
797 assert_eq!(dt.typeobj().qualname().unwrap(), "void");
798 assert_eq!(dt.char(), b'V');
799 assert_eq!(dt.kind(), b'V');
800 assert_eq!(dt.byteorder(), b'|');
801 assert_eq!(dt.is_native_byteorder(), None);
802 assert_eq!(dt.itemsize(), 24);
803 assert_eq!(dt.alignment(), 8);
804 assert!(dt.has_object());
805 assert_eq!(
806 dt.names(),
807 Some(vec!["x".to_owned(), "y".to_owned(), "z".to_owned()])
808 );
809 assert!(dt.has_fields());
810 assert!(dt.is_aligned_struct());
811 assert!(!dt.has_subarray());
812 assert_eq!(dt.ndim(), 0);
813 assert_eq!(dt.shape(), Vec::<usize>::new());
814 assert!(dt.base().is_equiv_to(&dt));
815 let x = dt.get_field("x").unwrap();
816 assert!(x.0.is_equiv_to(&dtype::<u8>(py)));
817 assert_eq!(x.1, 0);
818 let y = dt.get_field("y").unwrap();
819 assert!(y.0.is_equiv_to(&dtype::<f64>(py)));
820 assert_eq!(y.1, 8);
821 let z = dt.get_field("z").unwrap();
822 assert!(z.0.is_equiv_to(&dtype::<PyObject>(py)));
823 assert_eq!(z.1, 16);
824 });
825 }
826}