|
@@ -4,7 +4,16 @@
|
|
|
"cell_type": "markdown",
|
|
|
"metadata": {},
|
|
|
"source": [
|
|
|
- "## Pretrained GAN"
|
|
|
+ "## Artistic Model Training"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "#### NOTES: \n",
|
|
|
+ "* This is \"NoGAN\" based training, described in the DeOldify readme.\n",
|
|
|
+ "* This model prioritizes colorful renderings. It has higher variation in renderings at different resolutions compared to the \"stable\" model"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
@@ -14,7 +23,7 @@
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
|
"import os\n",
|
|
|
- "os.environ['CUDA_VISIBLE_DEVICES']='1' "
|
|
|
+ "os.environ['CUDA_VISIBLE_DEVICES']='0' "
|
|
|
]
|
|
|
},
|
|
|
{
|
|
@@ -32,6 +41,7 @@
|
|
|
"from fasterai.critics import *\n",
|
|
|
"from fasterai.dataset import *\n",
|
|
|
"from fasterai.loss import *\n",
|
|
|
+ "from fasterai.save import *\n",
|
|
|
"from PIL import Image, ImageDraw, ImageFont\n",
|
|
|
"from PIL import ImageFile"
|
|
|
]
|
|
@@ -53,8 +63,10 @@
|
|
|
"path_hr = path\n",
|
|
|
"path_lr = path/'bandw'\n",
|
|
|
"\n",
|
|
|
- "proj_id = 'Artistic2'\n",
|
|
|
+ "proj_id = 'ArtisticModel'\n",
|
|
|
+ "\n",
|
|
|
"gen_name = proj_id + '_gen'\n",
|
|
|
+ "pre_gen_name = gen_name + '_0'\n",
|
|
|
"crit_name = proj_id + '_crit'\n",
|
|
|
"\n",
|
|
|
"name_gen = proj_id + '_image_gen'\n",
|
|
@@ -62,30 +74,8 @@
|
|
|
"\n",
|
|
|
"TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id)\n",
|
|
|
"\n",
|
|
|
- "nf_factor = 1.50\n",
|
|
|
- "pct_start = 1e-4"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "def save_all(suffix=''):\n",
|
|
|
- " learn_gen.save(gen_name + str(sz) + suffix)\n",
|
|
|
- " learn_crit.save(crit_name + str(sz) + suffix)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "def load_all(suffix=''):\n",
|
|
|
- " learn_gen.load(gen_name + str(sz) + suffix, with_opt=False)\n",
|
|
|
- " learn_crit.load(crit_name + str(sz) + suffix, with_opt=False)"
|
|
|
+ "nf_factor = 1.5\n",
|
|
|
+ "pct_start = 1e-8"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
@@ -96,42 +86,21 @@
|
|
|
"source": [
|
|
|
"def get_data(bs:int, sz:int, keep_pct:float):\n",
|
|
|
" return get_colorize_data(sz=sz, bs=bs, crappy_path=path_lr, good_path=path_hr, \n",
|
|
|
- " random_seed=None, keep_pct=keep_pct)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
+ " random_seed=None, keep_pct=keep_pct)\n",
|
|
|
+ "\n",
|
|
|
"def get_crit_data(classes, bs, sz):\n",
|
|
|
" src = ImageList.from_folder(path, include=classes, recurse=True).random_split_by_pct(0.1, seed=42)\n",
|
|
|
" ll = src.label_from_folder(classes=classes)\n",
|
|
|
" data = (ll.transform(get_transforms(max_zoom=2.), size=sz)\n",
|
|
|
" .databunch(bs=bs).normalize(imagenet_stats))\n",
|
|
|
- " return data"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "def crappify(fn,i):\n",
|
|
|
+ " return data\n",
|
|
|
+ "\n",
|
|
|
+ "def create_training_images(fn,i):\n",
|
|
|
" dest = path_lr/fn.relative_to(path_hr)\n",
|
|
|
" dest.parent.mkdir(parents=True, exist_ok=True)\n",
|
|
|
" img = PIL.Image.open(fn).convert('LA').convert('RGB')\n",
|
|
|
- " img.save(dest) "
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
+ " img.save(dest) \n",
|
|
|
+ " \n",
|
|
|
"def save_preds(dl):\n",
|
|
|
" i=0\n",
|
|
|
" names = dl.dataset.items\n",
|
|
@@ -140,15 +109,8 @@
|
|
|
" preds = learn_gen.pred_batch(batch=b, reconstruct=True)\n",
|
|
|
" for o in preds:\n",
|
|
|
" o.save(path_gen/names[i].name)\n",
|
|
|
- " i += 1"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
+ " i += 1\n",
|
|
|
+ " \n",
|
|
|
"def save_gen_images(learn_gen):\n",
|
|
|
" if path_gen.exists(): shutil.rmtree(path_gen)\n",
|
|
|
" path_gen.mkdir(exist_ok=True)\n",
|
|
@@ -161,21 +123,14 @@
|
|
|
"cell_type": "markdown",
|
|
|
"metadata": {},
|
|
|
"source": [
|
|
|
- "## Crappified data"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "markdown",
|
|
|
- "metadata": {},
|
|
|
- "source": [
|
|
|
- "Prepare the input data by crappifying images."
|
|
|
+ "## Create black and white training images"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "markdown",
|
|
|
"metadata": {},
|
|
|
"source": [
|
|
|
- "Uncomment the first time you run this notebook."
|
|
|
+ "Only runs if the directory isn't already created."
|
|
|
]
|
|
|
},
|
|
|
{
|
|
@@ -184,29 +139,31 @@
|
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
|
- "#il = ImageItemList.from_folder(path_hr)\n",
|
|
|
- "#parallel(crappify, il.items)"
|
|
|
+ "if not path_lr.exists():\n",
|
|
|
+ " il = ImageItemList.from_folder(path_hr)\n",
|
|
|
+ " parallel(create_training_images, il.items)"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "markdown",
|
|
|
"metadata": {},
|
|
|
"source": [
|
|
|
- "# Pre-training"
|
|
|
+ "## Pre-train generator"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "markdown",
|
|
|
"metadata": {},
|
|
|
"source": [
|
|
|
- "### Pre-train generator"
|
|
|
+ "#### NOTE\n",
|
|
|
+ "Most of the training takes place here in pretraining for NoGAN. The goal here is to take the generator as far as possible with conventional training, as that is much easier to control and obtain glitch-free results compared to GAN training."
|
|
|
]
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "markdown",
|
|
|
"metadata": {},
|
|
|
"source": [
|
|
|
- "Now let's pretrain the generator."
|
|
|
+ "### 64px"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
@@ -262,16 +219,7 @@
|
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
|
- "learn_gen.save(gen_name)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "learn_gen.load(gen_name, with_opt=False)"
|
|
|
+ "learn_gen.save(pre_gen_name)"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
@@ -298,16 +246,14 @@
|
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
|
- "learn_gen.save(gen_name)"
|
|
|
+ "learn_gen.save(pre_gen_name)"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
+ "cell_type": "markdown",
|
|
|
"metadata": {},
|
|
|
- "outputs": [],
|
|
|
"source": [
|
|
|
- "learn_gen.load(gen_name)"
|
|
|
+ "### 128px"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
@@ -354,7 +300,14 @@
|
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
|
- "learn_gen.save(gen_name)"
|
|
|
+ "learn_gen.save(pre_gen_name)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "### 192px"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
@@ -401,84 +354,22 @@
|
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
|
- "learn_gen.save(gen_name)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "markdown",
|
|
|
- "metadata": {},
|
|
|
- "source": [
|
|
|
- "### Save generated images"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "save_gen_images(gen_name)"
|
|
|
+ "learn_gen.save(pre_gen_name)"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "markdown",
|
|
|
"metadata": {},
|
|
|
"source": [
|
|
|
- "### Train critic"
|
|
|
+ "## Repeatable GAN Cycle"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "markdown",
|
|
|
"metadata": {},
|
|
|
"source": [
|
|
|
- "Pretrain the critic on crappy vs not crappy."
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "bs=64\n",
|
|
|
- "sz=128"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "learn_gen=None\n",
|
|
|
- "gc.collect()"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "learn_critic = colorize_crit_learner(data=data_crit, nf=256)"
|
|
|
+ "#### NOTE\n",
|
|
|
+ "Best results so far have been based on repeating the cycle below a few times (about 5-8?), until diminishing returns are hit (no improvement in image quality). Each time you repeat the cycle, you want to increment that old_checkpoint_num by 1 so that new check points don't overwrite the old. "
|
|
|
]
|
|
|
},
|
|
|
{
|
|
@@ -487,25 +378,19 @@
|
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
|
- "learn_critic.callback_fns.append(partial(LearnerTensorboardWriter, base_dir=TENSORBOARD_PATH, name='CriticPre'))"
|
|
|
+ "old_checkpoint_num = 0\n",
|
|
|
+ "checkpoint_num = old_checkpoint_num + 1\n",
|
|
|
+ "gen_old_checkpoint_name = gen_name + '_' + str(old_checkpoint_num)\n",
|
|
|
+ "gen_new_checkpoint_name = gen_name + '_' + str(checkpoint_num)\n",
|
|
|
+ "crit_old_checkpoint_name = crit_name + '_' + str(old_checkpoint_num)\n",
|
|
|
+ "crit_new_checkpoint_name= crit_name + '_' + str(checkpoint_num)"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "learn_critic.fit_one_cycle(6, 1e-3)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
+ "cell_type": "markdown",
|
|
|
"metadata": {},
|
|
|
- "outputs": [],
|
|
|
"source": [
|
|
|
- "learn_critic.save(crit_name)"
|
|
|
+ "### Save Generated Images"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
@@ -514,7 +399,7 @@
|
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
|
- "bs=16\n",
|
|
|
+ "bs=8\n",
|
|
|
"sz=192"
|
|
|
]
|
|
|
},
|
|
@@ -524,7 +409,7 @@
|
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
|
- "learn_critic.data=get_crit_data([name_gen, 'test'], bs=bs, sz=sz)"
|
|
|
+ "learn_gen = gen_learner_deep(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
@@ -533,150 +418,21 @@
|
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
|
- "learn_critic.fit_one_cycle(4, 1e-4)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "learn_critic.save(crit_name)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "markdown",
|
|
|
- "metadata": {},
|
|
|
- "source": [
|
|
|
- "## GAN"
|
|
|
+ "save_gen_images(gen_name)"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "markdown",
|
|
|
"metadata": {},
|
|
|
"source": [
|
|
|
- "Now we'll combine those pretrained model in a GAN."
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "learn_crit=None\n",
|
|
|
- "learn_gen=None\n",
|
|
|
- "gc.collect()"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "lr=1e-5\n",
|
|
|
- "sz=192\n",
|
|
|
- "bs=9"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "#placeholder- not actually used\n",
|
|
|
- "data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "learn_crit = colorize_crit_learner(data=data_crit, nf=256).load(crit_name, with_opt=False)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "learn_gen = gen_learner_deep(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_name, with_opt=False)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)\n",
|
|
|
- "learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.0,2.0), show_img=False, switcher=switcher,\n",
|
|
|
- " opt_func=partial(optim.Adam, betas=(0.,0.9)), wd=1e-3)\n",
|
|
|
- "learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))\n",
|
|
|
- "learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GanLearner', visual_iters=100))"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "for i in range(1,101):\n",
|
|
|
- " learn.data = get_data(sz=sz, bs=bs, keep_pct=0.001)\n",
|
|
|
- " learn_gen.freeze_to(-1)\n",
|
|
|
- " learn.fit(1,lr)\n",
|
|
|
- " save_all('_01_' + str(i))"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "save_all('_01')"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "learn.show_results(rows=bs)"
|
|
|
+ "### Pretrain Critic"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "markdown",
|
|
|
"metadata": {},
|
|
|
"source": [
|
|
|
- "### Save Generated Images Again"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "bs=12\n",
|
|
|
- "sz=192"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "learn_gen = gen_learner_deep(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load('ColorizeNew79_gen192_06_40', with_opt=False)"
|
|
|
+ "##### Only need full pretraining of critic when starting from scratch. Otherwise, just finetune!"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
@@ -685,14 +441,17 @@
|
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
|
- "save_gen_images(gen_name)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "markdown",
|
|
|
- "metadata": {},
|
|
|
- "source": [
|
|
|
- "### Train Critic Again"
|
|
|
+ "if old_checkpoint_num == 0:\n",
|
|
|
+ " bs=64\n",
|
|
|
+ " sz=128\n",
|
|
|
+ " learn_gen=None\n",
|
|
|
+ " gc.collect()\n",
|
|
|
+ " data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)\n",
|
|
|
+ " data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)\n",
|
|
|
+ " learn_critic = colorize_crit_learner(data=data_crit, nf=256)\n",
|
|
|
+ " learn_critic.callback_fns.append(partial(LearnerTensorboardWriter, base_dir=TENSORBOARD_PATH, name='CriticPre'))\n",
|
|
|
+ " learn_critic.fit_one_cycle(6, 1e-3)\n",
|
|
|
+ " learn_critic.save(crit_old_checkpoint_name)"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
@@ -705,16 +464,6 @@
|
|
|
"sz=192"
|
|
|
]
|
|
|
},
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "learn_gen=None\n",
|
|
|
- "gc.collect()"
|
|
|
- ]
|
|
|
- },
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
"execution_count": null,
|
|
@@ -739,7 +488,7 @@
|
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
|
- "learn_critic = colorize_crit_learner(data=data_crit, nf=256).load(crit_name + '6', with_opt=False)"
|
|
|
+ "learn_critic = colorize_crit_learner(data=data_crit, nf=256).load(crit_old_checkpoint_name, with_opt=False)"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
@@ -766,14 +515,14 @@
|
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
|
- "learn_critic.save(crit_name + '7')"
|
|
|
+ "learn_critic.save(crit_new_checkpoint_name)"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "markdown",
|
|
|
"metadata": {},
|
|
|
"source": [
|
|
|
- "### GAN Again"
|
|
|
+ "### GAN"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
@@ -813,8 +562,7 @@
|
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
|
- "#learn_crit = colorize_crit_learner(data=data_crit, nf=256).load(crit_name + '7', with_opt=False)\n",
|
|
|
- "learn_crit = colorize_crit_learner(data=data_crit, nf=256).load('ColorizeNew79_crit192_07_100', with_opt=False)"
|
|
|
+ "learn_crit = colorize_crit_learner(data=data_crit, nf=256).load(crit_new_checkpoint_name, with_opt=False)"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
@@ -823,7 +571,7 @@
|
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
|
- "learn_gen = gen_learner_deep(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load('ColorizeNew79_gen192_07_100', with_opt=False)"
|
|
|
+ "learn_gen = gen_learner_deep(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
@@ -836,27 +584,16 @@
|
|
|
"learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.0,2.0), show_img=False, switcher=switcher,\n",
|
|
|
" opt_func=partial(optim.Adam, betas=(0.,0.9)), wd=1e-3)\n",
|
|
|
"learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))\n",
|
|
|
- "learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GanLearner', visual_iters=100))"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": [
|
|
|
- "for i in range(1,101):\n",
|
|
|
- " learn.data = get_data(sz=sz, bs=bs, keep_pct=0.001)\n",
|
|
|
- " learn_gen.freeze_to(-1)\n",
|
|
|
- " learn.fit(1,lr)\n",
|
|
|
- " save_all('_07_' + str(i))"
|
|
|
+ "learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GanLearner', visual_iters=100))\n",
|
|
|
+ "learn.callback_fns.append(partial(GANSaveCallback, learn_gen=learn_gen, filename=gen_new_checkpoint_name, save_iters=100))"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "markdown",
|
|
|
"metadata": {},
|
|
|
"source": [
|
|
|
- "## fin"
|
|
|
+ "#### Instructions: \n",
|
|
|
+ "Find the checkpoint just before where glitches start to be introduced. This is all very new so you may need to play around with just how far you go here with keep_pct."
|
|
|
]
|
|
|
},
|
|
|
{
|
|
@@ -864,14 +601,11 @@
|
|
|
"execution_count": null,
|
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
|
- "source": []
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "execution_count": null,
|
|
|
- "metadata": {},
|
|
|
- "outputs": [],
|
|
|
- "source": []
|
|
|
+ "source": [
|
|
|
+ "learn.data = get_data(sz=sz, bs=bs, keep_pct=0.03)\n",
|
|
|
+ "learn_gen.freeze_to(-1)\n",
|
|
|
+ "learn.fit(1,lr)"
|
|
|
+ ]
|
|
|
}
|
|
|
],
|
|
|
"metadata": {
|
|
@@ -890,7 +624,7 @@
|
|
|
"name": "python",
|
|
|
"nbconvert_exporter": "python",
|
|
|
"pygments_lexer": "ipython3",
|
|
|
- "version": "3.7.3"
|
|
|
+ "version": "3.7.0"
|
|
|
}
|
|
|
},
|
|
|
"nbformat": 4,
|