_device.py 875 B

12345678910111213141516171819202122232425262728293031
  1. import os
  2. from enum import Enum
  3. from .device_id import DeviceId
  4. #NOTE: This must be called first before any torch imports in order to work properly!
  5. class DeviceException(Exception):
  6. pass
  7. class _Device:
  8. def __init__(self):
  9. self.set(DeviceId.CPU)
  10. def is_gpu(self):
  11. ''' Returns `True` if the current device is GPU, `False` otherwise. '''
  12. return self.current() is not DeviceID.CPU
  13. def current(self):
  14. return self._current_device
  15. def set(self, device:DeviceId):
  16. if device == DeviceId.CPU:
  17. os.environ['CUDA_VISIBLE_DEVICES']=''
  18. else:
  19. os.environ['CUDA_VISIBLE_DEVICES']=str(device.value)
  20. import torch
  21. torch.backends.cudnn.benchmark=False
  22. os.environ['OMP_NUM_THREADS']='1'
  23. self._current_device = device
  24. return device