|
@@ -2,7 +2,7 @@
|
|
|
"cells": [
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 1,
|
|
|
+ "execution_count": null,
|
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
@@ -13,7 +13,7 @@
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 2,
|
|
|
+ "execution_count": null,
|
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
@@ -37,25 +37,25 @@
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 3,
|
|
|
+ "execution_count": null,
|
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
|
"IMAGENET = Path('data/imagenet/ILSVRC/Data/CLS-LOC')\n",
|
|
|
"BWIMAGENET = Path('data/imagenet/ILSVRC/Data/CLS-LOC/bandw')\n",
|
|
|
"\n",
|
|
|
- "proj_id = 'colorizeV2h'\n",
|
|
|
+ "proj_id = 'colorizeV5o'\n",
|
|
|
"TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id)\n",
|
|
|
"\n",
|
|
|
- "gpath = IMAGENET.parent/(proj_id + '_gen_128.h5')\n",
|
|
|
- "dpath = IMAGENET.parent/(proj_id + '_critic_128.h5')\n",
|
|
|
+ "gpath = IMAGENET.parent/(proj_id + '_gen_64.h5')\n",
|
|
|
+ "dpath = IMAGENET.parent/(proj_id + '_critic_64.h5')\n",
|
|
|
"\n",
|
|
|
"torch.backends.cudnn.benchmark=True"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 4,
|
|
|
+ "execution_count": null,
|
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
@@ -75,7 +75,7 @@
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 5,
|
|
|
+ "execution_count": null,
|
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
@@ -85,7 +85,7 @@
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 6,
|
|
|
+ "execution_count": null,
|
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
@@ -95,18 +95,18 @@
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 7,
|
|
|
+ "execution_count": null,
|
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
|
"def get_data(sz:int, bs:int, keep_pct:float):\n",
|
|
|
" return get_colorize_data(sz=sz, bs=bs, crappy_path=BWIMAGENET, good_path=IMAGENET, \n",
|
|
|
- " random_seed=None, keep_pct=keep_pct,num_workers=32)"
|
|
|
+ " random_seed=None, keep_pct=keep_pct,num_workers=16)"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
- "execution_count": 8,
|
|
|
+ "execution_count": null,
|
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
@@ -126,6 +126,17 @@
|
|
|
" learn_crit.load(proj_id + '_crit_' + str(sz))"
|
|
|
]
|
|
|
},
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "def colorize_gen_learner_exp(data:ImageDataBunch, gen_loss=FeatureLoss(), arch=models.resnet34):\n",
|
|
|
+ " return unet_learner2(data, arch, wd=1e-3, blur=True, norm_type=NormType.Spectral,\n",
|
|
|
+ " self_attention=True, y_range=(-3.,3.), loss_func=gen_loss)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
{
|
|
|
"cell_type": "markdown",
|
|
|
"metadata": {},
|
|
@@ -141,17 +152,17 @@
|
|
|
"source": [
|
|
|
"#Needed to instantiate critic but not actually used\n",
|
|
|
"sz=64\n",
|
|
|
- "bs=128\n",
|
|
|
+ "bs=32\n",
|
|
|
"\n",
|
|
|
"data = get_data(sz=sz, bs=bs, keep_pct=1.0)\n",
|
|
|
- "learn_crit = colorize_crit_learner(data=data, nf=128)\n",
|
|
|
+ "learn_crit = colorize_crit_learner(data=data, nf=256)\n",
|
|
|
"learn_crit.unfreeze()\n",
|
|
|
"\n",
|
|
|
- "gen_loss = FeatureLoss()\n",
|
|
|
- "learn_gen = colorize_gen_learner(data=data, gen_loss=gen_loss, arch=models.resnet34)\n",
|
|
|
+ "gen_loss = FeatureLoss2(gram_wgt=5e3)\n",
|
|
|
+ "learn_gen = colorize_gen_learner_exp(data=data)\n",
|
|
|
"\n",
|
|
|
"switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)\n",
|
|
|
- "learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.,0.15), show_img=False, switcher=switcher,\n",
|
|
|
+ "learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.0,1.0), show_img=False, switcher=switcher,\n",
|
|
|
" opt_func=partial(optim.Adam, betas=(0.,0.99)), wd=1e-3)\n",
|
|
|
"\n",
|
|
|
"learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))\n",
|
|
@@ -172,64 +183,7 @@
|
|
|
"cell_type": "code",
|
|
|
"execution_count": null,
|
|
|
"metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "data": {
|
|
|
- "text/html": [
|
|
|
- "\n",
|
|
|
- " <div>\n",
|
|
|
- " <style>\n",
|
|
|
- " /* Turns off some styling */\n",
|
|
|
- " progress {\n",
|
|
|
- " /* gets rid of default border in Firefox and Opera. */\n",
|
|
|
- " border: none;\n",
|
|
|
- " /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
|
|
|
- " background-size: auto;\n",
|
|
|
- " }\n",
|
|
|
- " .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
|
|
|
- " background: #F44336;\n",
|
|
|
- " }\n",
|
|
|
- " </style>\n",
|
|
|
- " <progress value='0' class='' max='1', style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
|
|
- " 0.00% [0/1 00:00<00:00]\n",
|
|
|
- " </div>\n",
|
|
|
- " \n",
|
|
|
- "<table style='width:300px; margin-bottom:10px'>\n",
|
|
|
- " <tr>\n",
|
|
|
- " <th>epoch</th>\n",
|
|
|
- " <th>train_loss</th>\n",
|
|
|
- " <th>gen_loss</th>\n",
|
|
|
- " <th>disc_loss</th>\n",
|
|
|
- " </tr>\n",
|
|
|
- "</table>\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- " <div>\n",
|
|
|
- " <style>\n",
|
|
|
- " /* Turns off some styling */\n",
|
|
|
- " progress {\n",
|
|
|
- " /* gets rid of default border in Firefox and Opera. */\n",
|
|
|
- " border: none;\n",
|
|
|
- " /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
|
|
|
- " background-size: auto;\n",
|
|
|
- " }\n",
|
|
|
- " .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
|
|
|
- " background: #F44336;\n",
|
|
|
- " }\n",
|
|
|
- " </style>\n",
|
|
|
- " <progress value='4938' class='' max='9684', style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
|
|
- " 50.99% [4938/9684 1:32:47<1:29:10 1.0264]\n",
|
|
|
- " </div>\n",
|
|
|
- " "
|
|
|
- ],
|
|
|
- "text/plain": [
|
|
|
- "<IPython.core.display.HTML object>"
|
|
|
- ]
|
|
|
- },
|
|
|
- "metadata": {},
|
|
|
- "output_type": "display_data"
|
|
|
- }
|
|
|
- ],
|
|
|
+ "outputs": [],
|
|
|
"source": [
|
|
|
"learn_gen.freeze_to(-1)\n",
|
|
|
"learn.fit(1,lr)"
|
|
@@ -276,9 +230,10 @@
|
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
|
+ "load()\n",
|
|
|
"lr=lr/2\n",
|
|
|
"sz=96\n",
|
|
|
- "bs=bs//2"
|
|
|
+ "#bs=bs//2"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
@@ -448,7 +403,16 @@
|
|
|
"source": [
|
|
|
"lr=lr/1.5\n",
|
|
|
"sz=160\n",
|
|
|
- "bs=bs//1.5"
|
|
|
+ "bs=int(bs//1.5)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "bs=10"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
@@ -533,7 +497,7 @@
|
|
|
"source": [
|
|
|
"lr=lr/1.5\n",
|
|
|
"sz=192\n",
|
|
|
- "bs=bs//1.5"
|
|
|
+ "bs=int(bs//1.5)"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
@@ -618,7 +582,7 @@
|
|
|
"source": [
|
|
|
"lr=lr/1.5\n",
|
|
|
"sz=224\n",
|
|
|
- "bs=bs//1.5"
|
|
|
+ "bs=int(bs//1.5)"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
@@ -687,6 +651,13 @@
|
|
|
"source": [
|
|
|
"save()"
|
|
|
]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": []
|
|
|
}
|
|
|
],
|
|
|
"metadata": {
|