Bladeren bron

feat(wandb): log training of the GAN

Boris Dayma 5 jaren geleden
bovenliggende
commit
487404cc41
1 gewijzigde bestanden met toevoegingen van 154 en 85 verwijderingen
  1. 154 85
      ColorizeTrainingWandb.ipynb

+ 154 - 85
ColorizeTrainingWandb.ipynb

@@ -14,7 +14,8 @@
     "#### NOTES:  \n",
     "* This is \"NoGAN\" based training, described in the DeOldify readme.\n",
     "* This model prioritizes stable and reliable renderings.  It does particularly well on portraits and landscapes.  It's not as colorful as the artistic model.\n",
-    "* Training is logged and monitored through [Weights & Biases](https://www.wandb.com/)"
+    "* Training with this notebook has been logged and monitored through [Weights & Biases](https://www.wandb.com/). Refer to [W&B Report](https://app.wandb.ai/borisd13/DeOldify/reports?view=borisd13%2FDeOldify).\n",
+    "* It is **highly** recommended to use a 11 Go GPU to run this notebook. Anything lower will require to reduce the batch size (leading to moro instability) or use of \"Large Model Support\" from IBM WML-CE (not so easy to setup). An alternative is to rent ressources online."
    ]
   },
   {
@@ -46,7 +47,6 @@
     "import fastai\n",
     "from fastai import *\n",
     "from fastai.vision import *\n",
-    "# TODELETE from fastai.callbacks.tensorboard import *\n",
     "from fastai.vision.gan import *\n",
     "from deoldify.generators import *\n",
     "from deoldify.critics import *\n",
@@ -56,6 +56,7 @@
     "from PIL import Image, ImageDraw, ImageFont\n",
     "from PIL import ImageFile\n",
     "from torch.utils.data.sampler import RandomSampler, SequentialSampler\n",
+    "from tqdm import tqdm\n",
     "import wandb\n",
     "from wandb.fastai import WandbCallback"
    ]
@@ -192,8 +193,17 @@
     "    data.c = 3\n",
     "    return data\n",
     "\n",
-    "def get_crit_data(classes, bs, sz):\n",
-    "    src = ImageList.from_folder(path, include=classes, recurse=True).random_split_by_pct(0.1, seed=42)\n",
+    "# Function to limit amount of data in critic\n",
+    "def filter_data(pct=1.0):\n",
+    "    def _f(fname):\n",
+    "        if 'test' in str(fname):\n",
+    "            if np.random.random_sample() > pct:\n",
+    "                return False\n",
+    "        return True\n",
+    "    return _f\n",
+    "\n",
+    "def get_crit_data(classes, bs, sz, pct=1.0):\n",
+    "    src = ImageList.from_folder(path, include=classes, recurse=True).filter_by_func(filter_data(pct)).random_split_by_pct(0.1)\n",
     "    ll = src.label_from_folder(classes=classes)\n",
     "    data = (ll.transform(get_transforms(max_zoom=2.), size=sz)\n",
     "           .databunch(bs=bs).normalize(imagenet_stats))\n",
@@ -207,18 +217,17 @@
     "    \n",
     "def save_preds(dl):\n",
     "    i=0\n",
-    "    names = dl.dataset.items\n",
-    "    \n",
-    "    for b in dl:\n",
+    "    names = dl.dataset.items    \n",
+    "    for b in tqdm(dl):\n",
     "        preds = learn_gen.pred_batch(batch=b, reconstruct=True)\n",
     "        for o in preds:\n",
     "            o.save(path_gen/names[i].name)\n",
     "            i += 1\n",
     "    \n",
-    "def save_gen_images():\n",
+    "def save_gen_images(keep_pct):\n",
     "    if path_gen.exists(): shutil.rmtree(path_gen)\n",
     "    path_gen.mkdir(exist_ok=True)\n",
-    "    data_gen = get_data(bs=bs, sz=sz, keep_pct=0.085)\n",
+    "    data_gen = get_data(bs=bs, sz=sz, keep_pct=keep_pct)\n",
     "    save_preds(data_gen.fix_dl)"
    ]
   },
@@ -316,7 +325,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "data_gen = get_data(bs=bs, sz=sz, random_seed=12345, valid_pct=valid_pct, epoch_size=100*bs)"
+    "data_gen = get_data(bs=bs, sz=sz, random_seed=123, valid_pct=valid_pct, epoch_size=100*bs)"
    ]
   },
   {
@@ -328,15 +337,6 @@
     "learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor)"
    ]
   },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "# TODELETE learn_gen.callback_fns.append(partial(ImageGenTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GenPre'))"
-   ]
-  },
   {
    "cell_type": "code",
    "execution_count": null,
@@ -344,9 +344,7 @@
    "outputs": [],
    "source": [
     "learn_gen.callback_fns.append(partial(WandbCallback,\n",
-    "                                      input_type='images',  # log prediction samples\n",
-    "                                      save_model=False))    # bug in get_monitor_value in fastai v1.0.51 (needed for auto saving best model)\n",
-    "                                                            # save_model default can be used if using fastai v1.0.53"
+    "                                      input_type='images'))  # log prediction samples"
    ]
   },
   {
@@ -427,7 +425,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "learn_gen.data = get_data(bs=bs, sz=sz, random_seed=12345, valid_pct=valid_pct, epoch_size=100*bs)"
+    "learn_gen.data = get_data(bs=bs, sz=sz, random_seed=123, valid_pct=valid_pct, epoch_size=100*bs)"
    ]
   },
   {
@@ -490,7 +488,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "learn_gen.data = get_data(sz=sz, bs=bs, keep_pct=keep_pct)"
+    "learn_gen.data = get_data(bs=bs, sz=sz, random_seed=123, valid_pct=valid_pct, epoch_size=100*bs)"
    ]
   },
   {
@@ -574,7 +572,22 @@
    "outputs": [],
    "source": [
     "bs=8\n",
-    "sz=192"
+    "sz=192\n",
+    "\n",
+    "# Define target number of training/validation samples as well as number of epochs\n",
+    "epoch_train_size = 100 * bs\n",
+    "epoch_valid_size = 10 * bs\n",
+    "valid_pct = epoch_valid_size / data_size\n",
+    "number_epochs = (data_size - epoch_valid_size) // epoch_train_size"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "data_gen = get_data(bs=bs, sz=sz, random_seed=123, valid_pct=valid_pct, epoch_size=100*bs)"
    ]
   },
   {
@@ -592,7 +605,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "save_gen_images()"
+    "save_gen_images(0.1)"
    ]
   },
   {
@@ -616,16 +629,24 @@
    "outputs": [],
    "source": [
     "if old_checkpoint_num == 0:\n",
+    "    \n",
+    "    # Init logging of a new run\n",
+    "    wandb.init(tags=['Pre-train Crit'])  # tags are optional\n",
+    "    \n",
     "    bs=64\n",
     "    sz=128\n",
     "    learn_gen=None\n",
-    "    gc.collect()\n",
+    "    \n",
+    "    # Log hyper parameters\n",
+    "    wandb.config.update({\"Step 1 - batch size\": bs, \"Step 1 - image size\": sz})\n",
+    "\n",
+    "    gc.collect()    \n",
     "    data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)\n",
     "    data_crit.show_batch(rows=3, ds_type=DatasetType.Train, imgsize=3)\n",
-    "    learn_critic = colorize_crit_learner(data=data_crit, nf=256)\n",
-    "    # TODELETE learn_critic.callback_fns.append(partial(LearnerTensorboardWriter, base_dir=TENSORBOARD_PATH, name='CriticPre'))\n",
-    "    learn_critic.fit_one_cycle(6, 1e-3)\n",
-    "    learn_critic.save(crit_old_checkpoint_name)"
+    "    learn_crit = colorize_crit_learner(data=data_crit, nf=256)\n",
+    "    learn_crit.callback_fns.append(partial(WandbCallback))  # log prediction samples\n",
+    "    learn_crit.fit_one_cycle(6, 1e-3)\n",
+    "    learn_crit.save(crit_old_checkpoint_name)"
    ]
   },
   {
@@ -635,7 +656,10 @@
    "outputs": [],
    "source": [
     "bs=16\n",
-    "sz=192"
+    "sz=192\n",
+    "\n",
+    "# Log hyper parameters\n",
+    "wandb.config.update({\"Step 2 - batch size\": bs, \"Step 2 - image size\": sz})"
    ]
   },
   {
@@ -662,7 +686,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "learn_critic = colorize_crit_learner(data=data_crit, nf=256).load(crit_old_checkpoint_name, with_opt=False)"
+    "learn_crit = colorize_crit_learner(data=data_crit, nf=256).load(crit_old_checkpoint_name, with_opt=False)"
    ]
   },
   {
@@ -671,7 +695,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "learn_critic.fit_one_cycle(4, 1e-4)"
+    "learn_crit.fit_one_cycle(4, 1e-4)"
    ]
   },
   {
@@ -680,7 +704,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "learn_critic.save(crit_new_checkpoint_name)"
+    "learn_crit.save(crit_new_checkpoint_name)"
    ]
   },
   {
@@ -696,8 +720,10 @@
    "metadata": {},
    "outputs": [],
    "source": [
+    "# free up memory\n",
     "learn_crit=None\n",
     "learn_gen=None\n",
+    "learn=None\n",
     "gc.collect()"
    ]
   },
@@ -707,58 +733,105 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "lr=2e-5\n",
+    "# Set old_checkpoint_num to last iteration\n",
+    "old_checkpoint_num = 0\n",
+    "save_checkpoints = False\n",
+    "batch_per_epoch = 200\n",
+    "\n",
+    "checkpoint_num = old_checkpoint_num + 1\n",
+    "gen_old_checkpoint_name = gen_name + '_' + str(old_checkpoint_num)\n",
+    "gen_new_checkpoint_name = gen_name + '_' + str(checkpoint_num)\n",
+    "crit_old_checkpoint_name = crit_name + '_' + str(old_checkpoint_num)\n",
+    "crit_new_checkpoint_name= crit_name + '_' + str(checkpoint_num)   \n",
+    "\n",
+    "if False:   # need only to do it once\n",
+    "        \n",
+    "    # Generate data\n",
+    "    print('Generating data…')\n",
+    "    bs=8\n",
+    "    sz=192\n",
+    "    epoch_train_size = batch_per_epoch * bs\n",
+    "    epoch_valid_size = batch_per_epoch * bs // 10\n",
+    "    valid_pct = epoch_valid_size / data_size\n",
+    "    data_gen = get_data(bs=bs, sz=sz, epoch_size=epoch_train_size, valid_pct=valid_pct)\n",
+    "    learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)\n",
+    "    save_gen_images(0.02)\n",
+    "\n",
+    "    # Pre-train critic\n",
+    "    print('Pre-training critic…')\n",
+    "    bs=16\n",
+    "    sz=192\n",
+    "\n",
+    "    len_test = len(list((path / 'test').rglob('*.*')))\n",
+    "    len_gen = len(list((path / name_gen).rglob('*.*')))\n",
+    "    keep_test_pct = len_gen / len_test * 2\n",
+    "\n",
+    "    data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz, pct=keep_test_pct)\n",
+    "    learn_crit = colorize_crit_learner(data=data_crit, nf=256).load(crit_old_checkpoint_name, with_opt=False)\n",
+    "    learn_crit.fit_one_cycle(1, 1e-4)\n",
+    "    learn_crit.save(crit_new_checkpoint_name)\n",
+    "\n",
+    "# Creating GAN\n",
+    "print('Creating GAN…')\n",
     "sz=192\n",
-    "bs=5"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "learn_crit = colorize_crit_learner(data=data_crit, nf=256).load(crit_new_checkpoint_name, with_opt=False)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": [
+    "bs=8\n",
+    "lr_GAN=2e-5\n",
+    "epoch_train_size = batch_per_epoch * bs\n",
+    "epoch_valid_size = batch_per_epoch * bs // 10\n",
+    "valid_pct = epoch_valid_size / data_size\n",
+    "len_test = len(list((path / 'test').rglob('*.*')))\n",
+    "len_gen = len(list((path / name_gen).rglob('*.*')))\n",
+    "keep_test_pct = len_gen / len_test * 2\n",
+    "\n",
+    "data_crit = get_crit_data([name_gen, 'test'], bs=bs, sz=sz, pct=keep_test_pct)\n",
+    "learn_crit = colorize_crit_learner(data=data_crit, nf=256).load(crit_new_checkpoint_name, with_opt=False)\n",
+    "data_gen = get_data(bs=bs, sz=sz, epoch_size=epoch_train_size, valid_pct=valid_pct)\n",
+    "learn_gen = gen_learner_wide(data=data_gen, gen_loss=FeatureLoss(), nf_factor=nf_factor).load(gen_old_checkpoint_name, with_opt=False)\n",
     "switcher = partial(AdaptiveGANSwitcher, critic_thresh=0.65)\n",
     "learn = GANLearner.from_learners(learn_gen, learn_crit, weights_gen=(1.0,1.5), show_img=False, switcher=switcher,\n",
     "                                 opt_func=partial(optim.Adam, betas=(0.,0.9)), wd=1e-3)\n",
     "learn.callback_fns.append(partial(GANDiscriminativeLR, mult_lr=5.))\n",
-    "# TODELETE learn.callback_fns.append(partial(GANTensorboardWriter, base_dir=TENSORBOARD_PATH, name='GanLearner', visual_iters=100))\n",
-    "learn.callback_fns.append(partial(GANSaveCallback, learn_gen=learn_gen, filename=gen_new_checkpoint_name, save_iters=100))"
+    "learn.callback_fns.append(partial(WandbCallback, input_type='images', seed=None, save_model=False))\n",
+    "learn.data = get_data(bs=bs, sz=sz, epoch_size=epoch_train_size, valid_pct=valid_pct)\n",
+    "\n",
+    "# Start logging to W&B\n",
+    "wandb.init(tags=['GAN'])\n",
+    "wandb.config.update({\"learning rate\": lr_GAN})  \n",
+    "\n",
+    "# Run the loop until satisfied with the results\n",
+    "while True:\n",
+    "\n",
+    "    # Current loop\n",
+    "    checkpoint_num = old_checkpoint_num + 1\n",
+    "    gen_old_checkpoint_name = gen_name + '_' + str(old_checkpoint_num)\n",
+    "    gen_new_checkpoint_name = gen_name + '_' + str(checkpoint_num)\n",
+    "    crit_old_checkpoint_name = crit_name + '_' + str(old_checkpoint_num)\n",
+    "    crit_new_checkpoint_name= crit_name + '_' + str(checkpoint_num)      \n",
+    "    \n",
+    "    \n",
+    "    # GAN for 10 epochs between each checkpoint\n",
+    "    try:\n",
+    "        learn.fit(1, lr_GAN)\n",
+    "    except:\n",
+    "        # Sometimes we get an error for some unknown reason during callbacks\n",
+    "        learn.callback_fns[-1](learn).on_epoch_end(old_checkpoint_num, None, [])\n",
+    "        \n",
+    "    if save_checkpoints:\n",
+    "        learn_crit.save(crit_new_checkpoint_name)\n",
+    "        learn_gen.save(gen_new_checkpoint_name)\n",
+    "    \n",
+    "    old_checkpoint_num += 1"
    ]
   },
   {
-   "cell_type": "markdown",
+   "cell_type": "code",
+   "execution_count": null,
    "metadata": {},
+   "outputs": [],
    "source": [
-    "#### Instructions:  \n",
-    "Find the checkpoint just before where glitches start to be introduced.  This is all very new so you may need to play around with just how far you go here with keep_pct."
+    "# End logging of current session run\n",
+    "# Note: this is optional and would be automatically triggered when stopping the kernel\n",
+    "wandb.join()"
    ]
   },
   {
@@ -766,18 +839,14 @@
    "execution_count": null,
    "metadata": {},
    "outputs": [],
-   "source": [
-    "learn.data = get_data(sz=sz, bs=bs, keep_pct=0.03)\n",
-    "learn_gen.freeze_to(-1)\n",
-    "learn.fit(1,lr)"
-   ]
+   "source": []
   }
  ],
  "metadata": {
   "kernelspec": {
-   "display_name": "Python 3",
+   "display_name": "Python (deoldify)",
    "language": "python",
-   "name": "python3"
+   "name": "deoldify"
   },
   "language_info": {
    "codemirror_mode": {
@@ -794,4 +863,4 @@
  },
  "nbformat": 4,
  "nbformat_minor": 4
-}
+}