numpy/npyffi/
mod.rs

1//! Low-Level bindings for NumPy C API.
2//!
3//! <https://numpy.org/doc/stable/reference/c-api>
4#![allow(
5    non_camel_case_types,
6    missing_docs,
7    missing_debug_implementations,
8    clippy::too_many_arguments,
9    clippy::missing_safety_doc
10)]
11
12use std::mem::forget;
13use std::os::raw::{c_uint, c_void};
14
15use pyo3::{
16    sync::GILOnceCell,
17    types::{PyAnyMethods, PyCapsule, PyCapsuleMethods, PyModule},
18    PyResult, Python,
19};
20
21pub const API_VERSION_2_0: c_uint = 0x00000012;
22
23static API_VERSION: GILOnceCell<c_uint> = GILOnceCell::new();
24
25fn get_numpy_api<'py>(
26    py: Python<'py>,
27    module: &str,
28    capsule: &str,
29) -> PyResult<*const *const c_void> {
30    let module = PyModule::import(py, module)?;
31    let capsule = module.getattr(capsule)?.downcast_into::<PyCapsule>()?;
32
33    let api = capsule.pointer() as *const *const c_void;
34
35    // Intentionally leak a reference to the capsule
36    // so we can safely cache a pointer into its interior.
37    forget(capsule);
38
39    Ok(api)
40}
41
42/// Returns whether the runtime `numpy` version is 2.0 or greater.
43pub fn is_numpy_2<'py>(py: Python<'py>) -> bool {
44    let api_version = *API_VERSION.get_or_init(py, || unsafe {
45        PY_ARRAY_API.PyArray_GetNDArrayCFeatureVersion(py)
46    });
47    api_version >= API_VERSION_2_0
48}
49
50// Implements wrappers for NumPy's Array and UFunc API
51macro_rules! impl_api {
52    // API available on all versions
53    [$offset: expr; $fname: ident ($($arg: ident: $t: ty),* $(,)?) $(-> $ret: ty)?] => {
54        #[allow(non_snake_case)]
55        pub unsafe fn $fname<'py>(&self, py: Python<'py>, $($arg : $t), *) $(-> $ret)* {
56            let fptr = self.get(py, $offset) as *const extern "C" fn ($($arg : $t), *) $(-> $ret)*;
57            (*fptr)($($arg), *)
58        }
59    };
60
61    // API with version constraints, checked at runtime
62    [$offset: expr; NumPy1; $fname: ident ($($arg: ident: $t: ty),* $(,)?) $(-> $ret: ty)?] => {
63        #[allow(non_snake_case)]
64        pub unsafe fn $fname<'py>(&self, py: Python<'py>, $($arg : $t), *) $(-> $ret)* {
65            assert!(
66                !is_numpy_2(py),
67                "{} requires API < {:08X} (NumPy 1) but the runtime version is API {:08X}",
68                stringify!($fname),
69                API_VERSION_2_0,
70                *API_VERSION.get(py).expect("API_VERSION is initialized"),
71            );
72            let fptr = self.get(py, $offset) as *const extern "C" fn ($($arg: $t), *) $(-> $ret)*;
73            (*fptr)($($arg), *)
74        }
75
76    };
77    [$offset: expr; NumPy2; $fname: ident ($($arg: ident: $t: ty),* $(,)?) $(-> $ret: ty)?] => {
78        #[allow(non_snake_case)]
79        pub unsafe fn $fname<'py>(&self, py: Python<'py>, $($arg : $t), *) $(-> $ret)* {
80            assert!(
81                is_numpy_2(py),
82                "{} requires API {:08X} or greater (NumPy 2) but the runtime version is API {:08X}",
83                stringify!($fname),
84                API_VERSION_2_0,
85                *API_VERSION.get(py).expect("API_VERSION is initialized"),
86            );
87            let fptr = self.get(py, $offset) as *const extern "C" fn ($($arg: $t), *) $(-> $ret)*;
88            (*fptr)($($arg), *)
89        }
90
91    };
92}
93
94pub mod array;
95pub mod flags;
96pub mod objects;
97pub mod types;
98pub mod ufunc;
99
100pub use self::array::*;
101pub use self::flags::*;
102pub use self::objects::*;
103pub use self::types::*;
104pub use self::ufunc::*;