numpy/sum_products.rs
1use std::borrow::Cow;
2use std::ffi::{CStr, CString};
3use std::ptr::null_mut;
4
5use ndarray::{Dimension, IxDyn};
6use pyo3::types::PyAnyMethods;
7use pyo3::{Borrowed, Bound, FromPyObject, PyResult};
8
9use crate::array::PyArray;
10use crate::dtype::Element;
11use crate::npyffi::{array::PY_ARRAY_API, NPY_CASTING, NPY_ORDER};
12
13/// Return value of a function that can yield either an array or a scalar.
14pub trait ArrayOrScalar<'py, T>: FromPyObject<'py> {}
15
16impl<'py, T, D> ArrayOrScalar<'py, T> for Bound<'py, PyArray<T, D>>
17where
18 T: Element,
19 D: Dimension,
20{
21}
22
23impl<'py, T> ArrayOrScalar<'py, T> for T where T: Element + FromPyObject<'py> {}
24
25/// Return the inner product of two arrays.
26///
27/// [NumPy's documentation][inner] has the details.
28///
29/// # Examples
30///
31/// Note that this function can either return a scalar...
32///
33/// ```
34/// use pyo3::Python;
35/// use numpy::{inner, pyarray, PyArray0};
36///
37/// Python::with_gil(|py| {
38/// let vector = pyarray![py, 1.0, 2.0, 3.0];
39/// let result: f64 = inner(&vector, &vector).unwrap();
40/// assert_eq!(result, 14.0);
41/// });
42/// ```
43///
44/// ...or an array depending on its arguments.
45///
46/// ```
47/// use pyo3::{Python, Bound};
48/// use numpy::prelude::*;
49/// use numpy::{inner, pyarray, PyArray0};
50///
51/// Python::with_gil(|py| {
52/// let vector = pyarray![py, 1, 2, 3];
53/// let result: Bound<'_, PyArray0<_>> = inner(&vector, &vector).unwrap();
54/// assert_eq!(result.item(), 14);
55/// });
56/// ```
57///
58/// [inner]: https://numpy.org/doc/stable/reference/generated/numpy.inner.html
59pub fn inner<'py, T, DIN1, DIN2, OUT>(
60 array1: &Bound<'py, PyArray<T, DIN1>>,
61 array2: &Bound<'py, PyArray<T, DIN2>>,
62) -> PyResult<OUT>
63where
64 T: Element,
65 DIN1: Dimension,
66 DIN2: Dimension,
67 OUT: ArrayOrScalar<'py, T>,
68{
69 let py = array1.py();
70 let obj = unsafe {
71 let result = PY_ARRAY_API.PyArray_InnerProduct(py, array1.as_ptr(), array2.as_ptr());
72 Bound::from_owned_ptr_or_err(py, result)?
73 };
74 obj.extract()
75}
76
77/// Return the dot product of two arrays.
78///
79/// [NumPy's documentation][dot] has the details.
80///
81/// # Examples
82///
83/// Note that this function can either return an array...
84///
85/// ```
86/// use pyo3::{Python, Bound};
87/// use ndarray::array;
88/// use numpy::{dot, pyarray, PyArray2, PyArrayMethods};
89///
90/// Python::with_gil(|py| {
91/// let matrix = pyarray![py, [1, 0], [0, 1]];
92/// let another_matrix = pyarray![py, [4, 1], [2, 2]];
93///
94/// let result: Bound<'_, PyArray2<_>> = dot(&matrix, &another_matrix).unwrap();
95///
96/// assert_eq!(
97/// result.readonly().as_array(),
98/// array![[4, 1], [2, 2]]
99/// );
100/// });
101/// ```
102///
103/// ...or a scalar depending on its arguments.
104///
105/// ```
106/// use pyo3::Python;
107/// use numpy::{dot, pyarray, PyArray0};
108///
109/// Python::with_gil(|py| {
110/// let vector = pyarray![py, 1.0, 2.0, 3.0];
111/// let result: f64 = dot(&vector, &vector).unwrap();
112/// assert_eq!(result, 14.0);
113/// });
114/// ```
115///
116/// [dot]: https://numpy.org/doc/stable/reference/generated/numpy.dot.html
117pub fn dot<'py, T, DIN1, DIN2, OUT>(
118 array1: &Bound<'py, PyArray<T, DIN1>>,
119 array2: &Bound<'py, PyArray<T, DIN2>>,
120) -> PyResult<OUT>
121where
122 T: Element,
123 DIN1: Dimension,
124 DIN2: Dimension,
125 OUT: ArrayOrScalar<'py, T>,
126{
127 let py = array1.py();
128 let obj = unsafe {
129 let result = PY_ARRAY_API.PyArray_MatrixProduct(py, array1.as_ptr(), array2.as_ptr());
130 Bound::from_owned_ptr_or_err(py, result)?
131 };
132 obj.extract()
133}
134
135/// Return the Einstein summation convention of given tensors.
136///
137/// This is usually invoked via the the [`einsum!`][crate::einsum!] macro.
138pub fn einsum<'py, T, OUT>(
139 subscripts: &str,
140 arrays: &[Borrowed<'_, 'py, PyArray<T, IxDyn>>],
141) -> PyResult<OUT>
142where
143 T: Element,
144 OUT: ArrayOrScalar<'py, T>,
145{
146 let subscripts = match CStr::from_bytes_with_nul(subscripts.as_bytes()) {
147 Ok(subscripts) => Cow::Borrowed(subscripts),
148 Err(_) => Cow::Owned(CString::new(subscripts).unwrap()),
149 };
150
151 let py = arrays[0].py();
152 let obj = unsafe {
153 let result = PY_ARRAY_API.PyArray_EinsteinSum(
154 py,
155 subscripts.as_ptr() as _,
156 arrays.len() as _,
157 arrays.as_ptr() as _,
158 null_mut(),
159 NPY_ORDER::NPY_KEEPORDER,
160 NPY_CASTING::NPY_NO_CASTING,
161 null_mut(),
162 );
163 Bound::from_owned_ptr_or_err(py, result)?
164 };
165 obj.extract()
166}
167
168/// Return the Einstein summation convention of given tensors.
169///
170/// For more about the Einstein summation convention, please refer to
171/// [NumPy's documentation][einsum].
172///
173/// # Example
174///
175/// ```
176/// use pyo3::{Python, Bound};
177/// use ndarray::array;
178/// use numpy::{einsum, pyarray, PyArray, PyArray2, PyArrayMethods};
179///
180/// Python::with_gil(|py| {
181/// let tensor = PyArray::arange(py, 0, 2 * 3 * 4, 1).reshape([2, 3, 4]).unwrap();
182/// let another_tensor = pyarray![py, [20, 30], [40, 50], [60, 70]];
183///
184/// let result: Bound<'_, PyArray2<_>> = einsum!("ijk,ji->ik", tensor, another_tensor).unwrap();
185///
186/// assert_eq!(
187/// result.readonly().as_array(),
188/// array![[640, 760, 880, 1000], [2560, 2710, 2860, 3010]]
189/// );
190/// });
191/// ```
192///
193/// [einsum]: https://numpy.org/doc/stable/reference/generated/numpy.einsum.html
194#[macro_export]
195macro_rules! einsum {
196 ($subscripts:literal $(,$array:ident)+ $(,)*) => {{
197 let arrays = [$($array.to_dyn().as_borrowed(),)+];
198 $crate::einsum(concat!($subscripts, "\0"), &arrays)
199 }};
200}