|
@@ -34,7 +34,7 @@
|
|
|
"from pathlib import Path\n",
|
|
|
"from itertools import repeat\n",
|
|
|
"import tensorboardX\n",
|
|
|
- "torch.cuda.set_device(3)\n",
|
|
|
+ "torch.cuda.set_device(0)\n",
|
|
|
"plt.style.use('dark_background')\n",
|
|
|
"torch.backends.cudnn.benchmark=True\n"
|
|
|
]
|
|
@@ -46,19 +46,21 @@
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
|
"IMAGENET = Path('data/imagenet/ILSVRC/Data/CLS-LOC/train')\n",
|
|
|
- "proj_id = 'defade'\n",
|
|
|
+ "proj_id = 'bwdefade'\n",
|
|
|
"TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id)\n",
|
|
|
- "gpath = IMAGENET.parent/(proj_id + '_gen_192.h5')\n",
|
|
|
- "dpath = IMAGENET.parent/(proj_id + '_critic_192.h5')\n",
|
|
|
- "c_lr=2e-4\n",
|
|
|
+ "gpath = IMAGENET.parent/(proj_id + '_gen_64.h5')\n",
|
|
|
+ "dpath = IMAGENET.parent/(proj_id + '_critic_64.h5')\n",
|
|
|
+ "c_lr=5e-4\n",
|
|
|
"c_lrs = np.array([c_lr,c_lr,c_lr])\n",
|
|
|
- "g_lr=c_lr/4\n",
|
|
|
- "g_lrs = np.array([g_lr/1000,g_lr/100,g_lr])\n",
|
|
|
+ "\n",
|
|
|
+ "g_lr=c_lr/5\n",
|
|
|
+ "g_lrs = np.array([g_lr/100,g_lr/10,g_lr])\n",
|
|
|
+ "\n",
|
|
|
"keep_pcts=[0.25,0.25]\n",
|
|
|
"gen_freeze_tos=[-1,0]\n",
|
|
|
- "lrs_unfreeze_factor=1.0\n",
|
|
|
+ "lrs_unfreeze_factor=0.05\n",
|
|
|
"x_tfms = [RandomLighting(0.5, 0.5)]\n",
|
|
|
- "extra_aug_tfms = []\n",
|
|
|
+ "extra_aug_tfms = [BlackAndWhiteTransform(tfm_y=TfmType.PIXEL)]\n",
|
|
|
"torch.backends.cudnn.benchmark=True"
|
|
|
]
|
|
|
},
|
|
@@ -79,7 +81,7 @@
|
|
|
"#netGVis = ModelVisualizationHook(TENSORBOARD_PATH, netG, 'netG')\n",
|
|
|
"#load_model(netG, gpath)\n",
|
|
|
"\n",
|
|
|
- "netD = DCCritic(ni=3, nf=512).cuda()\n",
|
|
|
+ "netD = DCCritic(ni=3, nf=384).cuda()\n",
|
|
|
"#netDVis = ModelVisualizationHook(TENSORBOARD_PATH, netD, 'netD')\n",
|
|
|
"#load_model(netD, dpath)"
|
|
|
]
|
|
@@ -102,64 +104,83 @@
|
|
|
"source": [
|
|
|
"scheds=[]\n",
|
|
|
"\n",
|
|
|
- "scheds.extend(GANTrainSchedule.generate_schedules(szs=[64,64], bss=[32,32], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms,keep_pcts=keep_pcts, \n",
|
|
|
+ "scheds.extend(GANTrainSchedule.generate_schedules(szs=[64,64], bss=[128,128], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms,keep_pcts=[1.0,1.0], \n",
|
|
|
" save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=gen_freeze_tos))\n",
|
|
|
"\n",
|
|
|
+ "c_lrs=c_lrs/2\n",
|
|
|
+ "g_lrs=g_lrs/2\n",
|
|
|
+ "\n",
|
|
|
"#unshock\n",
|
|
|
- "scheds.extend(GANTrainSchedule.generate_schedules(szs=[96], bss=[16], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
|
|
|
- " save_base_name=proj_id, c_lrs=c_lrs/20, g_lrs=g_lrs/20, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
|
|
|
+ "scheds.extend(GANTrainSchedule.generate_schedules(szs=[96], bss=[64], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
|
|
|
+ " save_base_name=proj_id, c_lrs=c_lrs/10, g_lrs=g_lrs/10, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
|
|
|
"\n",
|
|
|
- "scheds.extend(GANTrainSchedule.generate_schedules(szs=[96,96], bss=[16,16], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=keep_pcts, \n",
|
|
|
- " save_base_name=proj_id, c_lrs=c_lrs/2, g_lrs=g_lrs/2, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=gen_freeze_tos))\n",
|
|
|
+ "scheds.extend(GANTrainSchedule.generate_schedules(szs=[96,96], bss=[64,64], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=keep_pcts, \n",
|
|
|
+ " save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=gen_freeze_tos))\n",
|
|
|
"\n",
|
|
|
+ "c_lrs=c_lrs/2\n",
|
|
|
+ "g_lrs=g_lrs/2\n",
|
|
|
"\n",
|
|
|
"#unshock\n",
|
|
|
- "scheds.extend(GANTrainSchedule.generate_schedules(szs=[128], bss=[8], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
|
|
|
- " save_base_name=proj_id, c_lrs=c_lrs/30, g_lrs=g_lrs/30, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
|
|
|
+ "scheds.extend(GANTrainSchedule.generate_schedules(szs=[128], bss=[32], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
|
|
|
+ " save_base_name=proj_id, c_lrs=c_lrs/10, g_lrs=g_lrs/10, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
|
|
|
+ "\n",
|
|
|
+ "scheds.extend(GANTrainSchedule.generate_schedules(szs=[128,128], bss=[32,32], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=keep_pcts, \n",
|
|
|
+ " save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=gen_freeze_tos))\n",
|
|
|
"\n",
|
|
|
- "scheds.extend(GANTrainSchedule.generate_schedules(szs=[128,128], bss=[8,8], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=keep_pcts, \n",
|
|
|
- " save_base_name=proj_id, c_lrs=c_lrs/3, g_lrs=g_lrs/3, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=gen_freeze_tos))\n",
|
|
|
+ "\n",
|
|
|
+ "c_lrs=c_lrs/2\n",
|
|
|
+ "g_lrs=g_lrs/2\n",
|
|
|
"\n",
|
|
|
"#unshock\n",
|
|
|
- "scheds.extend(GANTrainSchedule.generate_schedules(szs=[160], bss=[5], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
|
|
|
- " save_base_name=proj_id, c_lrs=c_lrs/80, g_lrs=g_lrs/80, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
|
|
|
+ "scheds.extend(GANTrainSchedule.generate_schedules(szs=[160], bss=[20], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
|
|
|
+ " save_base_name=proj_id, c_lrs=c_lrs/10, g_lrs=g_lrs/10, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
|
|
|
+ "\n",
|
|
|
+ "scheds.extend(GANTrainSchedule.generate_schedules(szs=[160], bss=[20], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
|
|
|
+ " save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
|
|
|
+ "\n",
|
|
|
+ "scheds.extend(GANTrainSchedule.generate_schedules(szs=[160], bss=[20], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
|
|
|
+ " save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[0]))\n",
|
|
|
"\n",
|
|
|
- "scheds.extend(GANTrainSchedule.generate_schedules(szs=[160], bss=[5], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
|
|
|
- " save_base_name=proj_id, c_lrs=c_lrs/8, g_lrs=g_lrs/8, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
|
|
|
"\n",
|
|
|
- "scheds.extend(GANTrainSchedule.generate_schedules(szs=[160], bss=[5], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
|
|
|
- " save_base_name=proj_id, c_lrs=c_lrs/16, g_lrs=g_lrs/16, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[0]))\n",
|
|
|
+ "c_lrs=c_lrs/2\n",
|
|
|
+ "g_lrs=g_lrs/2\n",
|
|
|
"\n",
|
|
|
"#unshock\n",
|
|
|
- "scheds.extend(GANTrainSchedule.generate_schedules(szs=[192], bss=[4], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
|
|
|
- " save_base_name=proj_id, c_lrs=c_lrs/80, g_lrs=g_lrs/80, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
|
|
|
+ "scheds.extend(GANTrainSchedule.generate_schedules(szs=[192], bss=[12], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
|
|
|
+ " save_base_name=proj_id, c_lrs=c_lrs/10, g_lrs=g_lrs/10, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
|
|
|
+ "\n",
|
|
|
+ "scheds.extend(GANTrainSchedule.generate_schedules(szs=[192], bss=[12], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
|
|
|
+ " save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
|
|
|
+ "\n",
|
|
|
+ "scheds.extend(GANTrainSchedule.generate_schedules(szs=[192], bss=[12], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
|
|
|
+ " save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[0]))\n",
|
|
|
"\n",
|
|
|
- "scheds.extend(GANTrainSchedule.generate_schedules(szs=[192], bss=[3], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
|
|
|
- " save_base_name=proj_id, c_lrs=c_lrs/8, g_lrs=g_lrs/8, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
|
|
|
"\n",
|
|
|
- "scheds.extend(GANTrainSchedule.generate_schedules(szs=[192], bss=[3], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
|
|
|
- " save_base_name=proj_id, c_lrs=c_lrs/16, g_lrs=g_lrs/16, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[0]))\n",
|
|
|
+ "c_lrs=c_lrs/2\n",
|
|
|
+ "g_lrs=g_lrs/2\n",
|
|
|
"\n",
|
|
|
"#unshock\n",
|
|
|
- "scheds.extend(GANTrainSchedule.generate_schedules(szs=[224], bss=[2], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
|
|
|
- " save_base_name=proj_id, c_lrs=c_lrs/160, g_lrs=g_lrs/160, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
|
|
|
+ "scheds.extend(GANTrainSchedule.generate_schedules(szs=[224], bss=[8], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
|
|
|
+ " save_base_name=proj_id, c_lrs=c_lrs/10, g_lrs=g_lrs/10, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
|
|
|
"\n",
|
|
|
- "scheds.extend(GANTrainSchedule.generate_schedules(szs=[224], bss=[2], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
|
|
|
- " save_base_name=proj_id, c_lrs=c_lrs/16, g_lrs=g_lrs/16, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
|
|
|
+ "scheds.extend(GANTrainSchedule.generate_schedules(szs=[224], bss=[8], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
|
|
|
+ " save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
|
|
|
"\n",
|
|
|
- "scheds.extend(GANTrainSchedule.generate_schedules(szs=[224], bss=[2], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
|
|
|
- " save_base_name=proj_id, c_lrs=c_lrs/32, g_lrs=g_lrs/32, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[0]))\n",
|
|
|
+ "scheds.extend(GANTrainSchedule.generate_schedules(szs=[224], bss=[8], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
|
|
|
+ " save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[0]))\n",
|
|
|
"\n",
|
|
|
+ "c_lrs=c_lrs/2\n",
|
|
|
+ "g_lrs=g_lrs/2\n",
|
|
|
"\n",
|
|
|
"#unshock\n",
|
|
|
- "scheds.extend(GANTrainSchedule.generate_schedules(szs=[256], bss=[2], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
|
|
|
- " save_base_name=proj_id, c_lrs=c_lrs/160, g_lrs=g_lrs/160, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
|
|
|
+ "scheds.extend(GANTrainSchedule.generate_schedules(szs=[256], bss=[6], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
|
|
|
+ " save_base_name=proj_id, c_lrs=c_lrs/10, g_lrs=g_lrs/10, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
|
|
|
"\n",
|
|
|
- "scheds.extend(GANTrainSchedule.generate_schedules(szs=[256], bss=[2], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
|
|
|
- " save_base_name=proj_id, c_lrs=c_lrs/16, g_lrs=g_lrs/16, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
|
|
|
+ "scheds.extend(GANTrainSchedule.generate_schedules(szs=[256], bss=[6], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
|
|
|
+ " save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
|
|
|
"\n",
|
|
|
- "scheds.extend(GANTrainSchedule.generate_schedules(szs=[256], bss=[2], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
|
|
|
- " save_base_name=proj_id, c_lrs=c_lrs/32, g_lrs=g_lrs/32, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[0]))\n"
|
|
|
+ "scheds.extend(GANTrainSchedule.generate_schedules(szs=[256], bss=[6], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
|
|
|
+ " save_base_name=proj_id, c_lrs=c_lrs, g_lrs=g_lrs, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[0]))"
|
|
|
]
|
|
|
},
|
|
|
{
|