Source code for qurry.process.classical_shadow.matrix_calcution

"""Post Processing - Classical Shadow - Matrix Calculation
(:mod:`qurry.process.classical_shadow.matrix_calcution`)

"""

from typing import Iterable, Literal, Callable, Union
import warnings
import functools as ft
import numpy as np

from .unitary_set import PRECOMPUTED_RHO_M_K_I, PRECOMPUTED_RHO_M_K_I_2
from ..availability import availablility
from ..exceptions import (
    PostProcessingThirdPartyImportError,
    PostProcessingThirdPartyUnavailableWarning,
)

try:
    import jax
    import jax.numpy as jnp

    jax.config.update("jax_enable_x64", True)
    # =========================================================
    # This is required to handle the complex128 dtype in JAX.
    # Or the result of JAX will be not same as Numpy.
    # =========================================================

    # trace summation calculation
    def all_trace_rho_by_einsum_aij_bji_to_ab_jax(
        rho_m_array: np.ndarray[tuple[int, int, int], np.dtype[np.complex128]],
    ) -> np.complex128:
        """The trace of Rho by einsum_aij_bji_to_ab by JAX.

        This is the fastest implementation to calculate the trace of Rho.

        Args:
            rho_m_array (np.ndarray[tuple[int, int, int], np.dtype[np.complex128]]):
                The Rho M array.

        Returns:
            np.complex128: The trace of Rho.
        """
        len_rho_m_array = len(rho_m_array)
        trace_matrix = jnp.einsum("aij,bji -> ab", rho_m_array, rho_m_array)

        mask = np.ones(trace_matrix.shape, dtype=bool)
        np.fill_diagonal(mask, False)

        sum_off_diagonal = trace_matrix[mask].sum()
        return np.complex128(sum_off_diagonal / (len_rho_m_array * (len_rho_m_array - 1)))

    def prediction_einsum_aij_bji_to_ab_jax(
        given_operators: np.ndarray[tuple[int, int, int], np.dtype[np.complex128]],
        estimators: np.ndarray[tuple[int, int, int], np.dtype[np.complex128]],
    ) -> tuple[list[np.complex128], list[np.ndarray[tuple[int, int], np.dtype[np.complex128]]]]:
        """Calculate the prediction of given operators by einsum_aij_bji_to_ab_jax.

        Args:
            given_operators (np.ndarray[tuple[int, int, int], np.dtype[np.complex128]]):
                The given operators.
            estimators (np.ndarray[tuple[int, int, int], np.dtype[np.complex128]]):
                The estimators.

        Returns:
            tuple[list[np.complex128], list[np.ndarray[tuple[int, int], np.dtype[np.complex128]]]]:
                A tuple containing:

                - A list of median values for each given operator.
                - A list of the corresponding median estimators for each given operator.
        """
        candidate_esitmators_foreach_given_operator = jnp.einsum(
            "aij,bji->ab", given_operators, estimators
        )
        median_foreach_given_operator = np.median(
            candidate_esitmators_foreach_given_operator, axis=1
        )
        median_location_given_operator = np.argmin(
            np.abs(
                candidate_esitmators_foreach_given_operator - median_foreach_given_operator[:, None]
            ),
            axis=1,
        )

        return list(median_foreach_given_operator), [
            np.array(candidate_esitmators_foreach_given_operator[i, j], dtype=np.complex128)
            for i, j in enumerate(median_location_given_operator)
        ]  # type: ignore

    def set_cpu_only():
        """Set JAX to use CPU only."""
        if not jax.config.values["jax_platforms"]:
            jax.config.update("jax_platforms", "cpu")

    JAX_AVAILABLE = True
    FAILED_JAX_IMPORT = None
except ImportError as err:
    JAX_AVAILABLE = False
    FAILED_JAX_IMPORT = err

    # trace summation calculation
[docs] def all_trace_rho_by_einsum_aij_bji_to_ab_jax( rho_m_array: np.ndarray[tuple[int, int, int], np.dtype[np.complex128]], ) -> np.complex128: """The trace of Rho by einsum_aij_bji_to_ab by JAX. This is the fastest implementation to calculate the trace of Rho. Args: rho_m_array (np.ndarray[tuple[int, int, int], np.dtype[np.complex128]]): The Rho M array. Returns: np.complex128: The trace of Rho. """ raise PostProcessingThirdPartyImportError( "JAX is not available, using numpy to calculate einsum_aij_bji_to_ab." + "error: " + str(FAILED_JAX_IMPORT) ) from FAILED_JAX_IMPORT
[docs] def prediction_einsum_aij_bji_to_ab_jax( given_operators: np.ndarray[tuple[int, int, int], np.dtype[np.complex128]], estimators: np.ndarray[tuple[int, int, int], np.dtype[np.complex128]], ) -> tuple[list[np.complex128], list[np.ndarray[tuple[int, int], np.dtype[np.complex128]]]]: """Calculate the prediction of given operators by einsum_aij_bji_to_ab_jax. Args: given_operators (np.ndarray[tuple[int, int, int], np.dtype[np.complex128]]): The given operators. estimators (np.ndarray[tuple[int, int, int], np.dtype[np.complex128]]): The estimators. Returns: tuple[list[np.complex128], list[np.ndarray[tuple[int, int], np.dtype[np.complex128]]]]: A tuple containing: - A list of median values for each given operator. - A list of the corresponding median estimators for each given operator. """ raise PostProcessingThirdPartyImportError( "JAX is not available, using numpy to calculate prediction_einsum_aij_bji_to_ab." + "error: " + str(FAILED_JAX_IMPORT) ) from FAILED_JAX_IMPORT
[docs] def set_cpu_only(): """Set JAX to use CPU only.""" warnings.warn( "JAX is not available, nothing to set." + "error: " + str(FAILED_JAX_IMPORT), PostProcessingThirdPartyUnavailableWarning, )
BACKEND_AVAILABLE = availablility( "classical_shadow.array_process", [ ("jax", JAX_AVAILABLE, FAILED_JAX_IMPORT), ], ) ClassicalShadowPythonMethod = Literal["jax", "numpy"] """The method to use for the calculation of classical shadow. It can be either "jax" or "numpy". - "jax": Use JAX to calculate the Kronecker product. - "numpy": Use Numpy to calculate the Kronecker product. """ DEFAULT_PYTHON_METHOD: ClassicalShadowPythonMethod = "jax" if JAX_AVAILABLE else "numpy" """The default backend to use for the calculation of classical shadow. It can be either "jax" or "numpy". - "jax": Use JAX to calculate the Kronecker product. - "numpy": Use Numpy to calculate the Kronecker product. """ # kronecker product calculation
[docs] def rho_mki_kronecker_product_numpy( key_list_of_precomputed: list[tuple[int, str]], ) -> np.ndarray[tuple[int, int], np.dtype[np.complex128]]: r"""Kronecker product for :math:`\rho_{mki}` by Numpy. Args: key_list_of_precomputed (list[tuple[int, str]]): The list of the keys of the precomputed :math:`\rho_{mki}`. Returns: NDArray[np.complex128]: The Kronecker product of the :math:`\rho_{mki}`. """ return ft.reduce( np.kron, [PRECOMPUTED_RHO_M_K_I[key] for key in key_list_of_precomputed] ) # type: ignore
[docs] def rho_mki_kronecker_product_numpy_2( key_list_of_precomputed: Iterable[int], ) -> np.ndarray[tuple[int, int], np.dtype[np.complex128]]: r"""Kronecker product for :math:`\rho_{mki}` by Numpy. Args: key_list_of_precomputed (Iterable[int]): The list of the keys of the precomputed :math:`\rho_{mki}`. Returns: NDArray[np.complex128]: The Kronecker product of the :math:`\rho_{mki}`. """ return ft.reduce( np.kron, [PRECOMPUTED_RHO_M_K_I_2[key] for key in key_list_of_precomputed] ) # type: ignore
# single trace calculation
[docs] def single_trace_rho_by_trace_of_matmul( rho_m1_and_rho_m2: tuple[ np.ndarray[tuple[int, int], np.dtype[np.complex128]], np.ndarray[tuple[int, int], np.dtype[np.complex128]], ], ) -> np.complex128: """The single trace of Rho by trace of matmul. Args: rho_m1_and_rho_m2 (tuple[ np.ndarray[tuple[int, int], np.dtype[np.complex128]], np.ndarray[tuple[int, int], np.dtype[np.complex128]], ]): The tuple of rho_m1 and rho_m2. Returns: np.complex128: The trace of Rho. """ rho_m1, rho_m2 = rho_m1_and_rho_m2 return np.trace((rho_m1 @ rho_m2)) + np.trace((rho_m2 @ rho_m1))
[docs] def single_trace_rho_by_einsum_ij_ji( rho_m1_and_rho_m2: tuple[ np.ndarray[tuple[int, int], np.dtype[np.complex128]], np.ndarray[tuple[int, int], np.dtype[np.complex128]], ], ) -> np.complex128: """The single trace of Rho by einsum_ij_ji by Numpy. Args: rho_m1_and_rho_m2 (tupletuple[ np.ndarray[tuple[int, int], np.dtype[np.complex128]], np.ndarray[tuple[int, int], np.dtype[np.complex128]], ]): The tuple of rho_m1 and rho_m2. Returns: np.complex128: The trace of Rho. """ rho_m1, rho_m2 = rho_m1_and_rho_m2 return np.einsum("ij,ji", rho_m1, rho_m2) + np.einsum("ij,ji", rho_m2, rho_m1)
SingleTraceRhoMethod = Union[ Literal[ "trace_of_matmul", "quick_trace_of_matmul", "einsum_ij_ji", ], str, ] """The method to calculate the trace of single Rho square. - "trace_of_matmul": Use np.trace(np.matmul(rho_m1, rho_m2)) to calculate the trace. - "quick_trace_of_matmul" or "einsum_ij_ji": Use np.einsum("ij,ji", rho_m1, rho_m2) to calculate the trace. """
[docs] def select_single_trace_rho_method( method: SingleTraceRhoMethod = "quick_trace_of_matmul", ) -> Callable[ [ tuple[ np.ndarray[tuple[int, int], np.dtype[np.complex128]], np.ndarray[tuple[int, int], np.dtype[np.complex128]], ], ], np.complex128, ]: """Select the method to calculate the trace of Rho square. Args: method (str): The method to use for the calculation. Returns: Callable[[tuple[ np.ndarray[tuple[int, int], np.dtype[np.complex128]], np.ndarray[tuple[int, int], np.dtype[np.complex128]], ]], np.complex128]: The function to calculate the trace of Rho. """ if method == "trace_of_matmul": return single_trace_rho_by_trace_of_matmul if method in ("quick_trace_of_matmul", "einsum_ij_ji"): return single_trace_rho_by_einsum_ij_ji raise ValueError(f"Invalid method: {method}")
# trace summation calculation
[docs] def all_trace_rho_by_einsum_aij_bji_to_ab_numpy( rho_m_array: np.ndarray[tuple[int, int, int], np.dtype[np.complex128]], ) -> np.complex128: """The trace of Rho by einsum_aij_bji_to_ab. This is the fastest implementation to calculate the trace of Rho. Args: rho_m_array (np.ndarray[tuple[int, int, int], np.dtype[np.complex128]]): The Rho M array. Returns: np.complex128: The trace of Rho. """ len_rho_m_array = len(rho_m_array) trace_matrix = np.einsum("aij,bji -> ab", rho_m_array, rho_m_array) mask = np.ones(trace_matrix.shape, dtype=bool) np.fill_diagonal(mask, False) sum_off_diagonal = trace_matrix[mask].sum() return sum_off_diagonal / (len_rho_m_array * (len_rho_m_array - 1))
AllTraceRhoMethod = Union[Literal["einsum_aij_bji_to_ab_numpy", "einsum_aij_bji_to_ab_jax"], str] """The method to calculate the all trace of Rho square. - "einsum_aij_bji_to_ab_numpy": Use np.einsum("aij,bji->ab", rho_m_list, rho_m_list) to calculate the trace. This is the fastest implementation to calculate the trace of Rho if JAX is not available. - "einsum_aij_bji_to_ab_jax": Use jnp.einsum("aij,bji->ab", rho_m_list, rho_m_list) to calculate the trace. This is the fastest implementation to calculate the trace of Rho. """ DEFAULT_ALL_TRACE_RHO_METHOD: AllTraceRhoMethod = ( "einsum_aij_bji_to_ab_jax" if JAX_AVAILABLE else "einsum_aij_bji_to_ab_numpy" )
[docs] def select_all_trace_rho_by_einsum_aij_bji_to_ab( method: AllTraceRhoMethod = DEFAULT_ALL_TRACE_RHO_METHOD, ) -> Callable[ [np.ndarray[tuple[int, int, int], np.dtype[np.complex128]]], np.complex128, ]: """Select the method to calculate the trace of Rho square. Args: method (AllTraceRhoMethod, optional): The method to use for the calculation. Defaults to DEFAULT_ALL_TRACE_RHO_METHOD. - "einsum_aij_bji_to_ab_numpy": Use np.einsum("aij,bji->ab", rho_m_list, rho_m_list) to calculate the trace. - "einsum_aij_bji_to_ab_jax": Use jnp.einsum("aij,bji->ab", rho_m_list, rho_m_list) to calculate the trace. This is the fastest implementation to calculate the trace of Rho. Returns: Callable[[np.ndarray[tuple[int, int, int], np.dtype[np.complex128]]], np.complex128]: The function to calculate the trace of Rho. """ if method == "einsum_aij_bji_to_ab_jax": if JAX_AVAILABLE: return all_trace_rho_by_einsum_aij_bji_to_ab_jax warnings.warn( "JAX is not available, using numpy to calculate all trace.", PostProcessingThirdPartyUnavailableWarning, ) if method != "einsum_aij_bji_to_ab_numpy": raise ValueError(f"Invalid backend: {method}") return all_trace_rho_by_einsum_aij_bji_to_ab_numpy
[docs] def prediction_einsum_aij_bji_to_ab_numpy( given_operators: np.ndarray[tuple[int, int, int], np.dtype[np.complex128]], estimators: np.ndarray[tuple[int, int, int], np.dtype[np.complex128]], ) -> tuple[list[np.complex128], list[np.ndarray[tuple[int, int], np.dtype[np.complex128]]]]: """Calculate the prediction of given operators by einsum_aij_bji_to_ab_numpy. Args: given_operators (np.ndarray[tuple[int, int, int], np.dtype[np.complex128]]): The given operators. estimators (np.ndarray[tuple[int, int, int], np.dtype[np.complex128]]): The estimators. Returns: tuple[list[np.complex128], list[np.ndarray[tuple[int, int], np.dtype[np.complex128]]]]: A tuple containing: - A list of median values for each given operator. - A list of the corresponding median estimators for each given operator. """ candidate_esitmators_foreach_given_operator = np.einsum( "aij,bji->ab", given_operators, estimators ) median_foreach_given_operator = np.median(candidate_esitmators_foreach_given_operator, axis=1) median_location_given_operator = np.argmin( np.abs( candidate_esitmators_foreach_given_operator - median_foreach_given_operator[:, None] ), axis=1, ) return list(median_foreach_given_operator), [ candidate_esitmators_foreach_given_operator[i, j] for i, j in enumerate(median_location_given_operator) ]
[docs] def select_prediction_einsum_aij_bji_to_ab( method: AllTraceRhoMethod = DEFAULT_ALL_TRACE_RHO_METHOD, ) -> Callable[ [ np.ndarray[tuple[int, int, int], np.dtype[np.complex128]], np.ndarray[tuple[int, int, int], np.dtype[np.complex128]], ], tuple[list[np.complex128], list[np.ndarray[tuple[int, int], np.dtype[np.complex128]]]], ]: """Select the method to calculate the prediction of given operators. Args: method (AllTraceRhoMethod, optional): The method to use for the calculation. Defaults to DEFAULT_ALL_TRACE_RHO_METHOD It can be either "jax" or "numpy". Returns: Callable[[ np.ndarray[tuple[int, int, int], np.dtype[np.complex128]], np.ndarray[tuple[int, int, int], np.dtype[np.complex128]] ], tuple[list[np.complex128], list[np.ndarray[tuple[int, int], np.dtype[np.complex128]]]]]: The function to calculate the prediction of given operators. """ if method == "einsum_aij_bji_to_ab_jax": if JAX_AVAILABLE: return prediction_einsum_aij_bji_to_ab_jax warnings.warn( "JAX is not available, using numpy to calculate prediction.", PostProcessingThirdPartyUnavailableWarning, ) if method != "einsum_aij_bji_to_ab_numpy": raise ValueError(f"Invalid backend: {method}") return prediction_einsum_aij_bji_to_ab_numpy