diff --git a/thinc/backends/cupy_ops.py b/thinc/backends/cupy_ops.py index 472b6c542..8a3928c17 100644 --- a/thinc/backends/cupy_ops.py +++ b/thinc/backends/cupy_ops.py @@ -1,7 +1,8 @@ import numpy +from typing import Any from .. import registry -from ..compat import cublas, cupy, cupyx +from ..compat import cublas, cupy, cupyx, has_cupy_gpu, has_torch_cuda_gpu, torch from ..types import DeviceTypes from ..util import ( is_cupy_array, @@ -346,6 +347,24 @@ def position_encode(self, N, D, period=10000, out=None): positions = NumpyOps().position_encode(N, D, period=period, out=out) return self.asarray(positions) + def has_gpu_support(self): + return has_cupy_gpu + + def set_active_gpu(self, gpu_id: int) -> Any: + if not self.has_gpu_support(): + raise ValueError("No CUDA GPU devices detected") + + device = cupy.cuda.device.Device(gpu_id) + device.use() + if has_torch_cuda_gpu: + torch.cuda.set_device(gpu_id) + + return device + + def get_default_torch_device(self): + device_id = torch.cuda.current_device() + return torch.device(f"cuda:{device_id}") + if cupy is not None: adam_kernel = cupy.ElementwiseKernel( diff --git a/thinc/backends/mps_ops.py b/thinc/backends/mps_ops.py index c6ba71f11..b8e3c96ee 100644 --- a/thinc/backends/mps_ops.py +++ b/thinc/backends/mps_ops.py @@ -3,6 +3,7 @@ import numpy from .. import registry +from ..compat import has_torch_mps_gpu, torch from .numpy_ops import NumpyOps from .ops import Ops @@ -26,3 +27,9 @@ class MPSOps(_Ops): name = "mps" xp = numpy + + def has_gpu_support(self): + return has_torch_mps_gpu + + def get_default_torch_device(self): + return torch.device('mps') diff --git a/thinc/backends/ops.py b/thinc/backends/ops.py index e3fec5c86..5c8a3f19e 100644 --- a/thinc/backends/ops.py +++ b/thinc/backends/ops.py @@ -16,6 +16,7 @@ import numpy +from ..compat import torch from ..types import ( Array1d, Array2d, @@ -1390,6 +1391,15 @@ def insert_into(self, shape, Xs): output[i, : x.shape[0]] = x return output + def has_gpu_support(self): + return False + + def set_active_gpu(self, gpu_id: int) -> Any: + return None + + def get_default_torch_device(self): + return torch.device("cpu") + """ LSTM Notation (kind of involved, but made it a lot easier to write) diff --git a/thinc/util.py b/thinc/util.py index f32f10344..3179e0d94 100644 --- a/thinc/util.py +++ b/thinc/util.py @@ -42,6 +42,7 @@ from .compat import mxnet as mx from .compat import tensorflow as tf from .compat import torch +from .config import registry DATA_VALIDATION: ContextVar[bool] = ContextVar("DATA_VALIDATION", default=False) @@ -59,17 +60,9 @@ def get_torch_default_device() -> "torch.device": raise ValueError("Cannot get default Torch device when Torch is not available.") from .backends import get_current_ops - from .backends.cupy_ops import CupyOps - from .backends.mps_ops import MPSOps ops = get_current_ops() - if isinstance(ops, CupyOps): - device_id = torch.cuda.current_device() - return torch.device(f"cuda:{device_id}") - elif isinstance(ops, MPSOps): - return torch.device("mps") - - return torch.device("cpu") + return ops.get_default_torch_device() def get_array_module(arr): # pragma: no cover @@ -87,7 +80,7 @@ def get_array_module(arr): # pragma: no cover def gpu_is_available(): - return has_gpu + return has_gpu or _find_oot_gpu_ops() != None def fix_random_seed(seed: int = 0) -> None: # pragma: no cover @@ -141,7 +134,7 @@ def is_torch_cuda_array(obj: Any) -> bool: # pragma: no cover def is_torch_gpu_array(obj: Any) -> bool: # pragma: no cover - return is_torch_cuda_array(obj) or is_torch_mps_array(obj) + return is_torch_array(obj) and not obj.is_cpu def is_torch_mps_array(obj: Any) -> bool: # pragma: no cover @@ -184,17 +177,11 @@ def to_numpy(data): # pragma: no cover def set_active_gpu(gpu_id: int) -> "cupy.cuda.Device": # pragma: no cover - """Set the current GPU device for cupy and torch (if available).""" - if not has_cupy_gpu: - raise ValueError("No CUDA GPU devices detected") - - device = cupy.cuda.device.Device(gpu_id) - device.use() - - if has_torch_cuda_gpu: - torch.cuda.set_device(gpu_id) + """Set the current GPU device for the backends and torch (if available).""" + from .backends import get_current_ops - return device + ops = get_current_ops() + return ops.set_active_gpu(gpu_id) def require_cpu() -> bool: # pragma: no cover @@ -207,14 +194,47 @@ def require_cpu() -> bool: # pragma: no cover return True -def prefer_gpu(gpu_id: int = 0) -> bool: # pragma: no cover +def prefer_gpu(gpu_id: int = 0, *, name: Optional[str] = None) -> bool: # pragma: no cover """Use GPU if it's available. Returns True if so, False otherwise.""" - if has_gpu: - require_gpu(gpu_id=gpu_id) - return has_gpu + try: + return require_gpu(gpu_id=gpu_id, name=name) + except: + pass + + return False + + +def require_gpu(gpu_id: int = 0, *, name: Optional[str] = None) -> bool: # pragma: no cover + """If name is not provided, fall back to old autodetect method.""" + if name == None: + """First look for any out of tree ops.""" + ops = _find_oot_gpu_ops() + if ops == None: + return _require_gpu_legacy(gpu_id) + else: + cls = registry.ops.get(name) + if cls == None: + raise ValueError(f"Could not find ops for device {name}") + ops = cls() + if not ops.has_gpu_support(): + raise ValueError(f"Requested ops {name} does not have GPU support") + + from .backends import set_current_ops + set_current_ops(ops) + set_active_gpu(gpu_id) + return True + + +def _find_oot_gpu_ops() -> Optional["Ops"]: + for cls in registry.ops.get_entry_points().values(): + ops = cls() + if ops.has_gpu_support(): + return ops + + return None -def require_gpu(gpu_id: int = 0) -> bool: # pragma: no cover +def _require_gpu_legacy(gpu_id: int = 0) -> bool: # pragma: no cover from .backends import CupyOps, MPSOps, set_current_ops if platform.system() == "Darwin" and not has_torch_mps: