DataFormatter

Operational class powering the format_data function.

class DataFormatter(ABCParse.ABCParse):
    """Format data to interface with numpy or torch, on a specified device."""
    def __init__(self, data: Union[_torch.Tensor, np.ndarray], *args, **kwargs):
        self.__parse__(locals())

    @property
    def device_type(self) -> str:
        """Returns device type"""
        if hasattr(self.data, "device"):
            return self.data.device.type
        return "cpu"

    @property
    def is_ArrayView(self) -> bool:
        """Checks if device is of type ArrayView"""
        return isinstance(self.data, anndata._core.views.ArrayView)

    @property
    def is_numpy_array(self) -> bool:
        """Checks if device is of type np.ndarray"""
        return isinstance(self.data, np.ndarray)

    @property
    def is_torch_Tensor(self) -> bool:
        """Checks if device is of type torch.Tensor"""
        return isinstance(self.data, _torch.Tensor)

    @property
    def on_cpu(self) -> bool:
        """Checks if device is on cuda or mps"""
        return self.device_type == "cpu"

    @property
    def on_gpu(self) -> bool:
        """Checks if device is on cuda or mps"""
        return self.device_type in ["cuda", "mps"]

    def to_numpy(self) -> np.ndarray:
        """Sends data to np.ndarray"""
        if self.is_torch_Tensor:
            if self.on_gpu:
                return self.data.detach().cpu().numpy()
            return self.data.numpy()
        elif self.is_ArrayView:
            return self.data.toarray()
        return self.data

    def to_torch(self, device=autodevice.AutoDevice()) -> _torch.Tensor:
        """
        Parameters
        ----------
        device: torch.device

        Returns
        -------
        torch.Tensor
        """
        self.__update__(locals())

        if self.is_torch_Tensor:
            return self.data.to(device)
        elif self.is_ArrayView:
            self.data = self.data.toarray()
        return _torch.Tensor(self.data).to(device)

GitHub: GitHub.com/mvinyard/AnnDataQuery/adata_query/_core/_formatter.py

Last updated