|
@@ -53,29 +53,115 @@
|
|
|
"## Setup"
|
|
|
]
|
|
|
},
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "### Activate Large Model Support for PyTorch\n",
|
|
|
+ "This will allow us to fit the model within a GPU with smaller memory capacity (e.g. GTX 1070 8Gb)."
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "Large Model Support (LMS) is a feature provided in IBM Watson Machine Learning Community Edition (WML-CE) PyTorch V1.1.0 that allows the successful training of deep learning models that would otherwise exhaust GPU memory and abort with “out-of-memory” errors. LMS manages this oversubscription of GPU memory by temporarily swapping tensors to host memory when they are not needed. One or more elements of a deep learning model can lead to GPU memory exhaustion.\n",
|
|
|
+ "\n",
|
|
|
+ "Requires the use of IBM WML-CE (Available here: https://www.ibm.com/support/knowledgecenter/en/SS5SF7_1.6.1/welcome/welcome.html)\n",
|
|
|
+ "\n",
|
|
|
+ "Further Reading on PyTorch with Large Model Support: https://www.ibm.com/support/knowledgecenter/en/SS5SF7_1.6.1/welcome/welcome.html"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "import shutil"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "# Set limit of GPU used before swapping to tensors to host memory\n",
|
|
|
+ "max_gpu_mem = 7\n",
|
|
|
+ "\n",
|
|
|
+ "def gb_to_bytes(gb):\n",
|
|
|
+ " return gb*1024*1024*1024\n",
|
|
|
+ "\n",
|
|
|
+ "# Enable PyTorch LMS\n",
|
|
|
+ "torch.cuda.set.enabled_lms(True)\n",
|
|
|
+ "# Set LMS limit\n",
|
|
|
+ "torch.cuda.set_limit_lms(gb_to_bytes(max_gpu_mem))"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "# Check LMS is enabled\n",
|
|
|
+ "torch.cuda.get_enabled_lms()"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "# Check LMS Limit has been set\n",
|
|
|
+ "torch.cuda.get_limit_lms()"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ " "
|
|
|
+ ]
|
|
|
+ },
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
"execution_count": null,
|
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
|
+ "# Path to Training Data\n",
|
|
|
"path = Path('data/imagenet/ILSVRC/Data/CLS-LOC')\n",
|
|
|
"path_hr = path\n",
|
|
|
- "path_lr = path/'bandw'\n",
|
|
|
"\n",
|
|
|
+ "# Path to Black and White images\n",
|
|
|
+ "path_bandw = Path('/training/DeOldify')\n",
|
|
|
+ "path_lr = path_bandw/'bandw'\n",
|
|
|
+ "\n",
|
|
|
+ "# Name of Model\n",
|
|
|
"proj_id = 'StableModel'\n",
|
|
|
"\n",
|
|
|
+ "# Name of Generator\n",
|
|
|
"gen_name = proj_id + '_gen'\n",
|
|
|
"pre_gen_name = gen_name + '_0'\n",
|
|
|
+ "\n",
|
|
|
+ "# Name of Critic\n",
|
|
|
"crit_name = proj_id + '_crit'\n",
|
|
|
"\n",
|
|
|
+ "# Name of Generated Images folder, located within the Black and White folder\n",
|
|
|
"name_gen = proj_id + '_image_gen'\n",
|
|
|
"path_gen = path/name_gen\n",
|
|
|
"\n",
|
|
|
+ "# Path to tensorboard data\n",
|
|
|
"TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id)\n",
|
|
|
"\n",
|
|
|
"nf_factor = 2\n",
|
|
|
- "pct_start = 1e-8"
|
|
|
+ "pct_start = 1e-8\n",
|
|
|
+ "\n",
|
|
|
+ "# Number of workers for DataLoader\n",
|
|
|
+ "num_works = 2"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
@@ -86,7 +172,7 @@
|
|
|
"source": [
|
|
|
"def get_data(bs:int, sz:int, keep_pct:float):\n",
|
|
|
" return get_colorize_data(sz=sz, bs=bs, crappy_path=path_lr, good_path=path_hr, \n",
|
|
|
- " random_seed=None, keep_pct=keep_pct)\n",
|
|
|
+ " random_seed=None, keep_pct=keep_pct, num_workers=num_works)\n",
|
|
|
"\n",
|
|
|
"def get_crit_data(classes, bs, sz):\n",
|
|
|
" src = ImageList.from_folder(path, include=classes, recurse=True).random_split_by_pct(0.1, seed=42)\n",
|
|
@@ -172,7 +258,7 @@
|
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
|
- "bs=88\n",
|
|
|
+ "bs=88 # This can be increased if using PyTorch LMS, training could be slower.\n",
|
|
|
"sz=64\n",
|
|
|
"keep_pct=1.0"
|
|
|
]
|
|
@@ -262,7 +348,7 @@
|
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
|
- "bs=20\n",
|
|
|
+ "bs=40 # This can be increased if using PyTorch LMS, training could be slower.\n",
|
|
|
"sz=128\n",
|
|
|
"keep_pct=1.0"
|
|
|
]
|
|
@@ -316,7 +402,7 @@
|
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
|
- "bs=8\n",
|
|
|
+ "bs=16 # This can be increased if using PyTorch LMS, training could be slower.\n",
|
|
|
"sz=192\n",
|
|
|
"keep_pct=0.50"
|
|
|
]
|
|
@@ -357,6 +443,60 @@
|
|
|
"learn_gen.save(pre_gen_name)"
|
|
|
]
|
|
|
},
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "### 256px"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "bs=8 # This can be increased if using PyTorch LMS, training could be slower.\n",
|
|
|
+ "sz=256\n",
|
|
|
+ "keep_pct=0.50"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "learn_gen.data = get_data(sz=sz, bs=bs, keep_pct=keep_pct)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "learn_gen.unfreeze()"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "learn_gen.fit_one_cycle(1, pct_start=pct_start, max_lr=slice(5e-8,5e-5))"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "learn_gen.save(pre_gen_name)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
{
|
|
|
"cell_type": "markdown",
|
|
|
"metadata": {},
|
|
@@ -400,7 +540,7 @@
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
|
"bs=8\n",
|
|
|
- "sz=192"
|
|
|
+ "sz=256"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
@@ -460,8 +600,8 @@
|
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
|
- "bs=16\n",
|
|
|
- "sz=192"
|
|
|
+ "bs=8\n",
|
|
|
+ "sz=256"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
@@ -543,7 +683,7 @@
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
|
"lr=2e-5\n",
|
|
|
- "sz=192\n",
|
|
|
+ "sz=256\n",
|
|
|
"bs=5"
|
|
|
]
|
|
|
},
|
|
@@ -624,7 +764,7 @@
|
|
|
"name": "python",
|
|
|
"nbconvert_exporter": "python",
|
|
|
"pygments_lexer": "ipython3",
|
|
|
- "version": "3.7.0"
|
|
|
+ "version": "3.7.3"
|
|
|
}
|
|
|
},
|
|
|
"nbformat": 4,
|