numpy/npyffi/
mod.rs

1//! Low-Level bindings for NumPy C API.
2//!
3//! This module provides FFI bindings to [NumPy C API], implementing access to the NumPy array and
4//! ufunc functionality. This binding is compatible with ABI v2 and the target API is v1.15 to
5//! ensure the compatibility with the older NumPy version. See the official NumPy documentation
6//! for more details about [API compatibility].
7//!
8//! [NumPy's C API]: https://numpy.org/doc/stable/reference/c-api
9//! [API compatibility]: https://numpy.org/doc/stable/dev/depending_on_numpy.html
10//!
11#![allow(
12    non_camel_case_types,
13    missing_docs,
14    missing_debug_implementations,
15    clippy::too_many_arguments,
16    clippy::missing_safety_doc
17)]
18
19use std::ffi::{c_uint, c_void};
20use std::mem::forget;
21use std::ptr::NonNull;
22
23use pyo3::{
24    ffi::PyTypeObject,
25    sync::PyOnceLock,
26    types::{PyAnyMethods, PyCapsule, PyCapsuleMethods, PyModule},
27    PyResult, Python,
28};
29
30static API_VERSION: PyOnceLock<c_uint> = PyOnceLock::new();
31
32fn get_numpy_api<'py>(
33    py: Python<'py>,
34    module: &str,
35    capsule: &str,
36) -> PyResult<NonNull<*const c_void>> {
37    let module = PyModule::import(py, module)?;
38    let capsule = module.getattr(capsule)?.cast_into::<PyCapsule>()?;
39
40    let api = capsule.pointer_checked(None)?;
41
42    // Intentionally leak a reference to the capsule
43    // so we can safely cache a pointer into its interior.
44    forget(capsule);
45
46    Ok(api.cast())
47}
48
49/// Returns whether the runtime `numpy` version is 2.0 or greater.
50pub fn is_numpy_2<'py>(py: Python<'py>) -> bool {
51    let api_version = *API_VERSION.get_or_init(py, || unsafe {
52        PY_ARRAY_API.PyArray_GetNDArrayCFeatureVersion(py)
53    });
54    api_version >= NPY_2_0_API_VERSION
55}
56
57// Implements wrappers for NumPy's Array and UFunc API
58macro_rules! impl_api {
59    // API available on all versions
60    [$offset: expr; $fname: ident ($($arg: ident: $t: ty),* $(,)?) $(-> $ret: ty)?] => {
61        #[allow(non_snake_case)]
62        pub unsafe fn $fname<'py>(&self, py: Python<'py>, $($arg : $t), *) $(-> $ret)* {
63            let f: extern "C" fn ($($arg : $t), *) $(-> $ret)* = self.get(py, $offset).cast().read();
64            f($($arg), *)
65        }
66    };
67}
68
69// Define type objects associated with the NumPy API
70macro_rules! impl_array_type {
71    ($(($api:ident [ $offset:expr ] , $tname:ident)),* $(,)?) => {
72        /// All type objects exported by the NumPy API.
73        #[allow(non_camel_case_types)]
74        pub enum NpyTypes { $($tname),* }
75
76        /// Get a pointer of the type object associated with `ty`.
77        pub unsafe fn get_type_object<'py>(py: Python<'py>, ty: NpyTypes) -> *mut PyTypeObject {
78            match ty {
79                $( NpyTypes::$tname => $api.get(py, $offset).read() as _ ),*
80            }
81        }
82    }
83}
84
85impl_array_type! {
86    // Multiarray API
87    // Slot 1 was never meaningfully used by NumPy
88    (PY_ARRAY_API[2], PyArray_Type),
89    (PY_ARRAY_API[3], PyArrayDescr_Type),
90    // Unused slot 4, was `PyArrayFlags_Type`
91    (PY_ARRAY_API[5], PyArrayIter_Type),
92    (PY_ARRAY_API[6], PyArrayMultiIter_Type),
93    // (PY_ARRAY_API[7], NPY_NUMUSERTYPES) -> c_int,
94    (PY_ARRAY_API[8], PyBoolArrType_Type),
95    // (PY_ARRAY_API[9], _PyArrayScalar_BoolValues) -> *mut PyBoolScalarObject,
96    (PY_ARRAY_API[10], PyGenericArrType_Type),
97    (PY_ARRAY_API[11], PyNumberArrType_Type),
98    (PY_ARRAY_API[12], PyIntegerArrType_Type),
99    (PY_ARRAY_API[13], PySignedIntegerArrType_Type),
100    (PY_ARRAY_API[14], PyUnsignedIntegerArrType_Type),
101    (PY_ARRAY_API[15], PyInexactArrType_Type),
102    (PY_ARRAY_API[16], PyFloatingArrType_Type),
103    (PY_ARRAY_API[17], PyComplexFloatingArrType_Type),
104    (PY_ARRAY_API[18], PyFlexibleArrType_Type),
105    (PY_ARRAY_API[19], PyCharacterArrType_Type),
106    (PY_ARRAY_API[20], PyByteArrType_Type),
107    (PY_ARRAY_API[21], PyShortArrType_Type),
108    (PY_ARRAY_API[22], PyIntArrType_Type),
109    (PY_ARRAY_API[23], PyLongArrType_Type),
110    (PY_ARRAY_API[24], PyLongLongArrType_Type),
111    (PY_ARRAY_API[25], PyUByteArrType_Type),
112    (PY_ARRAY_API[26], PyUShortArrType_Type),
113    (PY_ARRAY_API[27], PyUIntArrType_Type),
114    (PY_ARRAY_API[28], PyULongArrType_Type),
115    (PY_ARRAY_API[29], PyULongLongArrType_Type),
116    (PY_ARRAY_API[30], PyFloatArrType_Type),
117    (PY_ARRAY_API[31], PyDoubleArrType_Type),
118    (PY_ARRAY_API[32], PyLongDoubleArrType_Type),
119    (PY_ARRAY_API[33], PyCFloatArrType_Type),
120    (PY_ARRAY_API[34], PyCDoubleArrType_Type),
121    (PY_ARRAY_API[35], PyCLongDoubleArrType_Type),
122    (PY_ARRAY_API[36], PyObjectArrType_Type),
123    (PY_ARRAY_API[37], PyStringArrType_Type),
124    (PY_ARRAY_API[38], PyUnicodeArrType_Type),
125    (PY_ARRAY_API[39], PyVoidArrType_Type),
126    (PY_ARRAY_API[214], PyTimeIntegerArrType_Type),
127    (PY_ARRAY_API[215], PyDatetimeArrType_Type),
128    (PY_ARRAY_API[216], PyTimedeltaArrType_Type),
129    (PY_ARRAY_API[217], PyHalfArrType_Type),
130    (PY_ARRAY_API[218], NpyIter_Type),
131    // UFunc API
132    (PY_UFUNC_API[0], PyUFunc_Type),
133}
134
135pub mod array;
136pub mod flags;
137mod npy_common;
138mod numpyconfig;
139pub mod objects;
140pub mod types;
141pub mod ufunc;
142
143pub use self::array::*;
144pub use self::flags::*;
145pub use self::npy_common::*;
146pub use self::numpyconfig::*;
147pub use self::objects::*;
148pub use self::types::*;
149pub use self::ufunc::*;