1use std::mem::size_of;
2use std::os::raw::{c_int, c_long, c_longlong, c_short, c_uint, c_ulong, c_ulonglong, c_ushort};
3use std::ptr;
4
5#[cfg(feature = "half")]
6use half::{bf16, f16};
7use num_traits::{Bounded, Zero};
8#[cfg(feature = "half")]
9use pyo3::sync::PyOnceLock;
10use pyo3::{
11 conversion::IntoPyObject,
12 exceptions::{PyIndexError, PyValueError},
13 ffi::{self, PyTuple_Size},
14 pyobject_native_type_named,
15 types::{PyAnyMethods, PyDict, PyDictMethods, PyTuple, PyType},
16 Borrowed, Bound, Py, PyAny, PyResult, PyTypeInfo, Python,
17};
18
19use crate::npyffi::{
20 NpyTypes, PyArray_Descr, PyDataType_ALIGNMENT, PyDataType_ELSIZE, PyDataType_FIELDS,
21 PyDataType_FLAGS, PyDataType_NAMES, PyDataType_SUBARRAY, NPY_ALIGNED_STRUCT,
22 NPY_BYTEORDER_CHAR, NPY_ITEM_HASOBJECT, NPY_TYPES, PY_ARRAY_API,
23};
24
25pub use num_complex::{Complex32, Complex64};
26
27#[repr(transparent)]
51pub struct PyArrayDescr(PyAny);
52
53pyobject_native_type_named!(PyArrayDescr);
54
55unsafe impl PyTypeInfo for PyArrayDescr {
56 const NAME: &'static str = "PyArrayDescr";
57 const MODULE: Option<&'static str> = Some("numpy");
58
59 #[inline]
60 fn type_object_raw<'py>(py: Python<'py>) -> *mut ffi::PyTypeObject {
61 unsafe { PY_ARRAY_API.get_type_object(py, NpyTypes::PyArrayDescr_Type) }
62 }
63}
64
65#[inline]
67pub fn dtype<'py, T: Element>(py: Python<'py>) -> Bound<'py, PyArrayDescr> {
68 T::get_dtype(py)
69}
70
71impl PyArrayDescr {
72 #[inline]
78 pub fn new<'a, 'py, T>(py: Python<'py>, ob: T) -> PyResult<Bound<'py, Self>>
79 where
80 T: IntoPyObject<'py>,
81 {
82 fn inner<'py>(
83 py: Python<'py>,
84 obj: Borrowed<'_, 'py, PyAny>,
85 ) -> PyResult<Bound<'py, PyArrayDescr>> {
86 let mut descr: *mut PyArray_Descr = ptr::null_mut();
87 unsafe {
88 PY_ARRAY_API.PyArray_DescrConverter2(py, obj.as_ptr(), &mut descr);
90 Bound::from_owned_ptr_or_err(py, descr.cast()).map(|any| any.cast_into_unchecked())
91 }
92 }
93
94 inner(
95 py,
96 ob.into_pyobject(py)
97 .map_err(Into::into)?
98 .into_any()
99 .as_borrowed(),
100 )
101 }
102
103 #[inline]
105 pub fn object(py: Python<'_>) -> Bound<'_, Self> {
106 Self::from_npy_type(py, NPY_TYPES::NPY_OBJECT)
107 }
108
109 #[inline]
111 pub fn of<'py, T: Element>(py: Python<'py>) -> Bound<'py, Self> {
112 T::get_dtype(py)
113 }
114
115 fn from_npy_type<'py>(py: Python<'py>, npy_type: NPY_TYPES) -> Bound<'py, Self> {
116 unsafe {
117 let descr = PY_ARRAY_API.PyArray_DescrFromType(py, npy_type as _);
118 Bound::from_owned_ptr(py, descr.cast()).cast_into_unchecked()
119 }
120 }
121
122 pub(crate) fn new_from_npy_type<'py>(py: Python<'py>, npy_type: NPY_TYPES) -> Bound<'py, Self> {
123 unsafe {
124 let descr = PY_ARRAY_API.PyArray_DescrNewFromType(py, npy_type as _);
125 Bound::from_owned_ptr(py, descr.cast()).cast_into_unchecked()
126 }
127 }
128}
129
130#[doc(alias = "PyArrayDescr")]
132pub trait PyArrayDescrMethods<'py>: Sealed {
133 fn as_dtype_ptr(&self) -> *mut PyArray_Descr;
135
136 fn into_dtype_ptr(self) -> *mut PyArray_Descr;
140
141 fn is_equiv_to(&self, other: &Self) -> bool;
143
144 fn typeobj(&self) -> Bound<'py, PyType>;
151
152 fn num(&self) -> c_int {
162 unsafe { &*self.as_dtype_ptr() }.type_num
163 }
164
165 fn itemsize(&self) -> usize;
171
172 fn alignment(&self) -> usize;
178
179 fn byteorder(&self) -> u8 {
187 unsafe { &*self.as_dtype_ptr() }.byteorder.max(0) as _
188 }
189
190 fn char(&self) -> u8 {
198 unsafe { &*self.as_dtype_ptr() }.type_.max(0) as _
199 }
200
201 fn kind(&self) -> u8 {
209 unsafe { &*self.as_dtype_ptr() }.kind.max(0) as _
210 }
211
212 fn flags(&self) -> u64;
218
219 fn ndim(&self) -> usize;
225
226 fn base(&self) -> Bound<'py, PyArrayDescr>;
234
235 fn shape(&self) -> Vec<usize>;
243
244 fn has_object(&self) -> bool {
250 self.flags() & NPY_ITEM_HASOBJECT != 0
251 }
252
253 fn is_aligned_struct(&self) -> bool {
262 self.flags() & NPY_ALIGNED_STRUCT != 0
263 }
264
265 fn has_subarray(&self) -> bool;
269
270 fn has_fields(&self) -> bool;
274
275 fn is_native_byteorder(&self) -> Option<bool> {
277 match self.byteorder() {
279 b'=' => Some(true),
280 b'|' => None,
281 byteorder => Some(byteorder == NPY_BYTEORDER_CHAR::NPY_NATBYTE as u8),
282 }
283 }
284
285 fn names(&self) -> Option<Vec<String>>;
293
294 fn get_field(&self, name: &str) -> PyResult<(Bound<'py, PyArrayDescr>, usize)>;
305}
306
307mod sealed {
308 pub trait Sealed {}
309}
310
311use sealed::Sealed;
312
313impl<'py> PyArrayDescrMethods<'py> for Bound<'py, PyArrayDescr> {
314 fn as_dtype_ptr(&self) -> *mut PyArray_Descr {
315 self.as_ptr() as _
316 }
317
318 fn into_dtype_ptr(self) -> *mut PyArray_Descr {
319 self.into_ptr() as _
320 }
321
322 fn is_equiv_to(&self, other: &Self) -> bool {
323 let self_ptr = self.as_dtype_ptr();
324 let other_ptr = other.as_dtype_ptr();
325
326 unsafe {
327 self_ptr == other_ptr
328 || PY_ARRAY_API.PyArray_EquivTypes(self.py(), self_ptr, other_ptr) != 0
329 }
330 }
331
332 fn typeobj(&self) -> Bound<'py, PyType> {
333 let dtype_type_ptr = unsafe { &*self.as_dtype_ptr() }.typeobj;
334 unsafe { PyType::from_borrowed_type_ptr(self.py(), dtype_type_ptr) }
335 }
336
337 fn itemsize(&self) -> usize {
338 unsafe { PyDataType_ELSIZE(self.py(), self.as_dtype_ptr()).max(0) as _ }
339 }
340
341 fn alignment(&self) -> usize {
342 unsafe { PyDataType_ALIGNMENT(self.py(), self.as_dtype_ptr()).max(0) as _ }
343 }
344
345 fn flags(&self) -> u64 {
346 unsafe { PyDataType_FLAGS(self.py(), self.as_dtype_ptr()) as _ }
347 }
348
349 fn ndim(&self) -> usize {
350 let subarray = unsafe { PyDataType_SUBARRAY(self.py(), self.as_dtype_ptr()).as_ref() };
351 match subarray {
352 None => 0,
353 Some(subarray) => unsafe { PyTuple_Size(subarray.shape) }.max(0) as _,
354 }
355 }
356
357 fn base(&self) -> Bound<'py, PyArrayDescr> {
358 let subarray = unsafe { PyDataType_SUBARRAY(self.py(), self.as_dtype_ptr()).as_ref() };
359 match subarray {
360 None => self.clone(),
361 Some(subarray) => unsafe {
362 Bound::from_borrowed_ptr(self.py(), subarray.base.cast()).cast_into_unchecked()
363 },
364 }
365 }
366
367 fn shape(&self) -> Vec<usize> {
368 let subarray = unsafe { PyDataType_SUBARRAY(self.py(), self.as_dtype_ptr()).as_ref() };
369 match subarray {
370 None => Vec::new(),
371 Some(subarray) => {
372 let shape = unsafe { Borrowed::from_ptr(self.py(), subarray.shape) };
374 shape.extract().unwrap()
375 }
376 }
377 }
378
379 fn has_subarray(&self) -> bool {
380 unsafe { !PyDataType_SUBARRAY(self.py(), self.as_dtype_ptr()).is_null() }
381 }
382
383 fn has_fields(&self) -> bool {
384 unsafe { !PyDataType_NAMES(self.py(), self.as_dtype_ptr()).is_null() }
385 }
386
387 fn names(&self) -> Option<Vec<String>> {
388 if !self.has_fields() {
389 return None;
390 }
391 let names = unsafe {
392 Borrowed::from_ptr(self.py(), PyDataType_NAMES(self.py(), self.as_dtype_ptr()))
393 };
394 names.extract().ok()
395 }
396
397 fn get_field(&self, name: &str) -> PyResult<(Bound<'py, PyArrayDescr>, usize)> {
398 if !self.has_fields() {
399 return Err(PyValueError::new_err(
400 "cannot get field information: type descriptor has no fields",
401 ));
402 }
403 let dict = unsafe {
404 Borrowed::from_ptr(self.py(), PyDataType_FIELDS(self.py(), self.as_dtype_ptr()))
405 };
406 let dict = unsafe { dict.cast_unchecked::<PyDict>() };
407 let tuple = dict
409 .get_item(name)?
410 .ok_or_else(|| PyIndexError::new_err(name.to_owned()))?
411 .cast_into::<PyTuple>()
412 .unwrap();
413 let dtype = tuple
415 .get_item(0)
416 .unwrap()
417 .cast_into::<PyArrayDescr>()
418 .unwrap();
419 let offset = tuple.get_item(1).unwrap().extract().unwrap();
420 Ok((dtype, offset))
421 }
422}
423
424impl Sealed for Bound<'_, PyArrayDescr> {}
425
426pub unsafe trait Element: Sized + Send + Sync {
464 const IS_COPY: bool;
472
473 fn get_dtype(py: Python<'_>) -> Bound<'_, PyArrayDescr>;
475
476 fn clone_ref(&self, py: Python<'_>) -> Self;
478
479 #[inline]
484 fn vec_from_slice(py: Python<'_>, slc: &[Self]) -> Vec<Self> {
485 slc.iter().map(|elem| elem.clone_ref(py)).collect()
486 }
487
488 #[inline]
493 fn array_from_view<D>(
494 py: Python<'_>,
495 view: ::ndarray::ArrayView<'_, Self, D>,
496 ) -> ::ndarray::Array<Self, D>
497 where
498 D: ::ndarray::Dimension,
499 {
500 view.map(|elem| elem.clone_ref(py))
501 }
502}
503
504fn npy_int_type_lookup<T, T0, T1, T2>(npy_types: [NPY_TYPES; 3]) -> NPY_TYPES {
505 match size_of::<T>() {
509 x if x == size_of::<T0>() => npy_types[0],
510 x if x == size_of::<T1>() => npy_types[1],
511 x if x == size_of::<T2>() => npy_types[2],
512 _ => panic!("Unable to match integer type descriptor: {npy_types:?}"),
513 }
514}
515
516fn npy_int_type<T: Bounded + Zero + Sized + PartialEq>() -> NPY_TYPES {
517 let is_unsigned = T::min_value() == T::zero();
518 let bit_width = 8 * size_of::<T>();
519
520 match (is_unsigned, bit_width) {
521 (false, 8) => NPY_TYPES::NPY_BYTE,
522 (false, 16) => NPY_TYPES::NPY_SHORT,
523 (false, 32) => npy_int_type_lookup::<i32, c_long, c_int, c_short>([
524 NPY_TYPES::NPY_LONG,
525 NPY_TYPES::NPY_INT,
526 NPY_TYPES::NPY_SHORT,
527 ]),
528 (false, 64) => npy_int_type_lookup::<i64, c_long, c_longlong, c_int>([
529 NPY_TYPES::NPY_LONG,
530 NPY_TYPES::NPY_LONGLONG,
531 NPY_TYPES::NPY_INT,
532 ]),
533 (true, 8) => NPY_TYPES::NPY_UBYTE,
534 (true, 16) => NPY_TYPES::NPY_USHORT,
535 (true, 32) => npy_int_type_lookup::<u32, c_ulong, c_uint, c_ushort>([
536 NPY_TYPES::NPY_ULONG,
537 NPY_TYPES::NPY_UINT,
538 NPY_TYPES::NPY_USHORT,
539 ]),
540 (true, 64) => npy_int_type_lookup::<u64, c_ulong, c_ulonglong, c_uint>([
541 NPY_TYPES::NPY_ULONG,
542 NPY_TYPES::NPY_ULONGLONG,
543 NPY_TYPES::NPY_UINT,
544 ]),
545 _ => unreachable!(),
546 }
547}
548
549macro_rules! clone_methods_impl {
552 ($Self:ty) => {
553 #[inline]
554 fn clone_ref(&self, _py: ::pyo3::Python<'_>) -> $Self {
555 ::std::clone::Clone::clone(self)
556 }
557
558 #[inline]
559 fn vec_from_slice(_py: ::pyo3::Python<'_>, slc: &[$Self]) -> Vec<$Self> {
560 ::std::borrow::ToOwned::to_owned(slc)
561 }
562
563 #[inline]
564 fn array_from_view<D>(
565 _py: ::pyo3::Python<'_>,
566 view: ::ndarray::ArrayView<'_, $Self, D>,
567 ) -> ::ndarray::Array<$Self, D>
568 where
569 D: ::ndarray::Dimension,
570 {
571 ::ndarray::ArrayView::to_owned(&view)
572 }
573 };
574}
575pub(crate) use clone_methods_impl;
576use pyo3::BoundObject;
577
578macro_rules! impl_element_scalar {
579 (@impl: $ty:ty, $npy_type:expr $(,#[$meta:meta])*) => {
580 $(#[$meta])*
581 unsafe impl Element for $ty {
582 const IS_COPY: bool = true;
583
584 fn get_dtype(py: Python<'_>) -> Bound<'_, PyArrayDescr> {
585 PyArrayDescr::from_npy_type(py, $npy_type)
586 }
587
588 clone_methods_impl!($ty);
589 }
590 };
591 ($ty:ty => $npy_type:ident $(,#[$meta:meta])*) => {
592 impl_element_scalar!(@impl: $ty, NPY_TYPES::$npy_type $(,#[$meta])*);
593 };
594 ($($tys:ty),+) => {
595 $(impl_element_scalar!(@impl: $tys, npy_int_type::<$tys>());)+
596 };
597}
598
599impl_element_scalar!(bool => NPY_BOOL);
600
601impl_element_scalar!(i8, i16, i32, i64);
602impl_element_scalar!(u8, u16, u32, u64);
603
604impl_element_scalar!(f32 => NPY_FLOAT);
605impl_element_scalar!(f64 => NPY_DOUBLE);
606
607#[cfg(feature = "half")]
608impl_element_scalar!(f16 => NPY_HALF);
609
610#[cfg(feature = "half")]
611unsafe impl Element for bf16 {
612 const IS_COPY: bool = true;
613
614 fn get_dtype(py: Python<'_>) -> Bound<'_, PyArrayDescr> {
615 static DTYPE: PyOnceLock<Py<PyArrayDescr>> = PyOnceLock::new();
616
617 DTYPE
618 .get_or_init(py, || {
619 PyArrayDescr::new(py, "bfloat16").expect("A package which provides a `bfloat16` data type for NumPy is required to use the `half::bf16` element type.").unbind()
620 })
621 .clone_ref(py)
622 .into_bound(py)
623 }
624
625 clone_methods_impl!(Self);
626}
627
628impl_element_scalar!(Complex32 => NPY_CFLOAT,
629 #[doc = "Complex type with `f32` components which maps to `numpy.csingle` (`numpy.complex64`)."]);
630impl_element_scalar!(Complex64 => NPY_CDOUBLE,
631 #[doc = "Complex type with `f64` components which maps to `numpy.cdouble` (`numpy.complex128`)."]);
632
633#[cfg(any(target_pointer_width = "32", target_pointer_width = "64"))]
634impl_element_scalar!(usize, isize);
635
636unsafe impl Element for Py<PyAny> {
637 const IS_COPY: bool = false;
638
639 fn get_dtype(py: Python<'_>) -> Bound<'_, PyArrayDescr> {
640 PyArrayDescr::object(py)
641 }
642
643 #[inline]
644 fn clone_ref(&self, py: Python<'_>) -> Self {
645 Py::clone_ref(self, py)
646 }
647}
648
649#[cfg(test)]
650mod tests {
651 use super::*;
652
653 use pyo3::types::PyString;
654 use pyo3::{py_run, types::PyTypeMethods};
655
656 use crate::npyffi::{is_numpy_2, NPY_NEEDS_PYAPI};
657
658 #[test]
659 fn test_dtype_new() {
660 Python::attach(|py| {
661 assert!(PyArrayDescr::new(py, "float64")
662 .unwrap()
663 .is(dtype::<f64>(py)));
664
665 let dt = PyArrayDescr::new(py, [("a", "O"), ("b", "?")].as_ref()).unwrap();
666 assert_eq!(dt.names(), Some(vec!["a".to_owned(), "b".to_owned()]));
667 assert!(dt.has_object());
668 assert!(dt.get_field("a").unwrap().0.is(dtype::<Py<PyAny>>(py)));
669 assert!(dt.get_field("b").unwrap().0.is(dtype::<bool>(py)));
670
671 assert!(PyArrayDescr::new(py, 123_usize).is_err());
672 });
673 }
674
675 #[test]
676 fn test_dtype_names() {
677 fn type_name<T: Element>(py: Python<'_>) -> Bound<'_, PyString> {
678 dtype::<T>(py).typeobj().qualname().unwrap()
679 }
680 Python::attach(|py| {
681 if is_numpy_2(py) {
682 assert_eq!(type_name::<bool>(py), "bool");
683 } else {
684 assert_eq!(type_name::<bool>(py), "bool_");
685 }
686
687 assert_eq!(type_name::<i8>(py), "int8");
688 assert_eq!(type_name::<i16>(py), "int16");
689 assert_eq!(type_name::<i32>(py), "int32");
690 assert_eq!(type_name::<i64>(py), "int64");
691 assert_eq!(type_name::<u8>(py), "uint8");
692 assert_eq!(type_name::<u16>(py), "uint16");
693 assert_eq!(type_name::<u32>(py), "uint32");
694 assert_eq!(type_name::<u64>(py), "uint64");
695 assert_eq!(type_name::<f32>(py), "float32");
696 assert_eq!(type_name::<f64>(py), "float64");
697
698 assert_eq!(type_name::<Complex32>(py), "complex64");
699 assert_eq!(type_name::<Complex64>(py), "complex128");
700
701 #[cfg(target_pointer_width = "32")]
702 {
703 assert_eq!(type_name::<usize>(py), "uint32");
704 assert_eq!(type_name::<isize>(py), "int32");
705 }
706
707 #[cfg(target_pointer_width = "64")]
708 {
709 assert_eq!(type_name::<usize>(py), "uint64");
710 assert_eq!(type_name::<isize>(py), "int64");
711 }
712 });
713 }
714
715 #[test]
716 fn test_dtype_methods_scalar() {
717 Python::attach(|py| {
718 let dt = dtype::<f64>(py);
719
720 assert_eq!(dt.num(), NPY_TYPES::NPY_DOUBLE as c_int);
721 assert_eq!(dt.flags(), 0);
722 assert_eq!(dt.typeobj().qualname().unwrap(), "float64");
723 assert_eq!(dt.char(), b'd');
724 assert_eq!(dt.kind(), b'f');
725 assert_eq!(dt.byteorder(), b'=');
726 assert_eq!(dt.is_native_byteorder(), Some(true));
727 assert_eq!(dt.itemsize(), 8);
728 assert_eq!(dt.alignment(), 8);
729 assert!(!dt.has_object());
730 assert!(dt.names().is_none());
731 assert!(!dt.has_fields());
732 assert!(!dt.is_aligned_struct());
733 assert!(!dt.has_subarray());
734 assert!(dt.base().is_equiv_to(&dt));
735 assert_eq!(dt.ndim(), 0);
736 assert_eq!(dt.shape(), Vec::<usize>::new());
737 });
738 }
739
740 #[test]
741 fn test_dtype_methods_subarray() {
742 Python::attach(|py| {
743 let locals = PyDict::new(py);
744 py_run!(
745 py,
746 *locals,
747 "dtype = __import__('numpy').dtype(('f8', (2, 3)))"
748 );
749 let dt = locals
750 .get_item("dtype")
751 .unwrap()
752 .unwrap()
753 .cast_into::<PyArrayDescr>()
754 .unwrap();
755
756 assert_eq!(dt.num(), NPY_TYPES::NPY_VOID as c_int);
757 assert_eq!(dt.flags(), 0);
758 assert_eq!(dt.typeobj().qualname().unwrap(), "void");
759 assert_eq!(dt.char(), b'V');
760 assert_eq!(dt.kind(), b'V');
761 assert_eq!(dt.byteorder(), b'|');
762 assert_eq!(dt.is_native_byteorder(), None);
763 assert_eq!(dt.itemsize(), 48);
764 assert_eq!(dt.alignment(), 8);
765 assert!(!dt.has_object());
766 assert!(dt.names().is_none());
767 assert!(!dt.has_fields());
768 assert!(!dt.is_aligned_struct());
769 assert!(dt.has_subarray());
770 assert_eq!(dt.ndim(), 2);
771 assert_eq!(dt.shape(), vec![2, 3]);
772 assert!(dt.base().is_equiv_to(&dtype::<f64>(py)));
773 });
774 }
775
776 #[test]
777 fn test_dtype_methods_record() {
778 Python::attach(|py| {
779 let locals = PyDict::new(py);
780 py_run!(
781 py,
782 *locals,
783 "dtype = __import__('numpy').dtype([('x', 'u1'), ('y', 'f8'), ('z', 'O')], align=True)"
784 );
785 let dt = locals
786 .get_item("dtype")
787 .unwrap()
788 .unwrap()
789 .cast_into::<PyArrayDescr>()
790 .unwrap();
791
792 assert_eq!(dt.num(), NPY_TYPES::NPY_VOID as c_int);
793 assert_ne!(dt.flags() & NPY_ITEM_HASOBJECT, 0);
794 assert_ne!(dt.flags() & NPY_NEEDS_PYAPI, 0);
795 assert_ne!(dt.flags() & NPY_ALIGNED_STRUCT, 0);
796 assert_eq!(dt.typeobj().qualname().unwrap(), "void");
797 assert_eq!(dt.char(), b'V');
798 assert_eq!(dt.kind(), b'V');
799 assert_eq!(dt.byteorder(), b'|');
800 assert_eq!(dt.is_native_byteorder(), None);
801 assert_eq!(dt.itemsize(), 24);
802 assert_eq!(dt.alignment(), 8);
803 assert!(dt.has_object());
804 assert_eq!(
805 dt.names(),
806 Some(vec!["x".to_owned(), "y".to_owned(), "z".to_owned()])
807 );
808 assert!(dt.has_fields());
809 assert!(dt.is_aligned_struct());
810 assert!(!dt.has_subarray());
811 assert_eq!(dt.ndim(), 0);
812 assert_eq!(dt.shape(), Vec::<usize>::new());
813 assert!(dt.base().is_equiv_to(&dt));
814 let x = dt.get_field("x").unwrap();
815 assert!(x.0.is_equiv_to(&dtype::<u8>(py)));
816 assert_eq!(x.1, 0);
817 let y = dt.get_field("y").unwrap();
818 assert!(y.0.is_equiv_to(&dtype::<f64>(py)));
819 assert_eq!(y.1, 8);
820 let z = dt.get_field("z").unwrap();
821 assert!(z.0.is_equiv_to(&dtype::<Py<PyAny>>(py)));
822 assert_eq!(z.1, 16);
823 });
824 }
825}