Przeglądaj źródła

Modifying colorize model to work with best pretrained weights; Various related fixes

Jason Antic 6 lat temu
rodzic
commit
7c29da2b9e

+ 1 - 0
.gitignore

@@ -26,3 +26,4 @@ fasterai/__pycache__/callbacks.cpython-36.pyc
 fasterai/SymbolicLinks.sh
 SymbolicLinks.sh
 .ipynb_checkpoints/README-checkpoint.md
+.ipynb_checkpoints/ComboVisualization-checkpoint.ipynb

+ 13 - 43
ColorizeTraining.ipynb

@@ -2,7 +2,7 @@
  "cells": [
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 1,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -13,7 +13,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 2,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -41,16 +41,16 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 3,
    "metadata": {},
    "outputs": [],
    "source": [
     "IMAGENET = Path('data/imagenet/ILSVRC/Data/CLS-LOC/train')\n",
-    "proj_id = 'bwc_rc'\n",
-    "TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id)\n",
+    "proj_id = 'colorize'\n",
+    "TENSORBOARD_PATH = Path('data/tensorboard/')\n",
     "\n",
-    "#gpath = IMAGENET.parent/(proj_id + '_gen_64.h5')\n",
-    "#dpath = IMAGENET.parent/(proj_id + '_critic_64.h5')\n",
+    "gpath = IMAGENET.parent/(proj_id + '_gen_192.h5')\n",
+    "dpath = IMAGENET.parent/(proj_id + '_critic_192.h5')\n",
     "\n",
     "c_lr=5e-4\n",
     "c_lrs = np.array([c_lr,c_lr,c_lr])\n",
@@ -75,7 +75,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 4,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -90,7 +90,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 5,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -100,7 +100,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 6,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -118,10 +118,10 @@
     "\n",
     "\n",
     "#unshock\n",
-    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[128], bss=[12], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
+    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[128], bss=[8], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
     "    save_base_name=proj_id, c_lrs=c_lrs/30, g_lrs=g_lrs/30, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
     "\n",
-    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[128,128], bss=[12,12], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=keep_pcts, \n",
+    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[128,128], bss=[8,8], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=keep_pcts, \n",
     "    save_base_name=proj_id, c_lrs=c_lrs/3, g_lrs=g_lrs/3, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=gen_freeze_tos))\n",
     "\n",
     "\n",
@@ -152,37 +152,7 @@
    "cell_type": "code",
    "execution_count": null,
    "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      " 24%|██▎       | 2280/9637 [21:43<1:01:13,  2.00it/s]\n",
-      "HingeLoss 1.920006275177002; RScore 0.9806321263313293; FScore 0.9393741488456726; GAddlLoss [2.86979]; Iters: 1140; GCost: 0.033145513385534286;\n",
-      " 24%|██▍       | 2300/9637 [21:53<1:00:31,  2.02it/s]\n",
-      "HingeLoss 1.7997559309005737; RScore 1.1856446266174316; FScore 0.6141113042831421; GAddlLoss [2.85213]; Iters: 1150; GCost: 0.47137489914894104;\n",
-      " 24%|██▍       | 2320/9637 [22:04<1:00:26,  2.02it/s]\n",
-      "HingeLoss 1.593624472618103; RScore 1.3018876314163208; FScore 0.29173681139945984; GAddlLoss [3.09584]; Iters: 1160; GCost: 0.35161638259887695;\n",
-      " 24%|██▍       | 2340/9637 [22:14<1:00:34,  2.01it/s]\n",
-      "HingeLoss 1.902028203010559; RScore 0.934773325920105; FScore 0.9672548770904541; GAddlLoss [3.21826]; Iters: 1170; GCost: 0.6368169188499451;\n",
-      " 24%|██▍       | 2360/9637 [22:25<1:00:15,  2.01it/s]\n",
-      "HingeLoss 1.8702163696289062; RScore 1.2792446613311768; FScore 0.5909717679023743; GAddlLoss [2.82861]; Iters: 1180; GCost: 0.5358245968818665;\n",
-      " 25%|██▍       | 2380/9637 [22:35<1:01:08,  1.98it/s]\n",
-      "HingeLoss 1.8615854978561401; RScore 1.105886459350586; FScore 0.7556990385055542; GAddlLoss [3.24053]; Iters: 1190; GCost: 0.2869108319282532;\n",
-      " 25%|██▍       | 2400/9637 [22:46<1:01:03,  1.98it/s]\n",
-      "HingeLoss 1.534956932067871; RScore 0.9493094086647034; FScore 0.585647463798523; GAddlLoss [2.96027]; Iters: 1200; GCost: 0.24809451401233673;\n",
-      " 25%|██▌       | 2420/9637 [23:00<1:04:34,  1.86it/s]\n",
-      "HingeLoss 2.0252113342285156; RScore 1.3209781646728516; FScore 0.7042331099510193; GAddlLoss [3.30014]; Iters: 1210; GCost: 0.35852766036987305;\n",
-      " 25%|██▌       | 2440/9637 [23:11<1:00:03,  2.00it/s]\n",
-      "HingeLoss 1.677668571472168; RScore 1.265999436378479; FScore 0.41166916489601135; GAddlLoss [3.20796]; Iters: 1220; GCost: -0.5658525228500366;\n",
-      " 26%|██▌       | 2460/9637 [23:22<59:15,  2.02it/s]  \n",
-      "HingeLoss 1.7294858694076538; RScore 1.066759467124939; FScore 0.6627264022827148; GAddlLoss [3.28718]; Iters: 1230; GCost: 0.17855745553970337;\n",
-      " 26%|██▌       | 2480/9637 [23:32<58:57,  2.02it/s]  \n",
-      "HingeLoss 1.7936912775039673; RScore 0.3893392086029053; FScore 1.404352068901062; GAddlLoss [3.24405]; Iters: 1240; GCost: 0.19427792727947235;\n",
-      " 26%|██▌       | 2489/9637 [23:37<1:06:16,  1.80it/s]"
-     ]
-    }
-   ],
+   "outputs": [],
    "source": [
     "trainer.train(scheds=scheds)"
    ]

Plik diff jest za duży
+ 11 - 9
ColorizeVisualization.ipynb


+ 4 - 4
FinalVisualization.ipynb → ComboVisualization.ipynb

@@ -50,8 +50,8 @@
     "IMAGENET = Path('data/imagenet/ILSVRC/Data/CLS-LOC/train')\n",
     "IMAGENET_SMALL = IMAGENET/'n01440764'\n",
     "\n",
-    "colorizer_path = IMAGENET.parent/('bwc_rc_gen_128.h5')\n",
-    "defader_path = IMAGENET.parent/('defade_rc_gen_128.h5')\n",
+    "colorizer_path = IMAGENET.parent/('colorize_gen_192.h5')\n",
+    "defader_path = IMAGENET.parent/('colorize_gen_192.h5')\n",
     "\n",
     "default_sz=400\n",
     "torch.backends.cudnn.benchmark=True"
@@ -1847,7 +1847,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/GasPrices1939.jpg\", netG, md.val_ds, tfms=x_tfms)"
+    "vis.plot_transformed_image(\"test_images/GasPrices1939.jpg\", netG, md.val_ds, tfms=x_tfms, sz=520)"
    ]
   },
   {
@@ -1874,7 +1874,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/SchoolDance1956.jpg\", netG, md.val_ds, tfms=x_tfms)"
+    "vis.plot_transformed_image(\"test_images/SchoolDance1956.jpg\", netG, md.val_ds, tfms=x_tfms, sz=520)"
    ]
   },
   {

+ 8 - 56
DeFadeTraining.ipynb

@@ -46,14 +46,12 @@
    "outputs": [],
    "source": [
     "IMAGENET = Path('data/imagenet/ILSVRC/Data/CLS-LOC/train')\n",
-    "OPENIMAGES = Path('data/openimages')\n",
-    "CIFAR10 = Path('data/cifar10/train')\n",
     "\n",
-    "proj_id = 'defade_rc'\n",
-    "TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id)\n",
+    "proj_id = 'defade'\n",
+    "TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id + '_cont2')\n",
     "\n",
-    "#gpath = IMAGENET.parent/(proj_id + '_gen_64.h5')\n",
-    "#dpath = IMAGENET.parent/(proj_id + '_critic_64.h5')\n",
+    "gpath = IMAGENET.parent/(proj_id + '_gen_192.h5')\n",
+    "dpath = IMAGENET.parent/(proj_id + '_critic_192.h5')\n",
     "\n",
     "c_lr=5e-4\n",
     "c_lrs = np.array([c_lr,c_lr,c_lr])\n",
@@ -123,10 +121,10 @@
     "\n",
     "\n",
     "#unshock\n",
-    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[128], bss=[12], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
+    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[128], bss=[8], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
     "    save_base_name=proj_id, c_lrs=c_lrs/30, g_lrs=g_lrs/30, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
     "\n",
-    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[128,128], bss=[12,12], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=keep_pcts, \n",
+    "scheds.extend(GANTrainSchedule.generate_schedules(szs=[128,128], bss=[8,8], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=keep_pcts, \n",
     "    save_base_name=proj_id, c_lrs=c_lrs/3, g_lrs=g_lrs/3, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=gen_freeze_tos))\n",
     "\n",
     "\n",
@@ -150,60 +148,14 @@
     "    save_base_name=proj_id, c_lrs=c_lrs/16, g_lrs=g_lrs/16, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
     "\n",
     "scheds.extend(GANTrainSchedule.generate_schedules(szs=[256], bss=[2], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
-    "    save_base_name=proj_id, c_lrs=c_lrs/32, g_lrs=g_lrs/32, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[0]))"
+    "    save_base_name=proj_id, c_lrs=c_lrs/32, g_lrs=g_lrs/32, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[0]))\n"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      " 19%|█▉        | 1820/9603 [16:41<1:02:45,  2.07it/s]\n",
-      "HingeLoss 1.9387891292572021; RScore 1.3907694816589355; FScore 0.5480196475982666; GAddlLoss [2.99658]; Iters: 910; GCost: 0.04714561998844147;\n",
-      " 19%|█▉        | 1840/9603 [16:52<1:01:58,  2.09it/s]\n",
-      "HingeLoss 1.8926112651824951; RScore 1.554268479347229; FScore 0.3383428156375885; GAddlLoss [2.81035]; Iters: 920; GCost: -0.04119761660695076;\n",
-      " 19%|█▉        | 1860/9603 [17:02<1:02:21,  2.07it/s]\n",
-      "HingeLoss 1.9511566162109375; RScore 1.0059045553207397; FScore 0.9452521204948425; GAddlLoss [2.69479]; Iters: 930; GCost: 0.26951920986175537;\n",
-      " 20%|█▉        | 1880/9603 [17:12<1:02:15,  2.07it/s]\n",
-      "HingeLoss 2.095939874649048; RScore 1.069827914237976; FScore 1.0261119604110718; GAddlLoss [2.68969]; Iters: 940; GCost: 0.023471105843782425;\n",
-      " 20%|█▉        | 1900/9603 [17:22<1:01:42,  2.08it/s]\n",
-      "HingeLoss 1.884484887123108; RScore 0.8096139430999756; FScore 1.0748709440231323; GAddlLoss [2.65773]; Iters: 950; GCost: 0.20085659623146057;\n",
-      " 20%|█▉        | 1920/9603 [17:33<1:01:44,  2.07it/s]\n",
-      "HingeLoss 2.0832958221435547; RScore 1.148973822593689; FScore 0.9343219995498657; GAddlLoss [2.43384]; Iters: 960; GCost: -0.11352153867483139;\n",
-      " 20%|██        | 1940/9603 [17:47<1:20:45,  1.58it/s]\n",
-      "HingeLoss 1.8702123165130615; RScore 1.2256433963775635; FScore 0.6445689797401428; GAddlLoss [3.01765]; Iters: 970; GCost: -0.16625043749809265;\n",
-      " 20%|██        | 1960/9603 [17:58<1:02:36,  2.03it/s]\n",
-      "HingeLoss 1.84499990940094; RScore 0.8106263875961304; FScore 1.0343735218048096; GAddlLoss [2.60933]; Iters: 980; GCost: 0.18509630858898163;\n",
-      " 21%|██        | 1980/9603 [18:08<1:02:28,  2.03it/s]\n",
-      "HingeLoss 1.9964604377746582; RScore 0.8363078236579895; FScore 1.1601526737213135; GAddlLoss [2.49506]; Iters: 990; GCost: 0.26431044936180115;\n",
-      " 21%|██        | 2000/9603 [18:18<1:01:02,  2.08it/s]\n",
-      "HingeLoss 1.8580374717712402; RScore 1.0154190063476562; FScore 0.8426185250282288; GAddlLoss [3.00272]; Iters: 1000; GCost: 0.24961289763450623;\n",
-      " 21%|██        | 2020/9603 [19:09<1:02:34,  2.02it/s]\n",
-      "HingeLoss 2.065361261367798; RScore 0.9498858451843262; FScore 1.1154754161834717; GAddlLoss [2.75769]; Iters: 1010; GCost: -0.031428515911102295;\n",
-      " 21%|██        | 2040/9603 [19:19<1:01:24,  2.05it/s]\n",
-      "HingeLoss 1.7888972759246826; RScore 0.7048905491828918; FScore 1.0840067863464355; GAddlLoss [2.7787]; Iters: 1020; GCost: -0.028373755514621735;\n",
-      " 21%|██▏       | 2060/9603 [19:30<1:02:16,  2.02it/s]\n",
-      "HingeLoss 2.095205545425415; RScore 1.1978288888931274; FScore 0.8973767161369324; GAddlLoss [2.75315]; Iters: 1030; GCost: 0.18989555537700653;\n",
-      " 22%|██▏       | 2080/9603 [19:40<59:53,  2.09it/s]  \n",
-      "HingeLoss 2.0072576999664307; RScore 1.4106954336166382; FScore 0.5965622663497925; GAddlLoss [2.80513]; Iters: 1040; GCost: 0.1219414696097374;\n",
-      " 22%|██▏       | 2100/9603 [19:54<1:10:29,  1.77it/s]\n",
-      "HingeLoss 2.008185863494873; RScore 0.7701600193977356; FScore 1.2380259037017822; GAddlLoss [2.75081]; Iters: 1050; GCost: 0.28288665413856506;\n",
-      " 22%|██▏       | 2120/9603 [20:05<1:00:21,  2.07it/s]\n",
-      "HingeLoss 1.9713478088378906; RScore 0.8099216222763062; FScore 1.1614261865615845; GAddlLoss [2.60936]; Iters: 1060; GCost: -0.3364897668361664;\n",
-      " 22%|██▏       | 2140/9603 [20:15<59:40,  2.08it/s]  \n",
-      "HingeLoss 1.8121495246887207; RScore 1.5296125411987305; FScore 0.282537043094635; GAddlLoss [2.48548]; Iters: 1070; GCost: -0.668886125087738;\n",
-      " 22%|██▏       | 2160/9603 [20:26<1:00:03,  2.07it/s]\n",
-      "HingeLoss 2.076537609100342; RScore 1.172768235206604; FScore 0.9037694334983826; GAddlLoss [2.59395]; Iters: 1080; GCost: -0.06519972532987595;\n",
-      " 23%|██▎       | 2180/9603 [20:36<59:42,  2.07it/s]  \n",
-      "HingeLoss 2.165712356567383; RScore 1.2539243698120117; FScore 0.9117878675460815; GAddlLoss [2.74424]; Iters: 1090; GCost: 0.268320769071579;\n",
-      " 23%|██▎       | 2182/9603 [20:37<59:31,  2.08it/s]  "
-     ]
-    }
-   ],
+   "outputs": [],
    "source": [
     "trainer.train(scheds=scheds)"
    ]

+ 1 - 1
DeFadeVisualization.ipynb

@@ -48,7 +48,7 @@
    "source": [
     "IMAGENET = Path('data/imagenet/ILSVRC/Data/CLS-LOC/train')\n",
     "IMAGENET_SMALL = IMAGENET/'n01440764'\n",
-    "gpath = IMAGENET.parent/('defade_rc_gen_128.h5')\n",
+    "gpath = IMAGENET.parent/('defade_gen_192.h5')\n",
     "default_sz=400\n",
     "torch.backends.cudnn.benchmark=True"
    ]

+ 12 - 7
fasterai/generators.py

@@ -19,6 +19,9 @@ class GeneratorModule(ABC, nn.Module):
         for l in c:     set_trainable(l, False)
         for l in c[n:]: set_trainable(l, True)
 
+    def get_device(self):
+        next(self.parameters()).device
+
  
 class Unet34(GeneratorModule): 
     @staticmethod
@@ -37,6 +40,7 @@ class Unet34(GeneratorModule):
         bn=True
         sn=True
         self.rn, self.lr_cut = Unet34.get_pretrained_resnet_base()
+        self.relu = nn.ReLU()
         self.sfs = [SaveFeatures(self.rn[i]) for i in [2,4,5,6]]
 
         self.up1 = UnetBlock(512,256,512*nf_factor, sn=sn, leakyReLu=leakyReLu, bn=bn)
@@ -46,8 +50,8 @@ class Unet34(GeneratorModule):
         self.up5 = UpSampleBlock(256*nf_factor, 32*nf_factor, 2*scale, sn=sn, leakyReLu=leakyReLu, bn=bn) 
         self.out= nn.Sequential(ConvBlock(32*nf_factor, 3, ks=3, actn=False, bn=False, sn=sn), nn.Tanh())
 
-    #Gets around irritating inconsistent halving come from resnet
-    def _pad_xtensor(self, x, target):
+    #Gets around irritating inconsistent halving coming from resnet
+    def _pad(self, x, target):
         h = x.shape[2] 
         w = x.shape[3]
 
@@ -62,11 +66,12 @@ class Unet34(GeneratorModule):
         return x
            
     def forward(self, x_in: torch.Tensor):
-        x = F.relu(self.rn(x_in))
-        x = self.up1(x, self._pad_xtensor(self.sfs[3].features, x))
-        x = self.up2(x, self._pad_xtensor(self.sfs[2].features, x))
-        x = self.up3(x, self._pad_xtensor(self.sfs[1].features, x))
-        x = self.up4(x, self._pad_xtensor(self.sfs[0].features, x))
+        x = self.rn(x_in)
+        x = self.relu(x)
+        x = self.up1(x, self._pad(self.sfs[3].features, x))
+        x = self.up2(x, self._pad(self.sfs[2].features, x))
+        x = self.up3(x, self._pad(self.sfs[1].features, x))
+        x = self.up4(x, self._pad(self.sfs[0].features, x))
         x = self.up5(x)
         x = self.out(x)
         return x

Niektóre pliki nie zostały wyświetlone z powodu dużej ilości zmienionych plików