Source code for nitypes.complex._conversion

from __future__ import annotations

from typing import Any, TypeVar, cast, overload

import numpy as np
import numpy.typing as npt

from nitypes._arguments import validate_dtype
from nitypes._exceptions import unsupported_dtype
from nitypes.complex._dtypes import ComplexInt32DType

_Item_co = TypeVar("_Item_co", bound=Any)
_ScalarType = TypeVar("_ScalarType", bound=np.generic)
_Shape = TypeVar("_Shape", bound=tuple[int, ...])

_COMPLEX_DTYPES = (
    np.complex64,
    np.complex128,
    ComplexInt32DType,
)

_FIELD_DTYPE = {
    np.dtype(np.complex64): np.float32,
    np.dtype(np.complex128): np.float64,
    ComplexInt32DType: np.int16,
}


@overload
def convert_complex(
    requested_dtype: type[_ScalarType] | np.dtype[_ScalarType],
    value: np.ndarray[_Shape, Any],
) -> np.ndarray[_Shape, np.dtype[_ScalarType]]: ...


@overload
def convert_complex(
    requested_dtype: npt.DTypeLike, value: np.ndarray[_Shape, Any]
) -> np.ndarray[_Shape, Any]: ...


# https://numpy.org/doc/2.2/reference/typing.html#d-arrays
# "While thus not strictly correct, all operations are that can potentially perform a 0D-array ->
# scalar cast are currently annotated as exclusively returning an ndarray."
@overload
def convert_complex(
    requested_dtype: type[_ScalarType] | np.dtype[_ScalarType],
    value: np.generic[Any],
) -> np.ndarray[tuple[()], np.dtype[_ScalarType]]: ...


@overload
def convert_complex(
    requested_dtype: npt.DTypeLike,
    value: np.generic[Any],
) -> np.ndarray[tuple[()], Any]: ...


[docs] def convert_complex( requested_dtype: npt.DTypeLike, value: np.ndarray[_Shape, Any] | np.generic[Any] ) -> np.ndarray[_Shape, Any]: """Convert a NumPy array or scalar of complex numbers to the specified dtype. Args: requested_dtype: The NumPy data type to convert to. Supported data types: :any:`numpy.complex64`, :any:`numpy.complex128`, :any:`ComplexInt32DType`. value: The NumPy array or scalar to convert. Returns: The value converted to the specified dtype. """ validate_dtype(requested_dtype, _COMPLEX_DTYPES) if requested_dtype == value.dtype: return cast(np.ndarray[_Shape, Any], value) elif requested_dtype == ComplexInt32DType or value.dtype == ComplexInt32DType: # ndarray.view on scalars requires the source and destination types to have the same size, # so reshape the scalar into an 1-element array before converting and index it afterwards. # shape == () means this is either a scalar (np.generic) or a 0-dimension array, but mypy # doesn't know that. if value.shape == (): return cast( np.ndarray[_Shape, Any], _convert_complexint32_array(requested_dtype, value.reshape(1))[0], ) else: return _convert_complexint32_array( requested_dtype, cast(np.ndarray[_Shape, Any], value) ) else: return value.astype(requested_dtype)
def _convert_complexint32_array( requested_dtype: npt.DTypeLike | type[_ScalarType] | np.dtype[_ScalarType], value: np.ndarray[_Shape, Any], ) -> np.ndarray[_Shape, np.dtype[_ScalarType]]: if not isinstance(requested_dtype, np.dtype): requested_dtype = np.dtype(requested_dtype) requested_field_dtype = _FIELD_DTYPE.get(requested_dtype) if requested_field_dtype is None: raise unsupported_dtype("requested data type", requested_dtype, _COMPLEX_DTYPES) value_field_dtype = _FIELD_DTYPE.get(value.dtype) if value_field_dtype is None: raise unsupported_dtype("array data type", value.dtype, _COMPLEX_DTYPES) return value.view(value_field_dtype).astype(requested_field_dtype).view(requested_dtype)