numpy/
sum_products.rs

1use std::borrow::Cow;
2use std::ffi::{CStr, CString};
3use std::ptr::null_mut;
4
5use ndarray::{Dimension, IxDyn};
6use pyo3::types::PyAnyMethods;
7use pyo3::{Borrowed, Bound, FromPyObject, PyResult};
8
9use crate::array::PyArray;
10use crate::dtype::Element;
11use crate::npyffi::{array::PY_ARRAY_API, NPY_CASTING, NPY_ORDER};
12
13/// Return value of a function that can yield either an array or a scalar.
14pub trait ArrayOrScalar<'py, T>: FromPyObject<'py> {}
15
16impl<'py, T, D> ArrayOrScalar<'py, T> for Bound<'py, PyArray<T, D>>
17where
18    T: Element,
19    D: Dimension,
20{
21}
22
23impl<'py, T> ArrayOrScalar<'py, T> for T where T: Element + FromPyObject<'py> {}
24
25/// Return the inner product of two arrays.
26///
27/// [NumPy's documentation][inner] has the details.
28///
29/// # Examples
30///
31/// Note that this function can either return a scalar...
32///
33/// ```
34/// use pyo3::Python;
35/// use numpy::{inner, pyarray, PyArray0};
36///
37/// Python::with_gil(|py| {
38///     let vector = pyarray![py, 1.0, 2.0, 3.0];
39///     let result: f64 = inner(&vector, &vector).unwrap();
40///     assert_eq!(result, 14.0);
41/// });
42/// ```
43///
44/// ...or an array depending on its arguments.
45///
46/// ```
47/// use pyo3::{Python, Bound};
48/// use numpy::prelude::*;
49/// use numpy::{inner, pyarray, PyArray0};
50///
51/// Python::with_gil(|py| {
52///     let vector = pyarray![py, 1, 2, 3];
53///     let result: Bound<'_, PyArray0<_>> = inner(&vector, &vector).unwrap();
54///     assert_eq!(result.item(), 14);
55/// });
56/// ```
57///
58/// [inner]: https://numpy.org/doc/stable/reference/generated/numpy.inner.html
59pub fn inner<'py, T, DIN1, DIN2, OUT>(
60    array1: &Bound<'py, PyArray<T, DIN1>>,
61    array2: &Bound<'py, PyArray<T, DIN2>>,
62) -> PyResult<OUT>
63where
64    T: Element,
65    DIN1: Dimension,
66    DIN2: Dimension,
67    OUT: ArrayOrScalar<'py, T>,
68{
69    let py = array1.py();
70    let obj = unsafe {
71        let result = PY_ARRAY_API.PyArray_InnerProduct(py, array1.as_ptr(), array2.as_ptr());
72        Bound::from_owned_ptr_or_err(py, result)?
73    };
74    obj.extract()
75}
76
77/// Return the dot product of two arrays.
78///
79/// [NumPy's documentation][dot] has the details.
80///
81/// # Examples
82///
83/// Note that this function can either return an array...
84///
85/// ```
86/// use pyo3::{Python, Bound};
87/// use ndarray::array;
88/// use numpy::{dot, pyarray, PyArray2, PyArrayMethods};
89///
90/// Python::with_gil(|py| {
91///     let matrix = pyarray![py, [1, 0], [0, 1]];
92///     let another_matrix = pyarray![py, [4, 1], [2, 2]];
93///
94///     let result: Bound<'_, PyArray2<_>> = dot(&matrix, &another_matrix).unwrap();
95///
96///     assert_eq!(
97///         result.readonly().as_array(),
98///         array![[4, 1], [2, 2]]
99///     );
100/// });
101/// ```
102///
103/// ...or a scalar depending on its arguments.
104///
105/// ```
106/// use pyo3::Python;
107/// use numpy::{dot, pyarray, PyArray0};
108///
109/// Python::with_gil(|py| {
110///     let vector = pyarray![py, 1.0, 2.0, 3.0];
111///     let result: f64 = dot(&vector, &vector).unwrap();
112///     assert_eq!(result, 14.0);
113/// });
114/// ```
115///
116/// [dot]: https://numpy.org/doc/stable/reference/generated/numpy.dot.html
117pub fn dot<'py, T, DIN1, DIN2, OUT>(
118    array1: &Bound<'py, PyArray<T, DIN1>>,
119    array2: &Bound<'py, PyArray<T, DIN2>>,
120) -> PyResult<OUT>
121where
122    T: Element,
123    DIN1: Dimension,
124    DIN2: Dimension,
125    OUT: ArrayOrScalar<'py, T>,
126{
127    let py = array1.py();
128    let obj = unsafe {
129        let result = PY_ARRAY_API.PyArray_MatrixProduct(py, array1.as_ptr(), array2.as_ptr());
130        Bound::from_owned_ptr_or_err(py, result)?
131    };
132    obj.extract()
133}
134
135/// Return the Einstein summation convention of given tensors.
136///
137/// This is usually invoked via the the [`einsum!`][crate::einsum!] macro.
138pub fn einsum<'py, T, OUT>(
139    subscripts: &str,
140    arrays: &[Borrowed<'_, 'py, PyArray<T, IxDyn>>],
141) -> PyResult<OUT>
142where
143    T: Element,
144    OUT: ArrayOrScalar<'py, T>,
145{
146    let subscripts = match CStr::from_bytes_with_nul(subscripts.as_bytes()) {
147        Ok(subscripts) => Cow::Borrowed(subscripts),
148        Err(_) => Cow::Owned(CString::new(subscripts).unwrap()),
149    };
150
151    let py = arrays[0].py();
152    let obj = unsafe {
153        let result = PY_ARRAY_API.PyArray_EinsteinSum(
154            py,
155            subscripts.as_ptr() as _,
156            arrays.len() as _,
157            arrays.as_ptr() as _,
158            null_mut(),
159            NPY_ORDER::NPY_KEEPORDER,
160            NPY_CASTING::NPY_NO_CASTING,
161            null_mut(),
162        );
163        Bound::from_owned_ptr_or_err(py, result)?
164    };
165    obj.extract()
166}
167
168/// Return the Einstein summation convention of given tensors.
169///
170/// For more about the Einstein summation convention, please refer to
171/// [NumPy's documentation][einsum].
172///
173/// # Example
174///
175/// ```
176/// use pyo3::{Python, Bound};
177/// use ndarray::array;
178/// use numpy::{einsum, pyarray, PyArray, PyArray2, PyArrayMethods};
179///
180/// Python::with_gil(|py| {
181///     let tensor = PyArray::arange(py, 0, 2 * 3 * 4, 1).reshape([2, 3, 4]).unwrap();
182///     let another_tensor = pyarray![py, [20, 30], [40, 50], [60, 70]];
183///
184///     let result: Bound<'_, PyArray2<_>> = einsum!("ijk,ji->ik", tensor, another_tensor).unwrap();
185///
186///     assert_eq!(
187///         result.readonly().as_array(),
188///         array![[640,  760,  880, 1000], [2560, 2710, 2860, 3010]]
189///     );
190/// });
191/// ```
192///
193/// [einsum]: https://numpy.org/doc/stable/reference/generated/numpy.einsum.html
194#[macro_export]
195macro_rules! einsum {
196    ($subscripts:literal $(,$array:ident)+ $(,)*) => {{
197        let arrays = [$($array.to_dyn().as_borrowed(),)+];
198        $crate::einsum(concat!($subscripts, "\0"), &arrays)
199    }};
200}