Sfoglia il codice sorgente

Doing overhaul of image colorizer colab and cleaning up video colab; adding save callback for gan training

Jason Antic 6 anni fa
parent
commit
15884cefd6

+ 25 - 2
ImageColorizerArtistic.ipynb

@@ -7,7 +7,7 @@
    "outputs": [],
    "source": [
     "import os\n",
-    "os.environ['CUDA_VISIBLE_DEVICES']='0' "
+    "os.environ['CUDA_VISIBLE_DEVICES']='2' "
    ]
   },
   {
@@ -42,6 +42,15 @@
     "vis = get_image_colorizer(render_factor=render_factor, artistic=True)"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "vis.plot_transformed_image(\"test_images/poolparty.jpg\", render_factor=38)"
+   ]
+  },
   {
    "cell_type": "code",
    "execution_count": null,
@@ -120,7 +129,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/PushingCart.jpg\", render_factor=30)"
+    "vis.plot_transformed_image(\"test_images/PushingCart.jpg\", render_factor=21)"
    ]
   },
   {
@@ -897,6 +906,20 @@
     "vis.plot_transformed_image(\"test_images/1890Surfer.png\", render_factor=30)"
    ]
   },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
   {
    "cell_type": "code",
    "execution_count": null,

+ 106 - 160
ImageColorizerColab.ipynb

@@ -14,81 +14,71 @@
    "cell_type": "markdown",
    "metadata": {
     "colab_type": "text",
-    "id": "V-nHLm4jWpnV"
+    "id": "663IVxfrpIAb"
    },
    "source": [
-    "# DeOldify on Colab #\n",
+    "#◢ DeOldify - Colorize Your Own Photos!\n",
     "\n",
-    "This notebook allows you to colorize your own images using Google Colab!\n",
+    "##Use this Colab notebook to colorize black & white photos in four simple steps.\n",
+    "1. Specify photo URL- make sure it's a direct link to the photo (with extension of .jpg, .png, etc).\n",
+    "2. Select 'Render Factor'.  Generally, older and lower quality photos will render bettter with lower render factors (14-21 range) while higher quality photos will do better on higher render factors.\n",
+    "3. Colorize your photo with DeOldify\n",
+    "4. Save a copy to your device by right clicking on the rendered image and selecting save.\n",
     "\n",
-    "Special thanks to the that made this possible!\n",
+    "---\n",
     "\n",
-    "Original Author:  Matt Robinson, <matthew67robinson@gmail.com>\n",
+    "####**Credits:**\n",
     "\n",
-    "Additional Contributions: Maria Benavente"
+    "Special thanks to:\n",
+    "\n",
+    "Matt Robinson and María Benavente for pioneering the DeOldify image colab notebook.  \n",
+    "\n",
+    "Dana Kelley for doing things, breaking stuff & having an opinion on everything."
    ]
   },
   {
-   "cell_type": "code",
-   "execution_count": null,
+   "cell_type": "markdown",
    "metadata": {
-    "colab": {
-     "base_uri": "https://localhost:8080/",
-     "height": 119
-    },
-    "colab_type": "code",
-    "id": "-dSDNbBNb-N8",
-    "outputId": "4ef40df9-46dd-44a4-b54a-e720a3eee232"
+    "colab_type": "text",
+    "id": "ZjPqTBNoohK9"
    },
-   "outputs": [],
-   "source": [
-    "!git clone -b FastAIv1 --single-branch https://github.com/jantic/DeOldify.git DeOldify"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
    "source": [
-    "from os import path\n",
-    "import torch\n",
-    "print(torch.__version__)\n",
-    "print(torch.cuda.is_available())"
+    "\n",
+    "\n",
+    "---\n",
+    "\n",
+    "\n",
+    "#◢ Verify Correct Runtime Settings\n",
+    "\n",
+    "**<font color='#FF000'> IMPORTANT </font>**\n",
+    "\n",
+    "In the \"Runtime\" menu for the notebook window, select \"Change runtime type.\" Ensure that the following are selected:\n",
+    "* Runtime Type = Python 3\n",
+    "* Hardware Accelerator = GPU \n"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
-    "colab": {
-     "base_uri": "https://localhost:8080/",
-     "height": 34
-    },
+    "colab": {},
     "colab_type": "code",
-    "id": "k19F34Tsd7CX",
-    "outputId": "81828b10-6678-4eec-ec53-5b0b9d645782"
+    "id": "00_GcC_trpdE"
    },
    "outputs": [],
    "source": [
-    "cd DeOldify"
+    "from os import path\n",
+    "import torch"
    ]
   },
   {
-   "cell_type": "code",
-   "execution_count": null,
+   "cell_type": "markdown",
    "metadata": {
-    "colab": {
-     "base_uri": "https://localhost:8080/",
-     "height": 187
-    },
-    "colab_type": "code",
-    "id": "lhejMeOxghBM",
-    "outputId": "d063f1dc-1286-4355-f8aa-838a7dfc29ee"
+    "colab_type": "text",
+    "id": "gaEJBGDlptEo"
    },
-   "outputs": [],
    "source": [
-    "!pip install PyDrive"
+    "#◢ Git clone and install DeOldify"
    ]
   },
   {
@@ -97,60 +87,31 @@
    "metadata": {
     "colab": {},
     "colab_type": "code",
-    "id": "yODBFi8MgoLZ"
+    "id": "-T-svuHytJ-8"
    },
    "outputs": [],
    "source": [
-    "import os\n",
-    "from pydrive.auth import GoogleAuth\n",
-    "from pydrive.drive import GoogleDrive\n",
-    "from google.colab import auth\n",
-    "from oauth2client.client import GoogleCredentials\n",
-    "from google.colab import drive\n",
-    "from IPython.display import Image\n",
-    "import fastai\n",
-    "from fastai import *\n",
-    "from fasterai.visualize import *\n",
-    "from pathlib import Path\n",
-    "from itertools import repeat\n",
-    "from google.colab import drive\n",
-    "torch.backends.cudnn.benchmark=True"
+    "!git clone -b FastAIv1 --single-branch https://github.com/jantic/DeOldify.git DeOldify\n",
+    "#!git clone https://github.com/jantic/DeOldify.git DeOldify"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
-   "metadata": {
-    "colab": {},
-    "colab_type": "code",
-    "id": "Ow9Qhf4jgrgd"
-   },
+   "metadata": {},
    "outputs": [],
    "source": [
-    "auth.authenticate_user()\n",
-    "gauth = GoogleAuth()\n",
-    "gauth.credentials = GoogleCredentials.get_application_default()\n",
-    "drive = GoogleDrive(gauth)"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {
-    "colab_type": "text",
-    "id": "aVyIYMqrg-SM"
-   },
-   "source": [
-    "Note that the above requires a verification step. It isn't too bad."
+    "cd DeOldify"
    ]
   },
   {
    "cell_type": "markdown",
    "metadata": {
     "colab_type": "text",
-    "id": "Doru7d3rVYr7"
+    "id": "BDFjbNxaadNJ"
    },
    "source": [
-    "With access to your Google Drive, the \"deOldifyImages\" directory will be created. Drop there your personal images, and after the full execution of the notebook find the results at its subdirectory \"results\""
+    "#◢ Setup"
    ]
   },
   {
@@ -159,33 +120,30 @@
    "metadata": {
     "colab": {},
     "colab_type": "code",
-    "id": "sU2yQAfqhNJv"
+    "id": "Lsx7xCXNSVt6"
    },
    "outputs": [],
    "source": [
-    "results_dir=Path('/content/drive/My Drive/deOldifyImages/results')\n",
-    "\n",
-    "#Adjust this if image doesn't look quite right (max 64 on 11GB GPU).  The default here works for most photos.  \n",
-    "#It literally just is a number multiplied by 16 to get the square render resolution.  \n",
-    "#Note that this doesn't affect the resolution of the final output- the output is the same resolution as the input.\n",
-    "#Example:  render_factor=21 => color is rendered at 16x21 = 336x336 px.  \n",
-    "render_factor=36"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {},
-   "source": [
-    "#◢ Artistic vs Stable Model"
+    "!pip install PyDrive\n",
+    "!pip install ffmpeg-python\n",
+    "!pip install youtube-dl\n",
+    "!pip install tensorboardX"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
-   "metadata": {},
+   "metadata": {
+    "colab": {},
+    "colab_type": "code",
+    "id": "MsJa69CMwj3l"
+   },
    "outputs": [],
    "source": [
-    "artistic = True #@param {type:\"boolean\"}"
+    "import fastai\n",
+    "from fasterai.visualize import *\n",
+    "from pathlib import Path\n",
+    "torch.backends.cudnn.benchmark=True"
    ]
   },
   {
@@ -195,38 +153,32 @@
    "outputs": [],
    "source": [
     "!mkdir 'models'\n",
-    "if artistic:\n",
-    "    !wget https://www.dropbox.com/s/zkehq1uwahhbc2o/ColorizeArtistic_gen.pth?dl=0 -O ./models/ColorizeArtistic_gen.pth\n",
-    "else:\n",
-    "    !wget https://www.dropbox.com/s/mwjep3vyqk5mkjc/ColorizeStable_gen.pth?dl=0 -O ./models/ColorizeStable_gen.pth"
+    "!wget wget https://www.dropbox.com/s/zkehq1uwahhbc2o/ColorizeArtistic_gen.pth?dl=0 -O ./models/ColorizeArtistic_gen.pth"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
    "metadata": {
-    "colab": {
-     "base_uri": "https://localhost:8080/",
-     "height": 68
-    },
+    "colab": {},
     "colab_type": "code",
-    "id": "NlOT9IlBHkk7",
-    "outputId": "1bb5cc4d-15a2-4174-a37c-9e11f3a2dce3"
+    "id": "tzHVnegp21hC"
    },
    "outputs": [],
    "source": [
-    "from google.colab import drive\n",
-    "drive.mount('/content/drive')"
+    "colorizer = get_image_colorizer(artistic=True)"
    ]
   },
   {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
+   "cell_type": "markdown",
+   "metadata": {
+    "colab_type": "text",
+    "id": "z5rSDjZbTntY"
+   },
    "source": [
-    "!mkdir \"/content/drive/My Drive/deOldifyImages\"\n",
-    "!mkdir \"/content/drive/My Drive/deOldifyImages/results\""
+    "#◢ Image URL\n",
+    "\n",
+    "Any direct link to an image should do (will end with extension .jpg, .png, etc).  NOTE: If you want to use your own image, upload it first to a site like Imgur. "
    ]
   },
   {
@@ -235,81 +187,74 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis = get_image_colorizer(render_factor=render_factor, artistic=artistic)"
+    "source_url = '' #@param {type:\"string\"}"
    ]
   },
   {
    "cell_type": "markdown",
-   "metadata": {
-    "colab_type": "text",
-    "id": "ZpCf0qbxicVK"
-   },
+   "metadata": {},
    "source": [
-    "Here's an example of colorizing an image downloaded from the internet:"
+    "#◢ Render Factor\n",
+    "\n",
+    "The default value of 35 has been carefully chosen and should work -ok- for most scenarios (but probably won't be the -best-). This determines resolution at which the photo is rendered. Lower resolution will render faster, and colors also tend to look more vibrant.  Older and lower quality photos in particular will generally benefit by lowering the render factor. Higher render factors are often better for higher quality photos."
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "render_factor = 35  #@param {type: \"slider\", min: 7, max: 46}"
+   ]
+  },
+  {
+   "cell_type": "markdown",
    "metadata": {
-    "colab": {
-     "base_uri": "https://localhost:8080/",
-     "height": 204
-    },
-    "colab_type": "code",
-    "id": "8-hQC-AfiCKq",
-    "outputId": "f52e94a1-48bd-4b27-bc21-2306ee85ef9d"
+    "colab_type": "text",
+    "id": "sUQrbSYipiJn"
    },
-   "outputs": [],
    "source": [
-    "!wget \"https://media.githubusercontent.com/media/jantic/DeOldify/master/test_images/TV1930s.jpg\" -O \"family_TV.jpg\""
+    "#◢ Run DeOldify"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
-   "metadata": {
-    "colab": {
-     "base_uri": "https://localhost:8080/",
-     "height": 434
-    },
-    "colab_type": "code",
-    "id": "O6kfUN0GiJsq",
-    "outputId": "3e37f84d-6d27-4e4e-b7b6-9feae8be40ae"
-   },
+   "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image('family_TV.jpg')"
+    "if source_url is not None and source_url !='':\n",
+    "    !wget $source_url -O \"family_TV.jpg\"\n",
+    "    colorizer.plot_transformed_image_from_url(path=\"test_images/image.jpg\", url=source_url, render_factor=render_factor)\n",
+    "else:\n",
+    "    print('Provide an image url and try again.')"
    ]
   },
   {
    "cell_type": "markdown",
    "metadata": {
     "colab_type": "text",
-    "id": "dBY1N_bYaIxq"
+    "id": "A5WMS_GgP4fm"
    },
    "source": [
-    "If you want to colorize pictures from your Google drive, drop them in a directory named deOldifyImages (in the root of your drive) and the next cell will colorize all of them and save the resulting images in deOldifyImages/results."
+    "#◢ Download\n",
+    "\n",
+    "* In the menu to the left, click **Files**\n",
+    "* If you don't see the 'DeOldify' folder, click \"Refresh\"\n",
+    "* By default, rendered image will be in /DeOldify/result_images/"
    ]
   },
   {
-   "cell_type": "code",
-   "execution_count": null,
+   "cell_type": "markdown",
    "metadata": {
-    "colab": {
-     "base_uri": "https://localhost:8080/",
-     "height": 559
-    },
-    "colab_type": "code",
-    "id": "vL7775DWaFJ4",
-    "outputId": "cdb35172-d43a-4d45-f23c-d86c993aaf9b"
+    "colab_type": "text",
+    "id": "X7Ycv_Y9xAHp"
    },
-   "outputs": [],
    "source": [
-    "for img in os.listdir(\"/content/drive/My Drive/deOldifyImages/\"):\n",
-    "  img_path = str(\"/content/drive/My Drive/deOldifyImages/\") + img\n",
-    "  if os.path.isfile(img_path):\n",
-    "    vis.plot_transformed_image(img_path)"
+    "---\n",
+    "#⚙ Recommended image sources \n",
+    "* [/r/TheWayWeWere](https://www.reddit.com/r/TheWayWeWere/)"
    ]
   }
  ],
@@ -317,8 +262,9 @@
   "accelerator": "GPU",
   "colab": {
    "collapsed_sections": [],
-   "name": "DeOldify_colab.ipynb",
+   "name": "DeOldify-video.ipynb",
    "provenance": [],
+   "toc_visible": true,
    "version": "0.3.2"
   },
   "kernelspec": {

+ 12 - 2
ImageColorizerStable.ipynb

@@ -39,7 +39,17 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis = get_image_colorizer(render_factor=render_factor, artistic=False)"
+    "vis = get_image_colorizer(render_factor=render_factor, artistic=False)\n",
+    "#vis = get_video_colorizer(render_factor=render_factor).vis"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "vis.plot_transformed_image(\"test_images/poolparty.jpg\", render_factor=45)"
    ]
   },
   {
@@ -120,7 +130,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "vis.plot_transformed_image(\"test_images/PushingCart.jpg\", render_factor=30)"
+    "vis.plot_transformed_image(\"test_images/PushingCart.jpg\", render_factor=24)"
    ]
   },
   {

+ 38 - 57
VideoColorizerColab.ipynb

@@ -19,13 +19,14 @@
    "source": [
     "#◢ DeOldify - Not just for photos!\n",
     "\n",
-    "##Use this Colab notebook to colorize black & white videos in three simple steps.\n",
+    "##Use this Colab notebook to colorize black & white videos in four simple steps.\n",
     "1. Specify media URL - YouTube, Twitter, Imgur, etc.\n",
-    "2. Run DeOldify to extract single images from your video or gif. Behind the scenes, the code does the following:\n",
+    "2. Select 'Render Factor'.  Generally, older and lower quality videos will render bettter with lower render factors (14-21 range) while higher quality videos will do better on higher render factors.\n",
+    "3. Run DeOldify to extract single images from your video or gif. Behind the scenes, the code does the following:\n",
     "    * Extracts single images from the specified media file.\n",
     "    * Processes the images with [DeOldify](https://github.com/jantic/DeOldify).\n",
     "    * Rebuilds the video from **<font color='#CC0000'>c</font><font color='#CC8800'>o</font><font color='#FFBB00'>l</font><font color='#00DD00'>o</font><font color='#0000FF'>r</font>ized** images.\n",
-    "3. Download the video to your device to view! \n",
+    "4. Download the video to your device to view! \n",
     "\n",
     "_FYI: This notebook is intended as a tool to colorize gifs and short videos, if you are trying to convert longer video you may hit the limit on processing space. Running the Jupyter notebook on your own machine is recommended (and faster) for larger video sizes._\n",
     "\n",
@@ -33,13 +34,9 @@
     "\n",
     "####**Credits:**\n",
     "\n",
-    "Use of this tool is thanks to,\n",
+    "Big special thanks to:\n",
     "\n",
-    "[@citnaj](https://twitter.com/citnaj) for creating DeOldify.\n",
-    "\n",
-    "[@tradica](https://twitter.com/tradica) for initial video and CoLab work.\n",
-    "\n",
-    "Matt Robinson for his [notebook](https://colab.research.google.com/github/jantic/DeOldify/blob/master/DeOldify_colab.ipynb) which helped make DeOldify approachable.\n",
+    "Robert Bell for all his work on the video Colab notebook, and paving the way to video in DeOldify!\n",
     "\n",
     "Dana Kelley for doing things, breaking stuff & having an opinion on everything."
    ]
@@ -76,9 +73,7 @@
    "outputs": [],
    "source": [
     "from os import path\n",
-    "import torch\n",
-    "print(torch.__version__)\n",
-    "print(torch.cuda.is_available())"
+    "import torch"
    ]
   },
   {
@@ -153,8 +148,6 @@
     "import fastai\n",
     "from fasterai.visualize import *\n",
     "from pathlib import Path\n",
-    "from google.colab import drive\n",
-    "from google.colab import files\n",
     "torch.backends.cudnn.benchmark=True"
    ]
   },
@@ -169,12 +162,16 @@
    ]
   },
   {
-   "cell_type": "markdown",
+   "cell_type": "code",
+   "execution_count": null,
    "metadata": {},
+   "outputs": [],
    "source": [
-    "#◢ Render Factor\n",
-    "\n",
-    "The default value of 21 has been carefully chosen and should work for most scenarios. This determines resolution at which video is rendered. Lower resolution will render faster, and colors also tend to look more vibrant.  Older and lower quality film in particular will generally benefit by lowering the render factor. Higher render factors are often better for higher quality videos and inconsistencies (flashy render) will generally be reduced, but the colors may get slightly washed out. "
+    "file_name = 'video.mp4'\n",
+    "source_dir = './video/source/'\n",
+    "source_path = source_dir + file_name\n",
+    "dest_dir = './video/result/'\n",
+    "dest_path = dest_dir + file_name"
    ]
   },
   {
@@ -183,28 +180,20 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "render_factor = 21  #@param {type: \"slider\", min: 5, max: 45}"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "metadata": {
-    "colab_type": "text",
-    "id": "z5rSDjZbTntY"
-   },
-   "source": [
-    "#◢ Specify URL\n",
-    "\n",
-    "YouTube, Imgur, Twitter, Reddit ... files of type .gif, .gifv and .mp4 work.  NOTE: If you want to use your own source material, upload it first to a site like Imgur. "
+    "!mkdir file_dir"
    ]
   },
   {
    "cell_type": "code",
    "execution_count": null,
-   "metadata": {},
+   "metadata": {
+    "colab": {},
+    "colab_type": "code",
+    "id": "tzHVnegp21hC"
+   },
    "outputs": [],
    "source": [
-    "source_url = '' #@param {type:\"string\"}"
+    "colorizer = get_video_colorizer()"
    ]
   },
   {
@@ -214,9 +203,9 @@
     "id": "z5rSDjZbTntY"
    },
    "source": [
-    "#◢ Additional Parameters\n",
+    "#◢ Video URL\n",
     "\n",
-    "It's not necessary to change the following, just run them as-is."
+    "YouTube, Imgur, Twitter, Reddit ... files of type .gif, .gifv and .mp4 work.  NOTE: If you want to use your own video, upload it first to a site like Imgur. "
    ]
   },
   {
@@ -225,11 +214,16 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "file_name = 'video.mp4'\n",
-    "source_dir = './video/source/'\n",
-    "source_path = source_dir + file_name\n",
-    "dest_dir = './video/result/'\n",
-    "dest_path = dest_dir + file_name"
+    "source_url = '' #@param {type:\"string\"}"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#◢ Render Factor\n",
+    "\n",
+    "The default value of 21 has been carefully chosen and should work -ok- for most scenarios (but probably won't be the -best-). This determines resolution at which video is rendered. Lower resolution will render faster, and colors also tend to look more vibrant.  Older and lower quality film in particular will generally benefit by lowering the render factor. Higher render factors are often better for higher quality videos and inconsistencies (flashy render) will generally be reduced, but the colors may get slightly washed out.  "
    ]
   },
   {
@@ -238,7 +232,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "!mkdir file_dir"
+    "render_factor = 21  #@param {type: \"slider\", min: 7, max: 45}"
    ]
   },
   {
@@ -248,20 +242,7 @@
     "id": "sUQrbSYipiJn"
    },
    "source": [
-    "#◢ Run DeOldify"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {
-    "colab": {},
-    "colab_type": "code",
-    "id": "tzHVnegp21hC"
-   },
-   "outputs": [],
-   "source": [
-    "colorizer = get_video_colorizer(render_factor=render_factor)"
+    "#◢ Colorize Video"
    ]
   },
   {
@@ -271,9 +252,9 @@
    "outputs": [],
    "source": [
     "if source_url is not None and source_url !='':\n",
-    "    colorizer.colorize_from_url(source_url, file_name)\n",
+    "    colorizer.colorize_from_url(source_url, file_name, render_factor)\n",
     "else:\n",
-    "    print('Provide a source url and try again.')"
+    "    print('Provide a video url and try again.')"
    ]
   },
   {

+ 1 - 1
fasterai/dataset.py

@@ -11,7 +11,7 @@ def get_colorize_data(sz:int, bs:int, crappy_path:Path, good_path:Path, random_s
 
     src = (ImageImageList.from_folder(crappy_path)
         .use_partial_data(sample_pct=keep_pct, seed=random_seed)
-        .random_split_by_pct(0.1, seed=random_seed))
+        .split_by_rand_pct(0.1, seed=random_seed))
 
     data = (src.label_from_func(lambda x: good_path/x.relative_to(crappy_path))
         .transform(get_transforms(max_zoom=1.2, max_lighting=0.5, max_warp=0.25, xtra_tfms=xtra_tfms), size=sz, tfm_y=True)

+ 21 - 0
fasterai/save.py

@@ -0,0 +1,21 @@
+from fastai.torch_core import *
+from fastai.basic_data import DataBunch
+from fastai.callback import *
+from fastai.basic_train import Learner, LearnerCallback
+from fastai.vision.gan import GANLearner
+
+class GANSaveCallback(LearnerCallback):
+    "A `LearnerCallback` that saves history of metrics while training `learn` into CSV `filename`."
+    def __init__(self, learn:GANLearner, learn_gen:Learner, filename:str, save_iters:int=1000): 
+        super().__init__(learn)
+        self.learn_gen, self.filename, self.save_iters = learn_gen, filename, save_iters
+
+
+    def on_batch_end(self, iteration:int, epoch:int, **kwargs)->None:
+        if iteration == 0: return
+        if iteration % self.save_iters == 0: 
+            self._save_gen_learner(iteration=iteration, epoch=epoch)
+
+    def _save_gen_learner(self, iteration:int, epoch:int):
+        fn = self.filename + '_' + str(epoch) + '_' + str(iteration)
+        self.learn_gen.save(fn)

+ 21 - 13
fasterai/visualize.py

@@ -12,6 +12,8 @@ from PIL import Image
 import ffmpeg
 import youtube_dl
 import gc
+import requests
+from io import BytesIO
 
 
 class ModelImageVisualizer():
@@ -27,6 +29,12 @@ class ModelImageVisualizer():
     def _open_pil_image(self, path:Path)->Image:
         return PIL.Image.open(path).convert('RGB')
 
+    def plot_transformed_image_from_url(self, path:str, url:str, figsize:(int,int)=(20,20), render_factor:int=None)->Image:
+        response = requests.get(url)
+        img = Image.open(BytesIO(response.content))
+        img.save(path)
+        return self.plot_transformed_image(path=path, figsize=figsize, render_factor=render_factor)
+
     def plot_transformed_image(self, path:str, figsize:(int,int)=(20,20), render_factor:int=None)->Image:
         path = Path(path)
         result = self.get_transformed_image(path, render_factor)
@@ -101,7 +109,7 @@ class VideoColorizer():
         ffmpeg.input(str(source_path)).output(str(bwframe_path_template), format='image2', vcodec='mjpeg', qscale=0).run(capture_stdout=True)
 
 
-    def _colorize_raw_frames(self, source_path:Path):
+    def _colorize_raw_frames(self, source_path:Path, render_factor:int=None):
         colorframes_folder = self.colorframes_root/(source_path.stem)
         colorframes_folder.mkdir(parents=True, exist_ok=True)
         self._purge_images(colorframes_folder)
@@ -110,7 +118,7 @@ class VideoColorizer():
         for img in progress_bar(os.listdir(str(bwframes_folder))):
             img_path = bwframes_folder/img
             if os.path.isfile(str(img_path)):
-                color_image = self.vis.get_transformed_image(str(img_path))
+                color_image = self.vis.get_transformed_image(str(img_path), render_factor=render_factor)
                 color_image.save(str(colorframes_folder/img))
     
     def _build_video(self, source_path:Path):
@@ -127,49 +135,49 @@ class VideoColorizer():
         
         print('Video created here: ' + str(result_path))
 
-    def colorize_from_url(self, source_url, file_name:str):    
+    def colorize_from_url(self, source_url, file_name:str, render_factor:int=None): 
         source_path =  self.source_folder/file_name
         self._download_video_from_url(source_url, source_path)
-        self._colorize_from_path(source_path)
+        self._colorize_from_path(source_path, render_factor=render_factor)
 
-    def colorize_from_file_name(self, file_name:str):
+    def colorize_from_file_name(self, file_name:str, render_factor:int=None):
         source_path =  self.source_folder/file_name
-        self._colorize_from_path(source_path)
+        self._colorize_from_path(source_path, render_factor=render_factor)
 
-    def _colorize_from_path(self, source_path:Path):
+    def _colorize_from_path(self, source_path:Path, render_factor:int=None):
         if not source_path.exists():
             raise Exception('Video at path specfied, ' + str(source_path) + ' could not be found.')
 
         self._extract_raw_frames(source_path)
-        self._colorize_raw_frames(source_path)
+        self._colorize_raw_frames(source_path, render_factor=render_factor)
         self._build_video(source_path)
 
 
-def get_video_colorizer(render_factor:int=36)->VideoColorizer:
+def get_video_colorizer(render_factor:int=21)->VideoColorizer:
     return get_stable_video_colorizer(render_factor=render_factor)
 
 def get_stable_video_colorizer(root_folder:Path=Path('./'), weights_name:str='ColorizeVideo_gen', 
-        results_dir='result_images', render_factor:int=36)->VideoColorizer:
+        results_dir='result_images', render_factor:int=21)->VideoColorizer:
     learn = gen_inference_wide(root_folder=root_folder, weights_name=weights_name)
     filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
     vis = ModelImageVisualizer(filtr, results_dir=results_dir)
     return VideoColorizer(vis)
 
-def get_image_colorizer(render_factor:int=36, artistic:bool=True)->ModelImageVisualizer:
+def get_image_colorizer(render_factor:int=35, artistic:bool=True)->ModelImageVisualizer:
     if artistic:
         return get_artistic_image_colorizer(render_factor=render_factor)
     else:
         return get_stable_image_colorizer(render_factor=render_factor)
 
 def get_stable_image_colorizer(root_folder:Path=Path('./'), weights_name:str='ColorizeStable_gen', 
-        results_dir='result_images', render_factor:int=36)->ModelImageVisualizer:
+        results_dir='result_images', render_factor:int=35)->ModelImageVisualizer:
     learn = gen_inference_wide(root_folder=root_folder, weights_name=weights_name)
     filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
     vis = ModelImageVisualizer(filtr, results_dir=results_dir)
     return vis
 
 def get_artistic_image_colorizer(root_folder:Path=Path('./'), weights_name:str='ColorizeArtistic_gen', 
-        results_dir='result_images', render_factor:int=36)->ModelImageVisualizer:
+        results_dir='result_images', render_factor:int=35)->ModelImageVisualizer:
     learn = gen_inference_deep(root_folder=root_folder, weights_name=weights_name)
     filtr = MasterFilter([ColorizerFilter(learn=learn)], render_factor=render_factor)
     vis = ModelImageVisualizer(filtr, results_dir=results_dir)

+ 3 - 0
test_images/poolparty.jpg

@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:d8f067cfc8cdf737d500a193f9f93137c1575121829572fcc3cddff0fd63d4d9
+size 74485