Explorar o código

Added PyTorch Large Model Support and increased training to 256px

Ben Swinney %!s(int64=6) %!d(string=hai) anos
pai
achega
1b457ce22d
Modificáronse 1 ficheiros con 151 adicións e 11 borrados
  1. 151 11
      ColorizeTrainingStable.ipynb

+ 151 - 11
ColorizeTrainingStable.ipynb

@@ -53,29 +53,115 @@
     "## Setup"
    ]
   },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Activate Large Model Support for PyTorch\n",
+    "This will allow us to fit the model within a GPU with smaller memory capacity (e.g. GTX 1070 8Gb)."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Large Model Support (LMS) is a feature provided in IBM Watson Machine Learning Community Edition (WML-CE) PyTorch V1.1.0 that allows the successful training of deep learning models that would otherwise exhaust GPU memory and abort with “out-of-memory” errors. LMS manages this oversubscription of GPU memory by temporarily swapping tensors to host memory when they are not needed. One or more elements of a deep learning model can lead to GPU memory exhaustion.\n",
+    "\n",
+    "Requires the use of IBM WML-CE (Available here: https://www.ibm.com/support/knowledgecenter/en/SS5SF7_1.6.1/welcome/welcome.html)\n",
+    "\n",
+    "Further Reading on PyTorch with Large Model Support: https://www.ibm.com/support/knowledgecenter/en/SS5SF7_1.6.1/welcome/welcome.html"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import shutil"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Set limit of GPU used before swapping to tensors to host memory\n",
+    "max_gpu_mem = 7\n",
+    "\n",
+    "def gb_to_bytes(gb):\n",
+    "    return gb*1024*1024*1024\n",
+    "\n",
+    "# Enable PyTorch LMS\n",
+    "torch.cuda.set.enabled_lms(True)\n",
+    "# Set LMS limit\n",
+    "torch.cuda.set_limit_lms(gb_to_bytes(max_gpu_memory))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Check LMS is enabled\n",
+    "torch.cuda.get_enabled_lms()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Check LMS Limit has been set\n",
+    "torch.cuda.get_limit_lms()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    " "
+   ]
+  },
   {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
+    "# Path to Training Data\n",
     "path = Path('data/imagenet/ILSVRC/Data/CLS-LOC')\n",
     "path_hr = path\n",
-    "path_lr = path/'bandw'\n",
     "\n",
+    "# Path to Black and White images\n",
+    "path_bandw = Path('/training/DeOldify')\n",
+    "path_lr = path_bandw/'bandw'\n",
+    "\n",
+    "# Name of Model\n",
     "proj_id = 'StableModel'\n",
     "\n",
+    "# Name of Generator\n",
     "gen_name = proj_id + '_gen'\n",
     "pre_gen_name = gen_name + '_0'\n",
+    "\n",
+    "# Name of Critic\n",
     "crit_name = proj_id + '_crit'\n",
     "\n",
+    "# Name of Generated Images folder, located within the Black and White folder\n",
     "name_gen = proj_id + '_image_gen'\n",
     "path_gen = path/name_gen\n",
     "\n",
+    "# Path to tensorboard data\n",
     "TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id)\n",
     "\n",
     "nf_factor = 2\n",
-    "pct_start = 1e-8"
+    "pct_start = 1e-8\n",
+    "\n",
+    "# Number of workers for DataLoader\n",
+    "num_works = 2"
    ]
   },
   {
@@ -86,7 +172,7 @@
    "source": [
     "def get_data(bs:int, sz:int, keep_pct:float):\n",
     "    return get_colorize_data(sz=sz, bs=bs, crappy_path=path_lr, good_path=path_hr, \n",
-    "                             random_seed=None, keep_pct=keep_pct)\n",
+    "                             random_seed=None, keep_pct=keep_pct, num_workers=num_works)\n",
     "\n",
     "def get_crit_data(classes, bs, sz):\n",
     "    src = ImageList.from_folder(path, include=classes, recurse=True).random_split_by_pct(0.1, seed=42)\n",
@@ -172,7 +258,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "bs=88\n",
+    "bs=88 # This can be increased if using PyTorch LMS, training could be slower.\n",
     "sz=64\n",
     "keep_pct=1.0"
    ]
@@ -262,7 +348,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "bs=20\n",
+    "bs=40 # This can be increased if using PyTorch LMS, training could be slower.\n",
     "sz=128\n",
     "keep_pct=1.0"
    ]
@@ -316,7 +402,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "bs=8\n",
+    "bs=16 # This can be increased if using PyTorch LMS, training could be slower.\n",
     "sz=192\n",
     "keep_pct=0.50"
    ]
@@ -357,6 +443,60 @@
     "learn_gen.save(pre_gen_name)"
    ]
   },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### 256px"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "bs=8 # This can be increased if using PyTorch LMS, training could be slower.\n",
+    "sz=256\n",
+    "keep_pct=0.50"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_gen.data = get_data(sz=sz, bs=bs, keep_pct=keep_pct)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_gen.unfreeze()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_gen.fit_one_cycle(1, pct_start=pct_start, max_lr=slice(5e-8,5e-5))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "learn_gen.save(pre_gen_name)"
+   ]
+  },
   {
    "cell_type": "markdown",
    "metadata": {},
@@ -400,7 +540,7 @@
    "outputs": [],
    "source": [
     "bs=8\n",
-    "sz=192"
+    "sz=256"
    ]
   },
   {
@@ -460,8 +600,8 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "bs=16\n",
-    "sz=192"
+    "bs=8\n",
+    "sz=256"
    ]
   },
   {
@@ -543,7 +683,7 @@
    "outputs": [],
    "source": [
     "lr=2e-5\n",
-    "sz=192\n",
+    "sz=256\n",
     "bs=5"
    ]
   },
@@ -624,7 +764,7 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.7.0"
+   "version": "3.7.3"
   }
  },
  "nbformat": 4,