1#![allow(
5 non_camel_case_types,
6 missing_docs,
7 missing_debug_implementations,
8 clippy::too_many_arguments,
9 clippy::missing_safety_doc
10)]
11
12use std::mem::forget;
13use std::os::raw::{c_uint, c_void};
14
15use pyo3::{
16 sync::GILOnceCell,
17 types::{PyAnyMethods, PyCapsule, PyCapsuleMethods, PyModule},
18 PyResult, Python,
19};
20
21pub const API_VERSION_2_0: c_uint = 0x00000012;
22
23static API_VERSION: GILOnceCell<c_uint> = GILOnceCell::new();
24
25fn get_numpy_api<'py>(
26 py: Python<'py>,
27 module: &str,
28 capsule: &str,
29) -> PyResult<*const *const c_void> {
30 let module = PyModule::import(py, module)?;
31 let capsule = module.getattr(capsule)?.downcast_into::<PyCapsule>()?;
32
33 let api = capsule.pointer() as *const *const c_void;
34
35 forget(capsule);
38
39 Ok(api)
40}
41
42pub fn is_numpy_2<'py>(py: Python<'py>) -> bool {
44 let api_version = *API_VERSION.get_or_init(py, || unsafe {
45 PY_ARRAY_API.PyArray_GetNDArrayCFeatureVersion(py)
46 });
47 api_version >= API_VERSION_2_0
48}
49
50macro_rules! impl_api {
52 [$offset: expr; $fname: ident ($($arg: ident: $t: ty),* $(,)?) $(-> $ret: ty)?] => {
54 #[allow(non_snake_case)]
55 pub unsafe fn $fname<'py>(&self, py: Python<'py>, $($arg : $t), *) $(-> $ret)* {
56 let fptr = self.get(py, $offset) as *const extern "C" fn ($($arg : $t), *) $(-> $ret)*;
57 (*fptr)($($arg), *)
58 }
59 };
60
61 [$offset: expr; NumPy1; $fname: ident ($($arg: ident: $t: ty),* $(,)?) $(-> $ret: ty)?] => {
63 #[allow(non_snake_case)]
64 pub unsafe fn $fname<'py>(&self, py: Python<'py>, $($arg : $t), *) $(-> $ret)* {
65 assert!(
66 !is_numpy_2(py),
67 "{} requires API < {:08X} (NumPy 1) but the runtime version is API {:08X}",
68 stringify!($fname),
69 API_VERSION_2_0,
70 *API_VERSION.get(py).expect("API_VERSION is initialized"),
71 );
72 let fptr = self.get(py, $offset) as *const extern "C" fn ($($arg: $t), *) $(-> $ret)*;
73 (*fptr)($($arg), *)
74 }
75
76 };
77 [$offset: expr; NumPy2; $fname: ident ($($arg: ident: $t: ty),* $(,)?) $(-> $ret: ty)?] => {
78 #[allow(non_snake_case)]
79 pub unsafe fn $fname<'py>(&self, py: Python<'py>, $($arg : $t), *) $(-> $ret)* {
80 assert!(
81 is_numpy_2(py),
82 "{} requires API {:08X} or greater (NumPy 2) but the runtime version is API {:08X}",
83 stringify!($fname),
84 API_VERSION_2_0,
85 *API_VERSION.get(py).expect("API_VERSION is initialized"),
86 );
87 let fptr = self.get(py, $offset) as *const extern "C" fn ($($arg: $t), *) $(-> $ret)*;
88 (*fptr)($($arg), *)
89 }
90
91 };
92}
93
94pub mod array;
95pub mod flags;
96pub mod objects;
97pub mod types;
98pub mod ufunc;
99
100pub use self::array::*;
101pub use self::flags::*;
102pub use self::objects::*;
103pub use self::types::*;
104pub use self::ufunc::*;