Переглянути джерело

Miscelaneous stuff that I just didn't check in yet. Derp

Jason Antic 6 роки тому
батько
коміт
e871496276

+ 8 - 457
.gitignore

@@ -1,27 +1,9 @@
-
-.ipynb_checkpoints/ColorizeTraining-checkpoint.ipynb
-.ipynb_checkpoints/ColorizeVisualization-checkpoint.ipynb
-.ipynb_checkpoints/DeFadeTraining-checkpoint.ipynb
-.ipynb_checkpoints/DeFadeVisualization-checkpoint.ipynb
-.ipynb_checkpoints/FinalVisualization-checkpoint.ipynb
 data
 fasterai/.ipynb_checkpoints/images-checkpoint.py
 fasterai/.ipynb_checkpoints/loss-checkpoint.py
 fasterai/.ipynb_checkpoints/transforms-checkpoint.py
 fasterai/.ipynb_checkpoints/visualize-checkpoint.py
-fasterai/__pycache__/wgan.cpython-36.pyc
-fasterai/__pycache__/visualize.cpython-36.pyc
-fasterai/__pycache__/transforms.cpython-36.pyc
-fasterai/__pycache__/training.cpython-36.pyc
-fasterai/__pycache__/structured.cpython-36.pyc
-fasterai/__pycache__/modules.cpython-36.pyc
-fasterai/__pycache__/models.cpython-36.pyc
-fasterai/__pycache__/loss.cpython-36.pyc
-fasterai/__pycache__/images.cpython-36.pyc
-fasterai/__pycache__/generators.cpython-36.pyc
-fasterai/__pycache__/files.cpython-36.pyc
-fasterai/__pycache__/dataset.cpython-36.pyc
-fasterai/__pycache__/callbacks.cpython-36.pyc
+fasterai/__pycache__/*.pyc
 fasterai/SymbolicLinks.sh
 SymbolicLinks.sh
 .ipynb_checkpoints/README-checkpoint.md
@@ -38,381 +20,10 @@ ColorizeTraining4.ipynb
 herp.jpg
 result_images/.ipynb_checkpoints/1864UnionSoldier-checkpoint.jpg
 test.py
-result_images/1850Geography.jpg
-result_images/1860Girls.jpg
-result_images/1860sSamauris.png
-result_images/1864UnionSoldier.jpg
-result_images/1867MusicianConstantinople.jpg
-result_images/1870Girl.jpg
-result_images/1870sSphinx.jpg
-result_images/1874Mexico.png
-result_images/1875Olds.jpg
-result_images/1880Paris.jpg
-result_images/1880sBrooklynBridge.jpg
-result_images/1888Slum.jpg
-result_images/1890BostonHospital.jpg
-result_images/1890CliffHouseSF.jpg
-result_images/1890sMedStudents.png
-result_images/1890sPingPong.jpg
-result_images/1890sShoeShopOhio.jpg
-result_images/1890sTouristsEgypt.png
-result_images/1890Surfer.png
-result_images/1892WaterLillies.jpg
-result_images/1895BikeMaidens.jpg
-result_images/1896NewsBoyGirl.jpg
-result_images/1897BlindmansBluff.jpg
-result_images/1899NycBlizzard.jpg
-result_images/1899SodaFountain.jpg
-result_images/1900sDaytonaBeach.png
-result_images/1900sSaloon.jpg
-result_images/1901Electrophone.jpg
-result_images/1907Cowboys.jpg
-result_images/1908FamilyPhoto.jpg
-result_images/1909Chicago.jpg
-result_images/1909ParisFirstFemaleTaxisDriver.jpg
-result_images/1910Bike.jpg
-result_images/1910Racket.png
-result_images/1916Sweeden.jpg
-result_images/1920CobblersShopLondon.jpg
-result_images/1920sDancing.jpg
-result_images/1920sFamilyPhoto.jpg
-result_images/1920sFarmKid.jpg
-result_images/1920sGuadalope.jpg
-result_images/1925Girl.jpg
-result_images/1929LondonOverFleetSt.jpg
-result_images/1930sGeorgia.jpg
-result_images/1933RockefellerCenter.jpg
-result_images/1938Reading.jpg
-result_images/1940sBeerRiver.jpg
-result_images/1946Wedding.jpg
-result_images/1948CarsGrandma.jpg
-result_images/20sWoman.jpg
-result_images/40sCouple.jpg
-result_images/abe.jpg
-result_images/AccordianKid1900Paris.jpg
-result_images/Agamemnon1919.jpg
-result_images/AirmanDad.jpg
-result_images/airmen1943.jpg
-result_images/AnselAdamsAdobe.jpg
-result_images/AnselAdamsBuildings.jpg
-result_images/AnselAdamsChurch.jpg
-result_images/AnselAdamsWhiteChurch.jpg
-result_images/AnselAdamsYosemite.jpg
-result_images/AppalachianLoggers1901.jpg
-result_images/Apsaroke1908.png
-result_images/ArkansasCowboys1880s.jpg
-result_images/AthleticClubParis1913.jpg
-result_images/AustriaHungaryWomen1890s.jpg
-result_images/BabyBigBoots.jpg
-result_images/Ballet1890Russia.jpg
-result_images/BellyLaughWWI.jpg
-result_images/bicycles.jpg
-result_images/BigManTavern1908NYC.jpg
-result_images/BombedLibraryLondon1940.jpg
-result_images/Boston1937.jpg
-result_images/BoulevardDuTemple1838.jpg
-result_images/BoxedBedEarly1900s.jpg
-result_images/BreadDelivery1920sIreland.jpg
-result_images/BritishDispatchRider.jpg
-result_images/BritishSlum.jpg
-result_images/BritishTeaBombay1890s.png
-result_images/brooklyn_girls_1940s.jpg
-result_images/BumperCarsParis1930.jpg
-result_images/CafeParis1928.jpg
-result_images/CafeTerrace1925Paris.jpg
-result_images/CalcuttaPoliceman1920.jpg
-result_images/camera_man.jpg
-result_images/Cars1890sIreland.jpg
-result_images/CatWash1931.jpg
-result_images/Chief.jpg
-result_images/ChinaOpiumc1880.jpg
-result_images/civil-war_2.jpg
-result_images/civil_war.jpg
-result_images/civil_war_3.jpg
-result_images/civil_war_4.jpg
-result_images/ClassDivide1930sBrittain.jpg
-result_images/CoalDeliveryParis1915.jpg
-result_images/Cork1905.jpg
-result_images/CorkKids1910.jpg
-result_images/CottonMill.jpg
-result_images/covered-wagons-traveling.jpg
-result_images/CricketLondon1930.jpg
-result_images/Deadwood1860s.png
-result_images/DeepSeaDiver1915.png
-result_images/Depression.jpg
-result_images/Dolores1920s.jpg
-result_images/Donegal1907Yarn.jpg
-result_images/dorothea-lange.jpg
-result_images/dorothea_lange_2.jpg
-result_images/DriveThroughGiantTree.jpg
-result_images/dustbowl_1.jpg
-result_images/dustbowl_2.jpg
-result_images/dustbowl_5.jpg
-result_images/dustbowl_people.jpg
-result_images/dustbowl_sd.jpg
-result_images/DutchBabyCoupleEllis.jpg
-result_images/EastEndLondonStreetKids1901.jpg
-result_images/EasterNyc1911.jpg
-result_images/Eddie-Adams.jpg
-result_images/Edinburgh1920s.jpg
-result_images/egypt-1.jpg
-result_images/egypt-2.jpg
-result_images/EgyptColosus.jpg
-result_images/EgyptianWomenLate1800s.jpg
-result_images/einstein_beach.jpg
-result_images/einstein_portrait.jpg
-result_images/ElectricScooter1915.jpeg
-result_images/ElephantLondon1934.png
-result_images/EmpireState1930.jpg
-result_images/Evelyn_Nesbit.jpg
-result_images/FadedDelores.PNG
-result_images/FadedDutchBabies.PNG
-result_images/FadedOvermiller.PNG
-result_images/FadedRacket.PNG
-result_images/FadedSphynx.PNG
-result_images/FarmWomen1895.jpg
-result_images/FatMenClub.jpg
-result_images/FatMensShop.jpg
-result_images/FreeportIL.jpg
-result_images/FreightTrainTeens1934.jpg
-result_images/FrenchVillage1950s.jpg
-result_images/GalwayIreland1902.jpg
-result_images/GasPrices1939.jpg
-result_images/GreatGrandparentsIrelandEarly1900s.jpg
-result_images/Greece1911.jpg
-result_images/GreekImmigrants1905.jpg
-result_images/HalloweenEarly1900s.jpg
-result_images/Harlem1932.jpg
-result_images/HarrodsLondon1920.jpg
-result_images/HealingTree.jpg
-result_images/HelenKeller.jpg
-result_images/helmut_newton-.jpg
-result_images/hemmingway.jpg
-result_images/Hemmingway2.jpg
-result_images/HerbSeller1899Paris.jpg
-result_images/HomeIreland1924.jpg
-result_images/HoovervilleSeattle1932.jpg
-result_images/HPLabelleOfficeMontreal.jpg
-result_images/HydeParkLondon1920s.jpg
-result_images/IceManLondon1919.jpg
-result_images/InuitWoman1903.png
-result_images/IrishLate1800s.jpg
-result_images/jacksonville.jpg
-result_images/Jane_Addams.jpg
-result_images/JerseyShore1905.jpg
-result_images/JudyGarland.jpeg
-result_images/Kabuki1870s.png
-result_images/KidCage1930s.png
-result_images/kids_pit.jpg
-result_images/Killarney1910.jpg
-result_images/last_samurai.jpg
-result_images/Late1800sNative.jpg
-result_images/LeBonMarcheParis1875.jpg
-result_images/LewisTomalinLondon1895.png
-result_images/LibraryOfCongress1910.jpg
-result_images/Lisbon1919.jpg
-result_images/LittleAirplane1934.jpg
-result_images/LivingRoom1920Sweeden.jpg
-result_images/Locomotive1880s.jpg
-result_images/London1850Coach.jpg
-result_images/London1900EastEndBlacksmith.jpg
-result_images/London1918WartimeClothesManufacture.jpg
-result_images/London1930sCheetah.jpg
-result_images/London1937.png
-result_images/LondonFireBrigadeMember1926.jpg
-result_images/LondonGarbageTruck1910.jpg
-result_images/LondonHeatWave1935.png
-result_images/LondonKidsEarly1900s.jpg
-result_images/LondonRailwayWork1931.jpg
-result_images/LondonSheep1920s.png
-result_images/LondonsSmallestShop1900.jpg
-result_images/LondonStreetDoctor1877.png
-result_images/LondonStreets1900.jpg
-result_images/LondonUnderground1860.jpg
-result_images/MadisonSquare1900.jpg
-result_images/MaioreWoman1895NZ.jpg
-result_images/ManPile.jpg
-result_images/marilyn_portrait.jpg
-result_images/marilyn_woods.jpg
-result_images/marktwain.jpg
-result_images/MementoMori1865.jpg
-result_images/MetropolitanDistrictRailway1869London.jpg
-result_images/Mid1800sSisters.jpg
-result_images/migrant_mother.jpg
-result_images/Mormons1870s.jpg
-result_images/MuffinManlLondon1910.jpg
-result_images/MuseauNacionalDosCoches.jpg
-result_images/NativeAmericans.jpg
-result_images/NativeCouple1912.jpg
-result_images/NativeWoman1926.jpg
-result_images/NewspaperCivilWar1863.jpg
-result_images/NewZealand1860s.jpg
-result_images/NorwegianBride1920s.jpg
-result_images/NYStreetClean1906.jpg
-result_images/opium.jpg
-result_images/OregonTrail1870s.jpg
-result_images/overmiller.jpg
-result_images/PaddingtonStationLondon1907.jpg
-result_images/PaddysMarketCork1900s.jpg
-result_images/paperboy.jpg
-result_images/Paris1899StreetDig.jpg
-result_images/Paris1920Cart.jpg
-result_images/Paris1926.jpg
-result_images/ParisLadies1910.jpg
-result_images/ParisLadies1930s.jpg
-result_images/ParisLate1800s.jpg
-result_images/ParisWomenFurs1920s.jpg
-result_images/PeddlerParis1899.jpg
-result_images/PetDucks1927.jpg
-result_images/PicadillyLate1800s.jpg
-result_images/PiggyBackRide.jpg
-result_images/pinkerton.jpg
-result_images/PlanesManhattan1931.jpg
-result_images/PostOfficeVermont1914.png
-result_images/poverty.jpg
-result_images/PuppyGify.jpg
-result_images/redwood_lumberjacks.jpg
-result_images/RepBrennanRadio1922.jpg
-result_images/rgs.jpg
-result_images/RossCorbettHouseCork.jpg
-result_images/Rottindean1890s.png
-result_images/royal_family.jpg
-result_images/RoyalUniversityMedStudent1900Ireland.jpg
-result_images/Rutherford_Hayes.jpg
-result_images/Sami1880s.jpg
-result_images/SanFran1851.jpg
-result_images/school_kids.jpg
-result_images/SchoolKidsConnemaraIreland1901.jpg
-result_images/Scotland1919.jpg
-result_images/SecondHandClothesLondonLate1800s.jpg
-result_images/SenecaNative1908.jpg
-result_images/ServantsBessboroughHouse1908Ireland.jpg
-result_images/Shack.jpg
-result_images/sioux.jpg
-result_images/skycrapper_lunch.jpg
-result_images/smoking_kid.jpg
-result_images/SoapBoxRacerParis1920s.jpg
-result_images/SoccerMotorcycles1923London.jpg
-result_images/soldier_kids.jpg
-result_images/Sphinx.jpeg
-result_images/SunHelmetsLondon1933.jpg
-result_images/SutroBaths1880s.jpg
-result_images/SynagogueInterior.PNG
-result_images/teddy_rubble.jpg
-result_images/Texas1938Woman.png
-result_images/TheatreGroupBombay1875.jpg
-result_images/TimesSquare1955.jpg
-result_images/TitanicGym.jpg
-result_images/TV1930s.jpg
-result_images/Unidentified1855.jpg
-result_images/unnamed.jpg
-result_images/VictorianDragQueen1880s.png
-result_images/VictorianLivingRoom.jpg
-result_images/ViennaBoys1880s.png
-result_images/w-b-yeats.jpg
-result_images/WalkingLibraryLondon1930.jpg
-result_images/WaltWhitman.jpg
-result_images/WaterfordIreland1909.jpg
-result_images/WestVirginiaHouse.jpg
-result_images/wh-auden.jpg
-result_images/wilson-slaverevivalmeeting.jpg
-result_images/women-bikers.png
-result_images/WomenTapingPlanes.jpg
-result_images/workers_canyon.jpg
-result_images/WorldsFair1900Paris.jpg
-result_images/WorriedKid1940sNyc.jpg
-result_images/ww1_trench.jpg
-result_images/WWIHospital.jpg
-result_images/WWIIPeeps.jpg
-result_images/WWISikhs.jpg
-result_images/ZoologischerGarten1898.jpg
-result_images/Twitter Social Icons.zip
-result_images/1850SchoolForGirls.jpg
-result_images/1852GatekeepersWindsor.jpg
-result_images/ArmisticeDay1918.jpg
-result_images/AtlanticCity1905.png
-result_images/AtlanticCityBeach1905.jpg
-result_images/BrooklynNavyYardHospital.jpg
-result_images/CottonMillWorkers1913.jpg
-result_images/DayAtSeaBelgium.jpg
-result_images/Drive1905.jpg
-result_images/FamilyWithDog.jpg
-result_images/FinnishPeasant1867.jpg
-result_images/FlyingMachinesParis1909.jpg
-result_images/GreatAunt1920.jpg
-result_images/HelenKeller.jpg
-result_images/IronLung.png
-result_images/NewBrunswick1915.jpg
-result_images/OldWomanSweden1904.jpg
-result_images/PushingCart-Copy1.jpg
-result_images/PushingCart.jpg
-result_images/Sami1880s.jpg
-result_images/Scotland1919.jpg
-result_images/SenecaNative1908.jpg
-result_images/ShoeMakerLate1800s.jpg
-result_images/SpottedBull1908.jpg
-result_images/TitanicGym.jpg
-result_images/TouristsGermany1904.jpg
-result_images/TunisianStudents1914.jpg
-result_images/Yorktown1862.jpg
-result_images/TitanicGym.jpg
-result_images/SenecaNative1908.jpg
-result_images/Scotland1919.jpg
-result_images/Sami1880s.jpg
-result_images/HelenKeller.jpg
-result_images/HelenKeller.jpg
-result_images/Sami1880s.jpg
-result_images/Scotland1919.jpg
-result_images/SenecaNative1908.jpg
-result_images/TitanicGym.jpg
-.~ColorizeVisualization.ipynb
-result_images/1886Hoop.jpg
-result_images/1886ProspectPark.jpg
-result_images/1890sChineseImmigrants.jpg
-result_images/1897VancouverAmberlamps.jpg
-result_images/1900NJThanksgiving.jpg
-result_images/1900ParkDog.jpg
-result_images/1900sLadiesTeaParty.jpg
-result_images/1902FrenchYellowBellies.jpg
-result_images/1904ChineseMan.jpg
-result_images/1908RookeriesLondon.jpg
-result_images/1910Finland.jpg
-result_images/1910ThanksgivingMaskersII.jpg
-result_images/1911ThanksgivingMaskers.jpg
-result_images/1913NewYorkConstruction.jpg
-result_images/1919WWIAviationOxygenMask.jpg
-result_images/1923HollandTunnel.jpg
-result_images/1925GypsyCampMaryland.jpg
-result_images/1929VictorianCosplayLondon.jpg
-result_images/1930MaineSchoolBus.jpg
-result_images/1930sRooftopPoland.jpg
-result_images/1934UmbriaItaly.jpg
-result_images/1936OpiumShanghai.jpg
-result_images/1936ParisCafe.jpg
-result_images/1936PetToad.jpg
-result_images/1939GypsyKids.jpg
-result_images/1939SewingBike.png
-result_images/1939YakimaWAGirl.jpg
-result_images/1940Connecticut.jpg
-result_images/1940PAFamily.jpg
-result_images/1941GeorgiaFarmhouse.jpg
-result_images/1941PoolTableGeorgia.jpg
-result_images/1945HiroshimaChild.jpg
-result_images/1950sLondonPoliceChild.jpg
-result_images/1959ParisFriends.png
-result_images/GoldenGateConstruction.jpg
-result_images/HelenKeller.jpg
-result_images/LondonFashion1911.png
-result_images/PostCivilWarAncestors.jpg
-result_images/Sami1880s.jpg
-result_images/Scotland1919.jpg
-result_images/SenecaNative1908.jpg
-result_images/TitanicGym.jpg
-result_images/HelenKeller.jpg
-result_images/Sami1880s.jpg
-result_images/Scotland1919.jpg
-result_images/SenecaNative1908.jpg
-result_images/TitanicGym.jpg
+result_images/*.jpg
+result_images/*.jpeg
+result_images/*.png
+
 fasterai/fastai
 .ipynb_checkpoints/ SuperResolutionVisualization-checkpoint.ipynb
 .ipynb_checkpoints/SuperResolutionTraining-checkpoint.ipynb
@@ -422,61 +33,7 @@ fasterai/.ipynb_checkpoints/modules-checkpoint.py
 result_images/.ipynb_checkpoints/ILSVRC2012_test_00000002-checkpoint.JPEG
 superres/result_images/ILSVRC2012_test_00000643.JPEG
 superres/result_images/Siamese_178.jpg
-superres/test_images/ILSVRC2012_test_00000642.JPEG
-superres/test_images/ILSVRC2012_test_00000643.JPEG
-superres/test_images/ILSVRC2012_test_00000644.JPEG
-superres/test_images/ILSVRC2012_test_00000645.JPEG
-superres/test_images/ILSVRC2012_test_00000646.JPEG
-superres/test_images/ILSVRC2012_test_00000647.JPEG
-superres/test_images/ILSVRC2012_test_00000648.JPEG
-superres/test_images/ILSVRC2012_test_00000649.JPEG
-superres/test_images/ILSVRC2012_test_00000650.JPEG
-superres/test_images/ILSVRC2012_test_00000651.JPEG
-superres/test_images/ILSVRC2012_test_00000652.JPEG
-superres/test_images/ILSVRC2012_test_00000653.JPEG
-superres/test_images/ILSVRC2012_test_00000654.JPEG
-superres/test_images/ILSVRC2012_test_00000655.JPEG
-superres/test_images/ILSVRC2012_test_00000656.JPEG
-superres/test_images/ILSVRC2012_test_00000657.JPEG
-superres/test_images/ILSVRC2012_test_00000658.JPEG
-superres/test_images/ILSVRC2012_test_00000659.JPEG
-superres/test_images/ILSVRC2012_test_00000660.JPEG
-superres/test_images/ILSVRC2012_test_00000661.JPEG
-superres/test_images/ILSVRC2012_test_00000662.JPEG
-superres/test_images/ILSVRC2012_test_00000663.JPEG
-superres/test_images/ILSVRC2012_test_00000664.JPEG
-superres/test_images/ILSVRC2012_test_00000665.JPEG
-superres/test_images/ILSVRC2012_test_00001067.JPEG
-superres/test_images/ILSVRC2012_test_00001068.JPEG
-superres/test_images/ILSVRC2012_test_00001069.JPEG
-superres/test_images/ILSVRC2012_test_00001070.JPEG
-superres/test_images/ILSVRC2012_test_00001071.JPEG
-superres/test_images/ILSVRC2012_test_00001072.JPEG
-superres/test_images/ILSVRC2012_test_00001073.JPEG
-superres/test_images/ILSVRC2012_test_00001074.JPEG
-superres/test_images/ILSVRC2012_test_00001075.JPEG
-superres/test_images/ILSVRC2012_test_00001076.JPEG
-superres/test_images/ILSVRC2012_test_00001077.JPEG
-superres/test_images/ILSVRC2012_test_00001078.JPEG
-superres/test_images/ILSVRC2012_test_00001079.JPEG
-superres/test_images/ILSVRC2012_test_00001080.JPEG
-superres/test_images/ILSVRC2012_test_00001081.JPEG
-superres/test_images/ILSVRC2012_test_00001082.JPEG
-superres/test_images/ILSVRC2012_test_00001083.JPEG
-superres/test_images/ILSVRC2012_test_00002343.JPEG
-superres/test_images/ILSVRC2012_test_00002344.JPEG
-superres/test_images/ILSVRC2012_test_00002345.JPEG
-superres/test_images/ILSVRC2012_test_00002346.JPEG
-superres/test_images/ILSVRC2012_test_00002347.JPEG
-superres/test_images/ILSVRC2012_test_00002348.JPEG
-superres/test_images/ILSVRC2012_test_00002349.JPEG
-superres/test_images/scottish_terrier_159.jpg
-superres/test_images/scottish_terrier_161.jpg
-superres/test_images/scottish_terrier_162.jpg
-superres/test_images/shiba_inu_137.jpg
-superres/test_images/shiba_inu_139.jpg
-superres/test_images/Siamese_178.jpg
-superres/test_images/Siamese_182.jpg
+superres/test_images/*.JPEG
 superres2x34_gen_pretrain.h5
 superres2x_gen_pretrain.h5
 superres_crit_pretrain.h5
@@ -485,14 +42,8 @@ test_images/Andy.jpg
 *.prof
 fastai
 *.pth
-.ipynb_checkpoints/SuperResTraining-checkpoint.ipynb
-.ipynb_checkpoints/ColorizeTrainingNew2-checkpoint.ipynb
-.ipynb_checkpoints/ColorizeTrainingNew-checkpoint.ipynb
-.ipynb_checkpoints/ColorizeTrainingNew3-checkpoint.ipynb
-.ipynb_checkpoints/ColorizeTrainingNew4-checkpoint.ipynb
 ColorizeTrainingNew2.ipynb
 ColorizeTrainingNew3.ipynb
 ColorizeTrainingNew4.ipynb
-.ipynb_checkpoints/ColorizeTraining1-checkpoint.ipynb
-.ipynb_checkpoints/ColorizeVisualization2-checkpoint.ipynb
-.ipynb_checkpoints/DeOldify-video-checkpoint.ipynb
+.ipynb_checkpoints/*-checkpoint.ipynb
+video

+ 2 - 5
DeOldify_colab.ipynb → ImageColorizerColab.ipynb

@@ -7,7 +7,7 @@
     "id": "view-in-github"
    },
    "source": [
-    "<a href=\"https://colab.research.google.com/github/jantic/DeOldify/blob/master/DeOldify_colab.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
+    "<a href=\"https://colab.research.google.com/github/jantic/DeOldify/blob/master/ImageColorizerColab.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
    ]
   },
   {
@@ -88,8 +88,7 @@
    },
    "outputs": [],
    "source": [
-    "!pip install PyDrive\n",
-    "!pip install tensorboardX"
+    "!pip install PyDrive"
    ]
   },
   {
@@ -112,11 +111,9 @@
     "import fastai\n",
     "from fastai import *\n",
     "from fastai.vision import *\n",
-    "from fastai.callbacks import *\n",
     "from fastai.vision.gan import *\n",
     "from fasterai.dataset import *\n",
     "from fasterai.visualize import *\n",
-    "from fasterai.tensorboard import *\n",
     "from fasterai.loss import *\n",
     "from fasterai.filters import *\n",
     "from fasterai.generators import *\n",

+ 1 - 1
README.md

@@ -1,6 +1,6 @@
 # DeOldify
 
-[<img src="https://colab.research.google.com/assets/colab-badge.svg" align="center">](https://colab.research.google.com/github/jantic/DeOldify/blob/master/DeOldify_colab.ipynb) 
+[<img src="https://colab.research.google.com/assets/colab-badge.svg" align="center">](https://colab.research.google.com/github/jantic/DeOldify/blob/master/ImageColorizerColab.ipynb) 
 
 [Get more updates on Twitter <img src="result_images/Twitter_Social_Icon_Rounded_Square_Color.svg" width="16">](https://twitter.com/citnaj)
 

+ 373 - 0
VideoColorizer.ipynb

@@ -0,0 +1,373 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import os\n",
+    "os.environ['CUDA_VISIBLE_DEVICES']='3' "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import fastai\n",
+    "import ffmpeg\n",
+    "from fastai import *\n",
+    "from fastai.vision import *\n",
+    "from fastai.callbacks.tensorboard import *\n",
+    "from fastai.vision.gan import *\n",
+    "from fasterai.dataset import *\n",
+    "from fasterai.visualize import *\n",
+    "from fasterai.loss import *\n",
+    "from fasterai.filters import *\n",
+    "from fasterai.generators import *\n",
+    "from pathlib import Path\n",
+    "from itertools import repeat\n",
+    "from IPython.display import HTML, display\n",
+    "plt.style.use('dark_background')\n",
+    "torch.backends.cudnn.benchmark=True"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "#Adjust render_factor (int) if image doesn't look quite right (max 64 on 11GB GPU).  The default here works for most photos.  \n",
+    "#It literally just is a number multiplied by 16 to get the square render resolution.  \n",
+    "#Note that this doesn't affect the resolution of the final output- the output is the same resolution as the input.\n",
+    "#Example:  render_factor=21 => color is rendered at 16x21 = 336x336 px.  \n",
+    "render_factor=25\n",
+    "root_folder =  Path('data/imagenet/ILSVRC/Data/CLS-LOC/bandw')\n",
+    "weights_name = 'ColorizeNew44_gen19205'\n",
+    "#weights_name = 'ColorizeNew32_gen'\n",
+    "nf_factor = 1.25\n",
+    "\n",
+    "workfolder = Path('./video')\n",
+    "source_folder = workfolder/\"source\"\n",
+    "bwframes_root = workfolder/\"bwframes\"\n",
+    "colorframes_root = workfolder/\"colorframes\"\n",
+    "result_folder = workfolder/\"result\"\n",
+    "#Make source_url None to just read from source_path directly without modification\n",
+    "source_url = 'https://twitter.com/silentmoviegifs/status/1092793719173115905'\n",
+    "#source_url=None\n",
+    "source_name = 'video5.mp4'\n",
+    "source_path =  source_folder/source_name\n",
+    "bwframes_folder = bwframes_root/(source_path.stem)\n",
+    "colorframes_folder = colorframes_root/(source_path.stem)\n",
+    "result_path = result_folder/source_name"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def progress(value, max=100):\n",
+    "    return HTML(\"\"\"\n",
+    "        <progress\n",
+    "            value='{value}'\n",
+    "            max='{max}',\n",
+    "            style='width: 40%'\n",
+    "        >\n",
+    "            {value}\n",
+    "        </progress>\n",
+    "    \"\"\".format(value=value, max=max))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def get_fps():\n",
+    "    probe = ffmpeg.probe(str(source_path))\n",
+    "    stream_data = next((stream for stream in probe['streams'] if stream['codec_type'] == 'video'), None)\n",
+    "    avg_frame_rate = stream_data['avg_frame_rate']\n",
+    "    print(avg_frame_rate)\n",
+    "    fps_num=avg_frame_rate.split(\"/\")[0]\n",
+    "    fps_den = avg_frame_rate.rsplit(\"/\")[1]\n",
+    "    return round(float(fps_num)/float(fps_den))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def purge_images(dir):\n",
+    "    for f in os.listdir(dir):\n",
+    "        if re.search('.*?\\.jpg', f):\n",
+    "            os.remove(os.path.join(dir, f))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Download Video (optional via setting source_url)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "##### Specify media_url.  Many sources will work (YouTube, Imgur, Twitter, Reddit, etc). Complete list here:  https://rg3.github.io/youtube-dl/supportedsites.html .  The resulting file path can be used later."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "youtube-dl \"https://twitter.com/silentmoviegifs/status/1092793719173115905\" -o \"video/source/video5.mp4\"\n",
+      "\n",
+      "\n",
+      "[twitter] 1092793719173115905: Downloading webpage\n",
+      "[twitter:card] 1092793719173115905: Downloading webpage\n",
+      "[twitter:card] 1092793719173115905: Downloading guest token\n",
+      "[twitter:card] 1092793719173115905: Downloading JSON metadata\n",
+      "[download] Destination: video/source/video5.mp4\n",
+      "\n",
+      "\u001b[K[download]   0.4% of 252.07KiB at 567.03KiB/s ETA 00:00\n",
+      "\u001b[K[download]   1.2% of 252.07KiB at  1.55MiB/s ETA 00:00\n",
+      "\u001b[K[download]   2.8% of 252.07KiB at  3.37MiB/s ETA 00:00\n",
+      "\u001b[K[download]   6.0% of 252.07KiB at  6.52MiB/s ETA 00:00\n",
+      "\u001b[K[download]  12.3% of 252.07KiB at  5.96MiB/s ETA 00:00\n",
+      "\u001b[K[download]  25.0% of 252.07KiB at  3.37MiB/s ETA 00:00\n",
+      "\u001b[K[download]  50.4% of 252.07KiB at  3.60MiB/s ETA 00:00\n",
+      "\u001b[K[download] 100.0% of 252.07KiB at  5.21MiB/s ETA 00:00\n",
+      "\u001b[K[download] 100% of 252.07KiB in 00:00\n",
+      "\n"
+     ]
+    }
+   ],
+   "source": [
+    "if source_url is not None:\n",
+    "    if source_path.exists(): source_path.unlink()\n",
+    "    youtubdl_command = 'youtube-dl \"' + source_url + '\" -o \"' + str(source_path) + '\"'\n",
+    "    print(youtubdl_command)\n",
+    "    print('\\n')\n",
+    "    output = Path(os.popen(youtubdl_command).read())\n",
+    "    print(str(output))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Extract Raw Frames"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "(b'', None)"
+      ]
+     },
+     "execution_count": 8,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "bwframe_path_template = str(bwframes_folder/'%5d.jpg')\n",
+    "bwframes_folder.mkdir(parents=True, exist_ok=True)\n",
+    "purge_images(bwframes_folder)\n",
+    "ffmpeg.input(str(source_path)).output(str(bwframe_path_template), format='image2', vcodec='mjpeg', qscale=0).run(capture_stdout=True)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "framecount = len(os.listdir(str(bwframes_folder)))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## DeOldify / Colorize"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/media/jason/Projects/Deep Learning/DeOldifyV2/DeOldify/fastai/data_block.py:414: UserWarning: Your training set is empty. Is this is by design, pass `ignore_empty=True` to remove this warning.\n",
+      "  warn(\"Your training set is empty. Is this is by design, pass `ignore_empty=True` to remove this warning.\")\n",
+      "/media/jason/Projects/Deep Learning/DeOldifyV2/DeOldify/fastai/data_block.py:417: UserWarning: Your validation set is empty. Is this is by design, use `no_split()` \n",
+      "                 or pass `ignore_empty=True` when labelling to remove this warning.\n",
+      "  or pass `ignore_empty=True` when labelling to remove this warning.\"\"\")\n"
+     ]
+    }
+   ],
+   "source": [
+    "vis = get_colorize_visualizer(root_folder=root_folder, weights_name=weights_name, nf_factor=nf_factor, render_factor=render_factor)\n",
+    "#vis = get_colorize_visualizer(render_factor=render_factor)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 11,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "\n",
+       "        <progress\n",
+       "            value='71'\n",
+       "            max='71',\n",
+       "            style='width: 40%'\n",
+       "        >\n",
+       "            71\n",
+       "        </progress>\n",
+       "    "
+      ],
+      "text/plain": [
+       "<IPython.core.display.HTML object>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
+   "source": [
+    "prog = 0\n",
+    "out = display(progress(0, 100), display_id=True)\n",
+    "colorframes_folder.mkdir(parents=True, exist_ok=True)\n",
+    "purge_images(colorframes_folder)\n",
+    "\n",
+    "for img in os.listdir(str(bwframes_folder)):\n",
+    "    img_path = bwframes_folder/img\n",
+    "    if os.path.isfile(str(img_path)):\n",
+    "        color_image = vis.get_transformed_image(str(img_path), render_factor)\n",
+    "        color_image.save(str(colorframes_folder/img))\n",
+    "    prog += 1\n",
+    "    out.update(progress(prog, framecount))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Build Video"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 12,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "71/7\n",
+      "10\n"
+     ]
+    },
+    {
+     "data": {
+      "text/plain": [
+       "(b'', None)"
+      ]
+     },
+     "execution_count": 12,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "colorframes_path_template = str(colorframes_folder/'%5d.jpg')\n",
+    "result_path.parent.mkdir(parents=True, exist_ok=True)\n",
+    "\n",
+    "if result_path.exists(): result_path.unlink()\n",
+    "\n",
+    "fps = get_fps()\n",
+    "print(fps)\n",
+    "ffmpeg.input(str(colorframes_path_template), format='image2', vcodec='mjpeg', framerate=str(fps)).output(str(result_path), crf=17, vcodec='libx264').run(capture_stdout=True)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.7.0"
+  },
+  "toc": {
+   "colors": {
+    "hover_highlight": "#DAA520",
+    "navigate_num": "#000000",
+    "navigate_text": "#333333",
+    "running_highlight": "#FF0000",
+    "selected_highlight": "#FFD700",
+    "sidebar_border": "#EEEEEE",
+    "wrapper_background": "#FFFFFF"
+   },
+   "moveMenuLeft": true,
+   "nav_menu": {
+    "height": "67px",
+    "width": "252px"
+   },
+   "navigate_menu": true,
+   "number_sections": true,
+   "sideBar": true,
+   "threshold": 4,
+   "toc_cell": false,
+   "toc_section_display": "block",
+   "toc_window_display": false,
+   "widenNotebook": false
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}

+ 3 - 5
DeOldify-video.ipynb → VideoColorizerColab.ipynb

@@ -7,7 +7,7 @@
     "id": "view-in-github"
    },
    "source": [
-    "<a href=\"https://colab.research.google.com/github/jantic/DeOldify/blob/master/DeOldify-video.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
+    "<a href=\"https://colab.research.google.com/github/jantic/DeOldify/blob/master/VideoColorizerColab.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
    ]
   },
   {
@@ -124,8 +124,7 @@
    "source": [
     "!pip install PyDrive\n",
     "!pip install ffmpeg-python\n",
-    "!pip install youtube-dl\n",
-    "!pip install tensorboardX"
+    "!pip install youtube-dl"
    ]
   },
   {
@@ -148,11 +147,10 @@
     "import fastai\n",
     "from fastai import *\n",
     "from fastai.vision import *\n",
-    "from fastai.callbacks import *\n",
+    "from fastai.callbacks.tensorboard import *\n",
     "from fastai.vision.gan import *\n",
     "from fasterai.dataset import *\n",
     "from fasterai.visualize import *\n",
-    "from fasterai.tensorboard import *\n",
     "from fasterai.loss import *\n",
     "from fasterai.filters import *\n",
     "from fasterai.generators import *\n",

+ 3 - 0
environment.yml

@@ -1,5 +1,8 @@
 name: deoldify
 dependencies:
 - fastai>=1.0.42
+- ffmpeg >= 4.0
 - pip:
   - tensorboardX>=1.4
+  - youtube-dl
+  - ffmpeg-python

+ 22 - 1
fasterai/critics.py

@@ -26,4 +26,25 @@ def custom_gan_critic(n_channels:int=3, nf:int=256, n_blocks:int=3, p:int=0.15):
     return nn.Sequential(*layers)
 
 def colorize_crit_learner(data:ImageDataBunch, loss_critic=AdaptiveLoss(nn.BCEWithLogitsLoss()), nf:int=256)->Learner:
-    return Learner(data, custom_gan_critic(nf=nf), metrics=accuracy_thresh_expand, loss_func=loss_critic, wd=1e-3)
+    return Learner(data, custom_gan_critic(nf=nf), metrics=accuracy_thresh_expand, loss_func=loss_critic, wd=1e-3)
+
+
+
+def custom_gan_critic2(n_channels:int=3, nf:int=256, n_blocks:int=3, p:int=0.15):
+    "Critic to train a `GAN`."
+    layers = [
+        _conv(n_channels, nf, ks=4, stride=2),
+        nn.Dropout2d(p/2),
+        _conv(nf, nf, ks=3, stride=1)]
+    for i in range(n_blocks):
+        layers += [
+            nn.Dropout2d(p),
+            _conv(nf, nf*2, ks=4, stride=2, self_attention=(i==0))]
+        nf *= 2
+    layers += [
+        _conv(nf, 1, ks=4, bias=False, padding=0, use_activ=False),
+        Flatten()]
+    return nn.Sequential(*layers)
+
+def colorize_crit_learner2(data:ImageDataBunch, loss_critic=AdaptiveLoss(nn.BCEWithLogitsLoss()), nf:int=256)->Learner:
+    return Learner(data, custom_gan_critic2(nf=nf), metrics=accuracy_thresh_expand, loss_func=loss_critic, wd=1e-3)

+ 33 - 1
fasterai/generators.py

@@ -1,6 +1,6 @@
 from fastai.vision import *
 from fastai.vision.learner import cnn_config
-from .unet import CustomDynamicUnet
+from .unet import CustomDynamicUnet, CustomDynamicUnet2
 from .loss import FeatureLoss
 from .dataset import *
 
@@ -32,4 +32,36 @@ def custom_unet_learner(data:DataBunch, arch:Callable, pretrained:bool=True, blu
     learn.split(ifnone(split_on,meta['split']))
     if pretrained: learn.freeze()
     apply_init(model[2], nn.init.kaiming_normal_)
+    return learn
+
+#-----------------------------
+
+#Weights are implicitly read from ./models/ folder 
+def colorize_gen_inference2(root_folder:Path, weights_name:str, nf_factor:int)->Learner:
+      data = get_dummy_databunch()
+      learn = colorize_gen_learner2(data=data, gen_loss=F.l1_loss, nf_factor=nf_factor)
+      learn.path = root_folder
+      learn.load(weights_name)
+      learn.model.eval()
+      return learn
+
+def colorize_gen_learner2(data:ImageDataBunch, gen_loss=FeatureLoss(), arch=models.resnet34, nf_factor:int=1)->Learner:
+    return custom_unet_learner2(data, arch, wd=1e-3, blur=True, norm_type=NormType.Spectral,
+                        self_attention=True, y_range=(-3.,3.), loss_func=gen_loss, nf_factor=nf_factor)
+
+#The code below is meant to be merged into fastaiv1 ideally
+def custom_unet_learner2(data:DataBunch, arch:Callable, pretrained:bool=True, blur_final:bool=True,
+                 norm_type:Optional[NormType]=NormType, split_on:Optional[SplitFuncOrIdxList]=None, 
+                 blur:bool=False, self_attention:bool=False, y_range:Optional[Tuple[float,float]]=None, last_cross:bool=True,
+                 bottle:bool=False, nf_factor:int=1, **kwargs:Any)->Learner:
+    "Build Unet learner from `data` and `arch`."
+    meta = cnn_config(arch)
+    body = create_body(arch, pretrained)
+    model = to_device(CustomDynamicUnet2(body, n_classes=data.c, blur=blur, blur_final=blur_final,
+          self_attention=self_attention, y_range=y_range, norm_type=norm_type, last_cross=last_cross,
+          bottle=bottle, nf_factor=nf_factor), data.device)
+    learn = Learner(data, model, **kwargs)
+    learn.split(ifnone(split_on,meta['split']))
+    if pretrained: learn.freeze()
+    apply_init(model[2], nn.init.kaiming_normal_)
     return learn

+ 177 - 0
fasterai/loss.py

@@ -4,6 +4,68 @@ from fastai.torch_core import *
 from fastai.callbacks  import hook_outputs
 import torchvision.models as models
 
+class FeatureLoss3(nn.Module):
+    def __init__(self, layer_wgts=[5,15,2]):
+        super().__init__()
+
+        self.m_feat = models.vgg16_bn(True).features.cuda().eval()
+        requires_grad(self.m_feat, False)
+        blocks = [i-1 for i,o in enumerate(children(self.m_feat)) if isinstance(o,nn.MaxPool2d)]
+        layer_ids = blocks[2:5]
+        self.loss_features = [self.m_feat[i] for i in layer_ids]
+        self.hooks = hook_outputs(self.loss_features, detach=False)
+        self.wgts = layer_wgts
+        self.metric_names = ['pixel',] + [f'feat_{i}' for i in range(len(layer_ids))] 
+        self.base_loss = F.l1_loss
+
+    def _make_features(self, x, clone=False):
+        self.m_feat(x)
+        return [(o.clone() if clone else o) for o in self.hooks.stored]
+
+    def forward(self, input, target):
+        out_feat = self._make_features(target, clone=True)
+        in_feat = self._make_features(input)
+        self.feat_losses = [self.base_loss(input,target)]
+        self.feat_losses += [self.base_loss(f_in, f_out)*w
+                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
+        
+        self.metrics = dict(zip(self.metric_names, self.feat_losses))
+        return sum(self.feat_losses)
+
+        
+    
+    def __del__(self): self.hooks.remove()
+
+class FeatureLoss2(nn.Module):
+    def __init__(self, layer_wgts=[20,70,10]):
+        super().__init__()
+
+        self.m_feat = models.vgg16_bn(True).features.cuda().eval()
+        requires_grad(self.m_feat, False)
+        blocks = [i-1 for i,o in enumerate(children(self.m_feat)) if isinstance(o,nn.MaxPool2d)]
+        layer_ids = blocks[2:5]
+        self.loss_features = [self.m_feat[i] for i in layer_ids]
+        self.hooks = hook_outputs(self.loss_features, detach=False)
+        self.wgts = layer_wgts
+        self.metric_names = ['pixel',] + [f'feat_{i}' for i in range(len(layer_ids))] 
+        self.base_loss = F.l1_loss
+
+    def _make_features(self, x, clone=False):
+        self.m_feat(x)
+        return [(o.clone() if clone else o) for o in self.hooks.stored]
+
+    def forward(self, input, target):
+        out_feat = self._make_features(target, clone=True)
+        in_feat = self._make_features(input)
+        self.feat_losses = [self.base_loss(input,target)]
+        self.feat_losses += [self.base_loss(f_in, f_out)*w
+                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
+        
+        self.metrics = dict(zip(self.metric_names, self.feat_losses))
+        return sum(self.feat_losses)
+    
+    def __del__(self): self.hooks.remove()
+
 
 #"Before activations" in ESRGAN paper
 class FeatureLoss(nn.Module):
@@ -33,6 +95,121 @@ class FeatureLoss(nn.Module):
         
         self.metrics = dict(zip(self.metric_names, self.feat_losses))
         return sum(self.feat_losses)
+
+        
+    
+    def __del__(self): self.hooks.remove()
+
+
+
+class PretrainFeatureLoss(nn.Module):
+    def __init__(self, layer_wgts=[5,15,2], gram_wgt:float=5e3):
+        super().__init__()
+        self.gram_wgt = gram_wgt
+        self.m_feat = models.vgg16_bn(True).features.cuda().eval()
+        requires_grad(self.m_feat, False)
+        blocks = [i-2 for i,o in enumerate(children(self.m_feat)) if isinstance(o,nn.MaxPool2d)]
+        layer_ids = blocks[2:5]
+        self.loss_features = [self.m_feat[i] for i in layer_ids]
+        self.hooks = hook_outputs(self.loss_features, detach=False)
+        self.wgts = layer_wgts
+        self.metric_names = ['pixel',] + [f'feat_{i}' for i in range(len(layer_ids))] 
+        self.base_loss = F.l1_loss
+
+    def _gram_matrix(self, x:torch.Tensor):
+        n,c,h,w = x.size()
+        x = x.view(n, c, -1)
+        return (x @ x.transpose(1,2))/(c*h*w)
+
+    def _make_features(self, x, clone=False):
+        self.m_feat(x)
+        return [(o.clone() if clone else o) for o in self.hooks.stored]
+
+    def forward(self, input, target):
+        out_feat = self._make_features(target, clone=True)
+        in_feat = self._make_features(input)
+        self.feat_losses = [self.base_loss(input,target)]
+
+        self.feat_losses += [self.base_loss(self._gram_matrix(f_in), self._gram_matrix(f_out))*w**2 * self.gram_wgt
+                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
+
+        self.feat_losses += [self.base_loss(f_in, f_out)*w
+                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
+        
+        self.metrics = dict(zip(self.metric_names, self.feat_losses))
+        return sum(self.feat_losses)
     
     def __del__(self): self.hooks.remove()
 
+
+#Includes wasserstein loss
+class WassFeatureLoss(nn.Module):
+    def __init__(self, layer_wgts=[5,15,2], wass_wgts=[3.0,0.7,0.01]):
+        super().__init__()
+        self.m_feat = models.vgg16_bn(True).features.cuda().eval()
+        requires_grad(self.m_feat, False)
+        blocks = [i-1 for i,o in enumerate(children(self.m_feat)) if isinstance(o,nn.MaxPool2d)]
+        layer_ids = blocks[2:5]
+        self.loss_features = [self.m_feat[i] for i in layer_ids]
+        self.hooks = hook_outputs(self.loss_features, detach=False)
+        self.wgts = layer_wgts
+        self.wass_wgts = wass_wgts
+        self.metric_names = ['pixel',] + [f'feat_{i}' for i in range(len(layer_ids))] + [f'wass_{i}' for i in range(len(layer_ids))]
+        self.base_loss = F.l1_loss
+
+    def _make_features(self, x, clone=False):
+        self.m_feat(x)
+        return [(o.clone() if clone else o) for o in self.hooks.stored]
+
+    def _calc_2_moments(self, tensor):
+        chans = tensor.shape[1]
+        tensor = tensor.view(1, chans, -1)
+        n = tensor.shape[2] 
+        mu = tensor.mean(2)
+        tensor = (tensor - mu[:,:,None]).squeeze(0)
+        #Prevents nasty bug that happens very occassionally- divide by zero.  Why such things happen?
+        if n == 0: return None, None
+        cov = torch.mm(tensor, tensor.t()) / float(n) 
+        return mu, cov
+
+    def _get_style_vals(self, tensor):
+        mean, cov = self._calc_2_moments(tensor) 
+        if mean is None:
+            return None, None, None
+        eigvals, eigvects = torch.symeig(cov, eigenvectors=True)
+        eigroot_mat = torch.diag(torch.sqrt(eigvals.clamp(min=0)))     
+        root_cov = torch.mm(torch.mm(eigvects, eigroot_mat), eigvects.t())  
+        tr_cov = eigvals.clamp(min=0).sum() 
+        return mean, tr_cov, root_cov
+
+    def _calc_l2wass_dist(self, mean_stl, tr_cov_stl, root_cov_stl, mean_synth, cov_synth):
+        tr_cov_synth = torch.symeig(cov_synth, eigenvectors=True)[0].clamp(min=0).sum()
+        mean_diff_squared = (mean_stl - mean_synth).pow(2).sum()
+        cov_prod = torch.mm(torch.mm(root_cov_stl, cov_synth), root_cov_stl)
+        var_overlap = torch.sqrt(torch.symeig(cov_prod, eigenvectors=True)[0].clamp(min=0)+1e-8).sum()
+        dist = mean_diff_squared + tr_cov_stl + tr_cov_synth - 2*var_overlap
+        return dist
+
+    def _single_wass_loss(self, pred, targ):
+        mean_test, tr_cov_test, root_cov_test = targ
+        mean_synth, cov_synth = self._calc_2_moments(pred)
+        loss = self._calc_l2wass_dist(mean_test, tr_cov_test, root_cov_test, mean_synth, cov_synth)
+        return loss
+    
+    def forward(self, input, target):
+        out_feat = self._make_features(target, clone=True)
+        in_feat = self._make_features(input)
+        self.feat_losses = [self.base_loss(input,target)]
+        self.feat_losses += [self.base_loss(f_in, f_out)*w
+                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]
+        
+        styles = [self._get_style_vals(i) for i in out_feat]
+
+        if styles[0][0] is not None:
+            self.feat_losses += [self._single_wass_loss(f_pred, f_targ)*w
+                                for f_pred, f_targ, w in zip(in_feat, styles, self.wass_wgts)]
+        
+        self.metrics = dict(zip(self.metric_names, self.feat_losses))
+        return sum(self.feat_losses)
+    
+    def __del__(self): self.hooks.remove()

+ 0 - 402
fasterai/tensorboard.py

@@ -1,402 +0,0 @@
-"Provides convenient callbacks for Learners that write model images, metrics/losses, stats and histograms to Tensorboard"
-from fastai.basic_train import Learner
-from fastai.basic_data import DatasetType, DataBunch
-from fastai.vision import Image
-from fastai.callbacks import LearnerCallback
-from fastai.core import *
-from fastai.torch_core import *
-from threading import Thread, Event
-from time import sleep
-from queue import Queue
-import statistics
-import torchvision.utils as vutils
-from abc import ABC, abstractmethod
-from tensorboardX import SummaryWriter
-
-
-__all__=['LearnerTensorboardWriter', 'GANTensorboardWriter', 'ImageGenTensorboardWriter']
-
-
-class LearnerTensorboardWriter(LearnerCallback):
-    def __init__(self, learn:Learner, base_dir:Path, name:str, loss_iters:int=25, hist_iters:int=500, stats_iters:int=100):
-        super().__init__(learn=learn)
-        self.base_dir = base_dir
-        self.name = name
-        log_dir = base_dir/name
-        self.tbwriter = SummaryWriter(log_dir=str(log_dir))
-        self.loss_iters = loss_iters
-        self.hist_iters = hist_iters
-        self.stats_iters = stats_iters
-        self.hist_writer = HistogramTBWriter()
-        self.stats_writer = ModelStatsTBWriter()
-        self.data = None
-        self.metrics_root = '/metrics/'
-        self._update_batches_if_needed()
-
-    def _update_batches_if_needed(self):
-        # one_batch function is extremely slow with large datasets.  This is an optimization.
-        # Note that also we want to always show the same batches so we can see changes 
-        # in tensorboard
-        update_batches = self.data is not self.learn.data
-
-        if update_batches:
-            self.data = self.learn.data
-            self.trn_batch = self.learn.data.one_batch(
-                ds_type=DatasetType.Train, detach=True, denorm=False, cpu=False)
-            self.val_batch = self.learn.data.one_batch(
-                ds_type=DatasetType.Valid, detach=True, denorm=False, cpu=False)
-
-    def _write_model_stats(self, iteration:int):
-        self.stats_writer.write(
-            model=self.learn.model, iteration=iteration, tbwriter=self.tbwriter)
-
-    def _write_training_loss(self, iteration:int, last_loss:Tensor):
-        scalar_value = to_np(last_loss)
-        tag = self.metrics_root + 'train_loss'
-        self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration)
-
-    def _write_weight_histograms(self, iteration:int):
-        self.hist_writer.write(
-            model=self.learn.model, iteration=iteration, tbwriter=self.tbwriter)
-
-    #TODO:  Relying on a specific hardcoded start_idx here isn't great.  Is there a better solution?
-    def _write_metrics(self, iteration:int, last_metrics:MetricsList, start_idx:int=2):
-        recorder = self.learn.recorder
-
-        for i, name in enumerate(recorder.names[start_idx:]):
-            if len(last_metrics) < i+1: return
-            scalar_value = last_metrics[i]
-            tag = self.metrics_root + name
-            self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration)
-
-    def on_batch_end(self, last_loss:Tensor, iteration:int, **kwargs):
-        if iteration == 0: return
-        self._update_batches_if_needed()
-
-        if iteration % self.loss_iters == 0:
-            self._write_training_loss(iteration=iteration, last_loss=last_loss)
-
-        if iteration % self.hist_iters == 0:
-            self._write_weight_histograms(iteration=iteration)
-
-    # Doing stuff here that requires gradient info, because they get zeroed out afterwards in training loop
-    def on_backward_end(self, iteration:int, **kwargs):
-        if iteration == 0: return
-        self._update_batches_if_needed()
-
-        if iteration % self.stats_iters == 0:
-            self._write_model_stats(iteration=iteration)
-
-    def on_epoch_end(self, last_metrics:MetricsList, iteration:int, **kwargs):
-        self._write_metrics(iteration=iteration, last_metrics=last_metrics)
-
-# TODO:  We're overriding almost everything here.  Seems like a good idea to question that ("is a" vs "has a")
-class GANTensorboardWriter(LearnerTensorboardWriter):
-    def __init__(self, learn:Learner, base_dir:Path, name:str, loss_iters:int=25, hist_iters:int=500,
-                 stats_iters:int=100, visual_iters:int=100):
-        super().__init__(learn=learn, base_dir=base_dir, name=name, loss_iters=loss_iters,
-                         hist_iters=hist_iters, stats_iters=stats_iters)
-        self.visual_iters = visual_iters
-        self.img_gen_vis = ImageTBWriter()
-        self.gen_stats_updated = True
-        self.crit_stats_updated = True
-
-    # override
-    def _write_weight_histograms(self, iteration:int):
-        trainer = self.learn.gan_trainer
-        generator = trainer.generator
-        critic = trainer.critic
-        self.hist_writer.write(
-            model=generator, iteration=iteration, tbwriter=self.tbwriter, name='generator')
-        self.hist_writer.write(
-            model=critic, iteration=iteration, tbwriter=self.tbwriter, name='critic')
-
-    # override
-    def _write_model_stats(self, iteration:int):
-        trainer = self.learn.gan_trainer
-        generator = trainer.generator
-        critic = trainer.critic
-
-        # Don't want to write stats when model is not iterated on and hence has zeroed out gradients
-        gen_mode = trainer.gen_mode
-
-        if gen_mode and not self.gen_stats_updated:
-            self.stats_writer.write(
-                model=generator, iteration=iteration, tbwriter=self.tbwriter, name='gen_model_stats')
-            self.gen_stats_updated = True
-
-        if not gen_mode and not self.crit_stats_updated:
-            self.stats_writer.write(
-                model=critic, iteration=iteration, tbwriter=self.tbwriter, name='crit_model_stats')
-            self.crit_stats_updated = True
-
-    # override
-    def _write_training_loss(self, iteration:int, last_loss:Tensor):
-        trainer = self.learn.gan_trainer
-        recorder = trainer.recorder
-
-        if len(recorder.losses) > 0:
-            scalar_value = to_np((recorder.losses[-1:])[0])
-            tag = self.metrics_root + 'train_loss'
-            self.tbwriter.add_scalar(tag=tag, scalar_value=scalar_value, global_step=iteration)
-
-    def _write(self, iteration:int):
-        trainer = self.learn.gan_trainer
-        #TODO:  Switching gen_mode temporarily seems a bit hacky here.  Certainly not a good side-effect.  Is there a better way?
-        gen_mode = trainer.gen_mode
-
-        try:
-            trainer.switch(gen_mode=True)
-            self.img_gen_vis.write(learn=self.learn, trn_batch=self.trn_batch, val_batch=self.val_batch,
-                                                    iteration=iteration, tbwriter=self.tbwriter)
-        finally:                                      
-            trainer.switch(gen_mode=gen_mode)
-
-    # override
-    def on_batch_end(self, iteration:int, **kwargs):
-        super().on_batch_end(iteration=iteration, **kwargs)
-        if iteration == 0: return
-        if iteration % self.visual_iters == 0:
-            self._write(iteration=iteration)
-
-    # override
-    def on_backward_end(self, iteration:int, **kwargs):
-        if iteration == 0: return
-        self._update_batches_if_needed()
-
-        #TODO:  This could perhaps be implemented as queues of requests instead but that seemed like overkill. 
-        # But I'm not the biggest fan of maintaining these boolean flags either... Review pls.
-        if iteration % self.stats_iters == 0:
-            self.gen_stats_updated = False
-            self.crit_stats_updated = False
-
-        if not (self.gen_stats_updated and self.crit_stats_updated):
-            self._write_model_stats(iteration=iteration)
-
-
-class ImageGenTensorboardWriter(LearnerTensorboardWriter):
-    def __init__(self, learn:Learner, base_dir:Path, name:str, loss_iters:int=25, hist_iters:int=500,
-                 stats_iters: int = 100, visual_iters: int = 100):
-        super().__init__(learn=learn, base_dir=base_dir, name=name, loss_iters=loss_iters, hist_iters=hist_iters,
-                         stats_iters=stats_iters)
-        self.visual_iters = visual_iters
-        self.img_gen_vis = ImageTBWriter()
-
-    def _write(self, iteration:int):
-        self.img_gen_vis.write(learn=self.learn, trn_batch=self.trn_batch, val_batch=self.val_batch,
-                                                  iteration=iteration, tbwriter=self.tbwriter)
-
-    # override
-    def on_batch_end(self, iteration:int, **kwargs):
-        super().on_batch_end(iteration=iteration, **kwargs)
-        if iteration == 0: return
-
-        if iteration % self.visual_iters == 0:
-            self._write(iteration=iteration)
-
-
-#------PRIVATE-----------
-
-class TBWriteRequest(ABC):
-    def __init__(self, tbwriter: SummaryWriter, iteration:int):
-        super().__init__()
-        self.tbwriter = tbwriter
-        self.iteration = iteration
-
-    @abstractmethod
-    def write(self):
-        pass   
-
-
-# SummaryWriter writes tend to block quite a bit.  This gets around that and greatly boosts performance.
-# Not all tensorboard writes are using this- just the ones that take a long time.  Note that the 
-# SummaryWriter does actually use a threadsafe consumer/producer design ultimately to write to Tensorboard, 
-# so writes done outside of this async loop should be fine.
-class AsyncTBWriter():
-    def __init__(self):
-        super().__init__()
-        self.stop_request = Event()
-        self.queue = Queue()
-        self.thread = Thread(target=self._queue_processor, daemon=True)
-        self.thread.start()
-
-    def request_write(self, request: TBWriteRequest):
-        if self.stop_request.isSet():
-            raise Exception('Close was already called!  Cannot perform this operation.')
-        self.queue.put(request)
-
-    def _queue_processor(self):
-        while not self.stop_request.isSet():
-            while not self.queue.empty():
-                request = self.queue.get()
-                request.write()
-            sleep(0.2)
-
-    #Provided this to stop thread explicitly or by context management (with statement) but thread should end on its own 
-    # upon program exit, due to being a daemon.  So using this is probably unecessary.
-    def close(self):
-        self.stop_request.set()
-        self.thread.join()
-
-    def __enter__(self):
-        # Nothing to do, thread already started.  Could start thread here to enforce use of context manager 
-        # (but that sounds like a pain and a bit unweildy and unecessary for actual usage)
-        pass
-
-    def __exit__(self, exc_type, exc_value, traceback):
-        self.close()
-
-asyncTBWriter = AsyncTBWriter() 
-
-class ModelImageSet():
-    @staticmethod
-    def get_list_from_model(learn:Learner, ds_type:DatasetType, batch:Tuple)->[]:
-        image_sets = []
-        x,y = batch[0],batch[1]
-        preds = learn.pred_batch(ds_type=ds_type, batch=(x,y), reconstruct=True)
-        
-        for orig_px, real_px, gen in zip(x,y,preds):
-            orig = Image(px=orig_px)
-            real = Image(px=real_px)
-            image_set = ModelImageSet(orig=orig, real=real, gen=gen)
-            image_sets.append(image_set)
-
-        return image_sets  
-
-    def __init__(self, orig:Image, real:Image, gen:Image):
-        self.orig = orig
-        self.real = real
-        self.gen = gen
-
-
-class HistogramTBRequest(TBWriteRequest):
-    def __init__(self, model:nn.Module, iteration:int, tbwriter:SummaryWriter, name:str):
-        super().__init__(tbwriter=tbwriter, iteration=iteration)
-        self.params = [(name, values.clone().detach().cpu()) for (name, values) in model.named_parameters()]
-        self.name = name
-
-    # override
-    def write(self):
-        try:
-            for param_name, values in self.params:
-                tag = self.name + '/weights/' + param_name
-                self.tbwriter.add_histogram(tag=tag, values=values, global_step=self.iteration)
-        except Exception as e:
-            print(("Failed to write model histograms to Tensorboard:  {0}").format(e))
-
-#If this isn't done async then this is sloooooow
-class HistogramTBWriter():
-    def __init__(self):
-        super().__init__()
-
-    def write(self, model:nn.Module, iteration:int, tbwriter:SummaryWriter, name:str='model'):
-        request = HistogramTBRequest(model=model, iteration=iteration, tbwriter=tbwriter, name=name)
-        asyncTBWriter.request_write(request)
-
-class ModelStatsTBRequest(TBWriteRequest):
-    def __init__(self, model:nn.Module, iteration:int, tbwriter:SummaryWriter, name:str):
-        super().__init__(tbwriter=tbwriter, iteration=iteration)
-        self.gradients = [x.grad.clone().detach().cpu() for x in model.parameters() if x.grad is not None]
-        self.name = name
-        self.gradients_root = '/gradients/'
-
-    # override
-    def write(self):
-        try:
-            if len(self.gradients) == 0: return
-
-            gradient_nps = [to_np(x.data) for x in self.gradients]
-            avg_norm = sum(x.data.norm() for x in self.gradients)/len(self.gradients)
-            self.tbwriter.add_scalar(
-                tag=self.name + self.gradients_root + 'avg_norm', scalar_value=avg_norm, global_step=self.iteration)
-
-            median_norm = statistics.median(x.data.norm() for x in self.gradients)
-            self.tbwriter.add_scalar(
-                tag=self.name + self.gradients_root + 'median_norm', scalar_value=median_norm, global_step=self.iteration)
-
-            max_norm = max(x.data.norm() for x in self.gradients)
-            self.tbwriter.add_scalar(
-                tag=self.name + self.gradients_root + 'max_norm', scalar_value=max_norm, global_step=self.iteration)
-
-            min_norm = min(x.data.norm() for x in self.gradients)
-            self.tbwriter.add_scalar(
-                tag=self.name + self.gradients_root + 'min_norm', scalar_value=min_norm, global_step=self.iteration)
-
-            num_zeros = sum((np.asarray(x) == 0.0).sum() for x in gradient_nps)
-            self.tbwriter.add_scalar(
-                tag=self.name + self.gradients_root + 'num_zeros', scalar_value=num_zeros, global_step=self.iteration)
-
-            avg_gradient = sum(x.data.mean() for x in self.gradients)/len(self.gradients)
-            self.tbwriter.add_scalar(
-                tag=self.name + self.gradients_root + 'avg_gradient', scalar_value=avg_gradient, global_step=self.iteration)
-
-            median_gradient = statistics.median(x.data.median() for x in self.gradients)
-            self.tbwriter.add_scalar(
-                tag=self.name + self.gradients_root + 'median_gradient', scalar_value=median_gradient, global_step=self.iteration)
-
-            max_gradient = max(x.data.max() for x in self.gradients)
-            self.tbwriter.add_scalar(
-                tag=self.name + self.gradients_root + 'max_gradient', scalar_value=max_gradient, global_step=self.iteration)
-
-            min_gradient = min(x.data.min() for x in self.gradients)
-            self.tbwriter.add_scalar(
-                tag=self.name + self.gradients_root + 'min_gradient', scalar_value=min_gradient, global_step=self.iteration)
-        except Exception as e:
-            print(("Failed to write model stats to Tensorboard:  {0}").format(e))
-
-
-class ModelStatsTBWriter():
-    def write(self, model:nn.Module, iteration:int, tbwriter:SummaryWriter, name:str='model_stats'):
-        request = ModelStatsTBRequest(model=model, iteration=iteration, tbwriter=tbwriter, name=name)
-        asyncTBWriter.request_write(request)
-
-
-class ImageTBRequest(TBWriteRequest):
-    def __init__(self, learn:Learner, batch:Tuple, iteration:int, tbwriter:SummaryWriter, ds_type:DatasetType):
-        super().__init__(tbwriter=tbwriter, iteration=iteration)
-        self.image_sets = ModelImageSet.get_list_from_model(learn=learn, batch=batch, ds_type=ds_type)
-        self.ds_type = ds_type
-
-    # override
-    def write(self):
-        try:
-            orig_images = []
-            gen_images = []
-            real_images = []
-
-            for image_set in self.image_sets:
-                orig_images.append(image_set.orig.px)
-                gen_images.append(image_set.gen.px)
-                real_images.append(image_set.real.px)
-
-            prefix = self.ds_type.name
-
-            self.tbwriter.add_image(
-                tag=prefix + ' orig images', img_tensor=vutils.make_grid(orig_images, normalize=True), 
-                global_step=self.iteration)
-            self.tbwriter.add_image(
-                tag=prefix + ' gen images', img_tensor=vutils.make_grid(gen_images, normalize=True), 
-                global_step=self.iteration)
-            self.tbwriter.add_image(
-                tag=prefix + ' real images', img_tensor=vutils.make_grid(real_images, normalize=True), 
-                global_step=self.iteration)
-        except Exception as e:
-            print(("Failed to write images to Tensorboard:  {0}").format(e))
-
-#If this isn't done async then this is noticeably slower
-class ImageTBWriter():
-    def __init__(self):
-        super().__init__()
-
-    def write(self, learn:Learner, trn_batch:Tuple, val_batch:Tuple, iteration:int, tbwriter:SummaryWriter):
-        self._write_for_dstype(learn=learn, batch=val_batch, iteration=iteration,
-                             tbwriter=tbwriter, ds_type=DatasetType.Valid)
-        self._write_for_dstype(learn=learn, batch=trn_batch, iteration=iteration,
-                             tbwriter=tbwriter, ds_type=DatasetType.Train)
-
-    def _write_for_dstype(self, learn:Learner, batch:Tuple, iteration:int, tbwriter:SummaryWriter, ds_type:DatasetType):
-        request = ImageTBRequest(learn=learn, batch=batch, iteration=iteration, tbwriter=tbwriter, ds_type=ds_type)
-        asyncTBWriter.request_write(request)
-
-
-

+ 72 - 0
fasterai/unet.py

@@ -2,6 +2,8 @@ from fastai.layers import *
 from .layers import *
 from fastai.torch_core import *
 from fastai.callbacks.hooks import *
+from fastai.vision import *
+
 
 #The code below is meant to be merged into fastaiv1 ideally
 
@@ -100,3 +102,73 @@ class CustomDynamicUnet(SequentialEx):
 
 
 
+
+#------------------------------------------------------
+class CustomUnetBlock2(nn.Module):
+    "A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
+    def __init__(self, up_in_c:int, x_in_c:int, n_out:int,  hook:Hook, final_div:bool=True, blur:bool=False, leaky:float=None,
+                 self_attention:bool=False,  **kwargs):
+        super().__init__()
+        self.hook = hook
+        up_out = x_out = n_out//2
+        self.shuf = CustomPixelShuffle_ICNR(up_in_c, up_out, blur=blur, leaky=leaky, **kwargs)
+        self.bn = batchnorm_2d(x_in_c)
+        ni = up_out + x_in_c
+        self.conv = custom_conv_layer(ni, x_out, leaky=leaky, self_attention=self_attention, **kwargs)
+        self.relu = relu(leaky=leaky)
+
+    def forward(self, up_in:Tensor) -> Tensor:
+        s = self.hook.stored
+        up_out = self.shuf(up_in)
+        ssh = s.shape[-2:]
+        if ssh != up_out.shape[-2:]:
+            up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest')
+        cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
+        return self.conv(cat_x)
+
+
+class CustomDynamicUnet2(SequentialEx):
+    "Create a U-Net from a given architecture."
+    def __init__(self, encoder:nn.Module, n_classes:int, blur:bool=False, blur_final=True, self_attention:bool=False,
+                 y_range:Optional[Tuple[float,float]]=None, last_cross:bool=True, bottle:bool=False,
+                 norm_type:Optional[NormType]=NormType.Batch, nf_factor:int=1, **kwargs):
+        
+        nf = 512 * nf_factor
+        extra_bn =  norm_type == NormType.Spectral
+        imsize = (256,256)
+        sfs_szs = model_sizes(encoder, size=imsize)
+        sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
+        self.sfs = hook_outputs([encoder[i] for i in sfs_idxs])
+        x = dummy_eval(encoder, imsize).detach()
+
+        ni = sfs_szs[-1][1]
+        middle_conv = nn.Sequential(custom_conv_layer(ni, ni*2, norm_type=norm_type, extra_bn=extra_bn, **kwargs),
+                                    custom_conv_layer(ni*2, ni, norm_type=norm_type, extra_bn=extra_bn, **kwargs)).eval()
+        x = middle_conv(x)
+        layers = [encoder, batchnorm_2d(ni), nn.ReLU(), middle_conv]
+
+        for i,idx in enumerate(sfs_idxs):
+            not_final = i!=len(sfs_idxs)-1
+            up_in_c, x_in_c = int(x.shape[1]), int(sfs_szs[idx][1])
+            do_blur = blur and (not_final or blur_final)
+            sa = self_attention and (i==len(sfs_idxs)-3)
+
+            n_out = nf if not_final else nf//2
+
+            unet_block = CustomUnetBlock2(up_in_c, x_in_c, n_out, self.sfs[i], final_div=not_final, blur=blur, self_attention=sa,
+                                   norm_type=norm_type, extra_bn=extra_bn, **kwargs).eval()
+            layers.append(unet_block)
+            x = unet_block(x)
+
+        ni = x.shape[1]
+        if imsize != sfs_szs[0][-2:]: layers.append(PixelShuffle_ICNR(ni, **kwargs))
+        if last_cross:
+            layers.append(MergeLayer(dense=True))
+            ni += in_channels(encoder)
+            layers.append(res_block(ni, bottle=bottle, norm_type=norm_type, **kwargs))
+        layers += [custom_conv_layer(ni, n_classes, ks=1, use_activ=False, norm_type=norm_type)]
+        if y_range is not None: layers.append(SigmoidRange(*y_range))
+        super().__init__(*layers)
+
+    def __del__(self):
+        if hasattr(self, "sfs"): self.sfs.remove()

+ 8 - 1
fasterai/visualize.py

@@ -4,7 +4,7 @@ from matplotlib.axes import Axes
 from matplotlib.figure import Figure
 from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
 from .filters import IFilter, MasterFilter, ColorizerFilter
-from .generators import colorize_gen_inference
+from .generators import colorize_gen_inference, colorize_gen_inference2
 from IPython.display import display
 from tensorboardX import SummaryWriter
 from scipy import misc
@@ -59,6 +59,13 @@ def get_colorize_visualizer(root_folder:Path=Path('./'), weights_name:str='color
     vis = ModelImageVisualizer(filtr, results_dir=results_dir)
     return vis
 
+def get_colorize_visualizer2(root_folder:Path=Path('./'), weights_name:str='colorize_gen', 
+        results_dir = 'result_images', nf_factor:int=1, render_factor:int=21)->ModelImageVisualizer:
+    learn = colorize_gen_inference2(root_folder=root_folder, weights_name=weights_name, nf_factor=nf_factor)
+    filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
+    vis = ModelImageVisualizer(filtr, results_dir=results_dir)
+    return vis
+
 
 
 

+ 2 - 0
requirements.txt

@@ -1,2 +1,4 @@
 fastai>=1.0.42
 tensorboardX>=1.4
+ffmpeg-python
+youtube-dl