|
@@ -46,14 +46,12 @@
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
|
"IMAGENET = Path('data/imagenet/ILSVRC/Data/CLS-LOC/train')\n",
|
|
|
- "OPENIMAGES = Path('data/openimages')\n",
|
|
|
- "CIFAR10 = Path('data/cifar10/train')\n",
|
|
|
"\n",
|
|
|
- "proj_id = 'defade_rc'\n",
|
|
|
- "TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id)\n",
|
|
|
+ "proj_id = 'defade'\n",
|
|
|
+ "TENSORBOARD_PATH = Path('data/tensorboard/' + proj_id + '_cont2')\n",
|
|
|
"\n",
|
|
|
- "#gpath = IMAGENET.parent/(proj_id + '_gen_64.h5')\n",
|
|
|
- "#dpath = IMAGENET.parent/(proj_id + '_critic_64.h5')\n",
|
|
|
+ "gpath = IMAGENET.parent/(proj_id + '_gen_192.h5')\n",
|
|
|
+ "dpath = IMAGENET.parent/(proj_id + '_critic_192.h5')\n",
|
|
|
"\n",
|
|
|
"c_lr=5e-4\n",
|
|
|
"c_lrs = np.array([c_lr,c_lr,c_lr])\n",
|
|
@@ -123,10 +121,10 @@
|
|
|
"\n",
|
|
|
"\n",
|
|
|
"#unshock\n",
|
|
|
- "scheds.extend(GANTrainSchedule.generate_schedules(szs=[128], bss=[12], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
|
|
|
+ "scheds.extend(GANTrainSchedule.generate_schedules(szs=[128], bss=[8], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.1], \n",
|
|
|
" save_base_name=proj_id, c_lrs=c_lrs/30, g_lrs=g_lrs/30, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
|
|
|
"\n",
|
|
|
- "scheds.extend(GANTrainSchedule.generate_schedules(szs=[128,128], bss=[12,12], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=keep_pcts, \n",
|
|
|
+ "scheds.extend(GANTrainSchedule.generate_schedules(szs=[128,128], bss=[8,8], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=keep_pcts, \n",
|
|
|
" save_base_name=proj_id, c_lrs=c_lrs/3, g_lrs=g_lrs/3, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=gen_freeze_tos))\n",
|
|
|
"\n",
|
|
|
"\n",
|
|
@@ -150,60 +148,14 @@
|
|
|
" save_base_name=proj_id, c_lrs=c_lrs/16, g_lrs=g_lrs/16, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[-1]))\n",
|
|
|
"\n",
|
|
|
"scheds.extend(GANTrainSchedule.generate_schedules(szs=[256], bss=[2], path=IMAGENET, x_tfms=x_tfms, extra_aug_tfms=extra_aug_tfms, keep_pcts=[0.25], \n",
|
|
|
- " save_base_name=proj_id, c_lrs=c_lrs/32, g_lrs=g_lrs/32, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[0]))"
|
|
|
+ " save_base_name=proj_id, c_lrs=c_lrs/32, g_lrs=g_lrs/32, lrs_unfreeze_factor=lrs_unfreeze_factor, gen_freeze_tos=[0]))\n"
|
|
|
]
|
|
|
},
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
"execution_count": null,
|
|
|
"metadata": {},
|
|
|
- "outputs": [
|
|
|
- {
|
|
|
- "name": "stdout",
|
|
|
- "output_type": "stream",
|
|
|
- "text": [
|
|
|
- " 19%|█▉ | 1820/9603 [16:41<1:02:45, 2.07it/s]\n",
|
|
|
- "HingeLoss 1.9387891292572021; RScore 1.3907694816589355; FScore 0.5480196475982666; GAddlLoss [2.99658]; Iters: 910; GCost: 0.04714561998844147;\n",
|
|
|
- " 19%|█▉ | 1840/9603 [16:52<1:01:58, 2.09it/s]\n",
|
|
|
- "HingeLoss 1.8926112651824951; RScore 1.554268479347229; FScore 0.3383428156375885; GAddlLoss [2.81035]; Iters: 920; GCost: -0.04119761660695076;\n",
|
|
|
- " 19%|█▉ | 1860/9603 [17:02<1:02:21, 2.07it/s]\n",
|
|
|
- "HingeLoss 1.9511566162109375; RScore 1.0059045553207397; FScore 0.9452521204948425; GAddlLoss [2.69479]; Iters: 930; GCost: 0.26951920986175537;\n",
|
|
|
- " 20%|█▉ | 1880/9603 [17:12<1:02:15, 2.07it/s]\n",
|
|
|
- "HingeLoss 2.095939874649048; RScore 1.069827914237976; FScore 1.0261119604110718; GAddlLoss [2.68969]; Iters: 940; GCost: 0.023471105843782425;\n",
|
|
|
- " 20%|█▉ | 1900/9603 [17:22<1:01:42, 2.08it/s]\n",
|
|
|
- "HingeLoss 1.884484887123108; RScore 0.8096139430999756; FScore 1.0748709440231323; GAddlLoss [2.65773]; Iters: 950; GCost: 0.20085659623146057;\n",
|
|
|
- " 20%|█▉ | 1920/9603 [17:33<1:01:44, 2.07it/s]\n",
|
|
|
- "HingeLoss 2.0832958221435547; RScore 1.148973822593689; FScore 0.9343219995498657; GAddlLoss [2.43384]; Iters: 960; GCost: -0.11352153867483139;\n",
|
|
|
- " 20%|██ | 1940/9603 [17:47<1:20:45, 1.58it/s]\n",
|
|
|
- "HingeLoss 1.8702123165130615; RScore 1.2256433963775635; FScore 0.6445689797401428; GAddlLoss [3.01765]; Iters: 970; GCost: -0.16625043749809265;\n",
|
|
|
- " 20%|██ | 1960/9603 [17:58<1:02:36, 2.03it/s]\n",
|
|
|
- "HingeLoss 1.84499990940094; RScore 0.8106263875961304; FScore 1.0343735218048096; GAddlLoss [2.60933]; Iters: 980; GCost: 0.18509630858898163;\n",
|
|
|
- " 21%|██ | 1980/9603 [18:08<1:02:28, 2.03it/s]\n",
|
|
|
- "HingeLoss 1.9964604377746582; RScore 0.8363078236579895; FScore 1.1601526737213135; GAddlLoss [2.49506]; Iters: 990; GCost: 0.26431044936180115;\n",
|
|
|
- " 21%|██ | 2000/9603 [18:18<1:01:02, 2.08it/s]\n",
|
|
|
- "HingeLoss 1.8580374717712402; RScore 1.0154190063476562; FScore 0.8426185250282288; GAddlLoss [3.00272]; Iters: 1000; GCost: 0.24961289763450623;\n",
|
|
|
- " 21%|██ | 2020/9603 [19:09<1:02:34, 2.02it/s]\n",
|
|
|
- "HingeLoss 2.065361261367798; RScore 0.9498858451843262; FScore 1.1154754161834717; GAddlLoss [2.75769]; Iters: 1010; GCost: -0.031428515911102295;\n",
|
|
|
- " 21%|██ | 2040/9603 [19:19<1:01:24, 2.05it/s]\n",
|
|
|
- "HingeLoss 1.7888972759246826; RScore 0.7048905491828918; FScore 1.0840067863464355; GAddlLoss [2.7787]; Iters: 1020; GCost: -0.028373755514621735;\n",
|
|
|
- " 21%|██▏ | 2060/9603 [19:30<1:02:16, 2.02it/s]\n",
|
|
|
- "HingeLoss 2.095205545425415; RScore 1.1978288888931274; FScore 0.8973767161369324; GAddlLoss [2.75315]; Iters: 1030; GCost: 0.18989555537700653;\n",
|
|
|
- " 22%|██▏ | 2080/9603 [19:40<59:53, 2.09it/s] \n",
|
|
|
- "HingeLoss 2.0072576999664307; RScore 1.4106954336166382; FScore 0.5965622663497925; GAddlLoss [2.80513]; Iters: 1040; GCost: 0.1219414696097374;\n",
|
|
|
- " 22%|██▏ | 2100/9603 [19:54<1:10:29, 1.77it/s]\n",
|
|
|
- "HingeLoss 2.008185863494873; RScore 0.7701600193977356; FScore 1.2380259037017822; GAddlLoss [2.75081]; Iters: 1050; GCost: 0.28288665413856506;\n",
|
|
|
- " 22%|██▏ | 2120/9603 [20:05<1:00:21, 2.07it/s]\n",
|
|
|
- "HingeLoss 1.9713478088378906; RScore 0.8099216222763062; FScore 1.1614261865615845; GAddlLoss [2.60936]; Iters: 1060; GCost: -0.3364897668361664;\n",
|
|
|
- " 22%|██▏ | 2140/9603 [20:15<59:40, 2.08it/s] \n",
|
|
|
- "HingeLoss 1.8121495246887207; RScore 1.5296125411987305; FScore 0.282537043094635; GAddlLoss [2.48548]; Iters: 1070; GCost: -0.668886125087738;\n",
|
|
|
- " 22%|██▏ | 2160/9603 [20:26<1:00:03, 2.07it/s]\n",
|
|
|
- "HingeLoss 2.076537609100342; RScore 1.172768235206604; FScore 0.9037694334983826; GAddlLoss [2.59395]; Iters: 1080; GCost: -0.06519972532987595;\n",
|
|
|
- " 23%|██▎ | 2180/9603 [20:36<59:42, 2.07it/s] \n",
|
|
|
- "HingeLoss 2.165712356567383; RScore 1.2539243698120117; FScore 0.9117878675460815; GAddlLoss [2.74424]; Iters: 1090; GCost: 0.268320769071579;\n",
|
|
|
- " 23%|██▎ | 2182/9603 [20:37<59:31, 2.08it/s] "
|
|
|
- ]
|
|
|
- }
|
|
|
- ],
|
|
|
+ "outputs": [],
|
|
|
"source": [
|
|
|
"trainer.train(scheds=scheds)"
|
|
|
]
|