소스 검색

Adding FID benchmark notebook

Jason Antic 5 년 전
부모
커밋
e69d3bbda5
5개의 변경된 파일1082개의 추가작업 그리고 3개의 파일을 삭제
  1. 1 3
      .gitignore
  2. 266 0
      ColorFIDBenchmarkArtistic.ipynb
  3. 201 0
      fid/LICENSE
  4. 299 0
      fid/fid_score.py
  5. 315 0
      fid/inception.py

+ 1 - 3
.gitignore

@@ -111,9 +111,7 @@ ColorizeTraining*[0-9]*.ipynb
 *Colorizer[0-9]*.ipynb
 lesson7-superres*.ipynb
 test.py
-result_images/*.jpg
-result_images/*.jpeg
-result_images/*.png
+result_images
 
 deoldify/fastai
 

+ 266 - 0
ColorFIDBenchmarkArtistic.ipynb

@@ -0,0 +1,266 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Color FID Benchmark (HQ)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import os\n",
+    "os.environ['CUDA_VISIBLE_DEVICES']='3'\n",
+    "os.environ['OMP_NUM_THREADS']='1'"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import statistics\n",
+    "from fastai import *\n",
+    "from deoldify.visualize import *\n",
+    "import cv2\n",
+    "from fid.fid_score import *\n",
+    "from fid.inception import *\n",
+    "import imageio\n",
+    "plt.style.use('dark_background')\n",
+    "torch.backends.cudnn.benchmark=True\n",
+    "import warnings\n",
+    "warnings.filterwarnings(\"ignore\", category=UserWarning, module=\"torch.nn.functional\")\n",
+    "warnings.filterwarnings(\"ignore\", category=UserWarning, message='.*?retrieve source code for container of type.*?')"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Setup"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "#NOTE:  Data should come from here:  'https://datasets.figure-eight.com/figure_eight_datasets/open-images/test_challenge.zip'\n",
+    "#NOTE:  Minimum recommmended number of samples is 10K.  Source:  https://github.com/bioinf-jku/TTUR\n",
+    "\n",
+    "path = Path('data/ColorBenchmark')\n",
+    "path_hr = path/'source'\n",
+    "path_lr = path/'bandw'\n",
+    "path_results = Path('./result_images/ColorBenchmarkFID/artistic')\n",
+    "path_rendered = path_results/'rendered'\n",
+    "\n",
+    "#path = Path('data/DeOldifyColor')\n",
+    "#path_hr = path\n",
+    "#path_lr = path/'bandw'\n",
+    "#path_results = Path('./result_images/ColorBenchmark/edge')\n",
+    "#path_rendered = path_results/'rendered'\n",
+    "\n",
+    "#num_images = 2048\n",
+    "num_images = 15000\n",
+    "#num_images = 50000\n",
+    "render_factor=35\n",
+    "fid_batch_size = 4\n",
+    "eval_size=299"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def inception_model(dims:int):\n",
+    "    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]\n",
+    "    model = InceptionV3([block_idx])\n",
+    "    model.cuda()\n",
+    "    return model"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def create_before_images(fn,i):\n",
+    "    dest = path_lr/fn.relative_to(path_hr)\n",
+    "    dest.parent.mkdir(parents=True, exist_ok=True)\n",
+    "    img = PIL.Image.open(fn).convert('LA').convert('RGB')\n",
+    "    img.save(dest)  "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def render_images(colorizer, source_dir:Path, filtered_dir:Path, target_dir:Path, render_factor:int, num_images:int)->[(Path, Path, Path)]:\n",
+    "    results = []\n",
+    "    bandw_list = ImageList.from_folder(path_lr)\n",
+    "    bandw_list = bandw_list[:num_images]\n",
+    "\n",
+    "    if len(bandw_list.items) == 0: return results\n",
+    "\n",
+    "    results = []\n",
+    "    img_iterator = progress_bar(bandw_list.items)\n",
+    "\n",
+    "    for bandw_path in img_iterator:\n",
+    "        target_path = target_dir/bandw_path.relative_to(source_dir)\n",
+    "\n",
+    "        try:\n",
+    "            result_image = colorizer.get_transformed_image(path=bandw_path, render_factor=render_factor)\n",
+    "            result_path = Path(str(path_results) + '/' + bandw_path.parent.name + '/' + bandw_path.name)\n",
+    "            if not result_path.parent.exists():\n",
+    "                result_path.parent.mkdir(parents=True, exist_ok=True)\n",
+    "            result_image.save(result_path)\n",
+    "            results.append((result_path, bandw_path, target_path))\n",
+    "        except Exception as err:\n",
+    "            print('Failed to render image.  Skipping.  Details: {0}'.format(err))\n",
+    "    \n",
+    "    return results "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def calculate_fid_score(render_results, bs:int, eval_size:int):\n",
+    "    dims = 2048\n",
+    "    cuda = True\n",
+    "    model = inception_model(dims=dims)\n",
+    "    rendered_paths = []\n",
+    "    target_paths = []\n",
+    "    \n",
+    "    for render_result in render_results:\n",
+    "        rendered_path, _, target_path = render_result\n",
+    "        rendered_paths.append(str(rendered_path))\n",
+    "        target_paths.append(str(target_path))\n",
+    "        \n",
+    "    rendered_m, rendered_s = calculate_activation_statistics(files=rendered_paths, model=model, batch_size=bs, dims=dims, cuda=cuda)\n",
+    "    target_m, target_s = calculate_activation_statistics(files=target_paths, model=model, batch_size=bs, dims=dims, cuda=cuda)\n",
+    "    fid_score = calculate_frechet_distance(rendered_m, rendered_s, target_m, target_s)\n",
+    "    del model\n",
+    "    return fid_score"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Create black and whites source images"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Only runs if the directory isn't already created."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "if not path_lr.exists():\n",
+    "    il = ImageList.from_folder(path_hr)\n",
+    "    parallel(create_before_images, il.items)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "path_results.parent.mkdir(parents=True, exist_ok=True)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Rendering"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "colorizer = get_image_colorizer(artistic=True)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "render_results = render_images(colorizer=colorizer, source_dir=path_lr, target_dir=path_hr, filtered_dir=path_results, render_factor=render_factor, num_images=num_images)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Colorizaton Scoring"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "fid_score = calculate_fid_score(render_results, bs=fid_batch_size, eval_size=eval_size)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "print('FID Score: ' + str(fid_score))"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.7.0"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}

+ 201 - 0
fid/LICENSE

@@ -0,0 +1,201 @@
+                                 Apache License
+                           Version 2.0, January 2004
+                        http://www.apache.org/licenses/
+
+   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+   1. Definitions.
+
+      "License" shall mean the terms and conditions for use, reproduction,
+      and distribution as defined by Sections 1 through 9 of this document.
+
+      "Licensor" shall mean the copyright owner or entity authorized by
+      the copyright owner that is granting the License.
+
+      "Legal Entity" shall mean the union of the acting entity and all
+      other entities that control, are controlled by, or are under common
+      control with that entity. For the purposes of this definition,
+      "control" means (i) the power, direct or indirect, to cause the
+      direction or management of such entity, whether by contract or
+      otherwise, or (ii) ownership of fifty percent (50%) or more of the
+      outstanding shares, or (iii) beneficial ownership of such entity.
+
+      "You" (or "Your") shall mean an individual or Legal Entity
+      exercising permissions granted by this License.
+
+      "Source" form shall mean the preferred form for making modifications,
+      including but not limited to software source code, documentation
+      source, and configuration files.
+
+      "Object" form shall mean any form resulting from mechanical
+      transformation or translation of a Source form, including but
+      not limited to compiled object code, generated documentation,
+      and conversions to other media types.
+
+      "Work" shall mean the work of authorship, whether in Source or
+      Object form, made available under the License, as indicated by a
+      copyright notice that is included in or attached to the work
+      (an example is provided in the Appendix below).
+
+      "Derivative Works" shall mean any work, whether in Source or Object
+      form, that is based on (or derived from) the Work and for which the
+      editorial revisions, annotations, elaborations, or other modifications
+      represent, as a whole, an original work of authorship. For the purposes
+      of this License, Derivative Works shall not include works that remain
+      separable from, or merely link (or bind by name) to the interfaces of,
+      the Work and Derivative Works thereof.
+
+      "Contribution" shall mean any work of authorship, including
+      the original version of the Work and any modifications or additions
+      to that Work or Derivative Works thereof, that is intentionally
+      submitted to Licensor for inclusion in the Work by the copyright owner
+      or by an individual or Legal Entity authorized to submit on behalf of
+      the copyright owner. For the purposes of this definition, "submitted"
+      means any form of electronic, verbal, or written communication sent
+      to the Licensor or its representatives, including but not limited to
+      communication on electronic mailing lists, source code control systems,
+      and issue tracking systems that are managed by, or on behalf of, the
+      Licensor for the purpose of discussing and improving the Work, but
+      excluding communication that is conspicuously marked or otherwise
+      designated in writing by the copyright owner as "Not a Contribution."
+
+      "Contributor" shall mean Licensor and any individual or Legal Entity
+      on behalf of whom a Contribution has been received by Licensor and
+      subsequently incorporated within the Work.
+
+   2. Grant of Copyright License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      copyright license to reproduce, prepare Derivative Works of,
+      publicly display, publicly perform, sublicense, and distribute the
+      Work and such Derivative Works in Source or Object form.
+
+   3. Grant of Patent License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      (except as stated in this section) patent license to make, have made,
+      use, offer to sell, sell, import, and otherwise transfer the Work,
+      where such license applies only to those patent claims licensable
+      by such Contributor that are necessarily infringed by their
+      Contribution(s) alone or by combination of their Contribution(s)
+      with the Work to which such Contribution(s) was submitted. If You
+      institute patent litigation against any entity (including a
+      cross-claim or counterclaim in a lawsuit) alleging that the Work
+      or a Contribution incorporated within the Work constitutes direct
+      or contributory patent infringement, then any patent licenses
+      granted to You under this License for that Work shall terminate
+      as of the date such litigation is filed.
+
+   4. Redistribution. You may reproduce and distribute copies of the
+      Work or Derivative Works thereof in any medium, with or without
+      modifications, and in Source or Object form, provided that You
+      meet the following conditions:
+
+      (a) You must give any other recipients of the Work or
+          Derivative Works a copy of this License; and
+
+      (b) You must cause any modified files to carry prominent notices
+          stating that You changed the files; and
+
+      (c) You must retain, in the Source form of any Derivative Works
+          that You distribute, all copyright, patent, trademark, and
+          attribution notices from the Source form of the Work,
+          excluding those notices that do not pertain to any part of
+          the Derivative Works; and
+
+      (d) If the Work includes a "NOTICE" text file as part of its
+          distribution, then any Derivative Works that You distribute must
+          include a readable copy of the attribution notices contained
+          within such NOTICE file, excluding those notices that do not
+          pertain to any part of the Derivative Works, in at least one
+          of the following places: within a NOTICE text file distributed
+          as part of the Derivative Works; within the Source form or
+          documentation, if provided along with the Derivative Works; or,
+          within a display generated by the Derivative Works, if and
+          wherever such third-party notices normally appear. The contents
+          of the NOTICE file are for informational purposes only and
+          do not modify the License. You may add Your own attribution
+          notices within Derivative Works that You distribute, alongside
+          or as an addendum to the NOTICE text from the Work, provided
+          that such additional attribution notices cannot be construed
+          as modifying the License.
+
+      You may add Your own copyright statement to Your modifications and
+      may provide additional or different license terms and conditions
+      for use, reproduction, or distribution of Your modifications, or
+      for any such Derivative Works as a whole, provided Your use,
+      reproduction, and distribution of the Work otherwise complies with
+      the conditions stated in this License.
+
+   5. Submission of Contributions. Unless You explicitly state otherwise,
+      any Contribution intentionally submitted for inclusion in the Work
+      by You to the Licensor shall be under the terms and conditions of
+      this License, without any additional terms or conditions.
+      Notwithstanding the above, nothing herein shall supersede or modify
+      the terms of any separate license agreement you may have executed
+      with Licensor regarding such Contributions.
+
+   6. Trademarks. This License does not grant permission to use the trade
+      names, trademarks, service marks, or product names of the Licensor,
+      except as required for reasonable and customary use in describing the
+      origin of the Work and reproducing the content of the NOTICE file.
+
+   7. Disclaimer of Warranty. Unless required by applicable law or
+      agreed to in writing, Licensor provides the Work (and each
+      Contributor provides its Contributions) on an "AS IS" BASIS,
+      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+      implied, including, without limitation, any warranties or conditions
+      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+      PARTICULAR PURPOSE. You are solely responsible for determining the
+      appropriateness of using or redistributing the Work and assume any
+      risks associated with Your exercise of permissions under this License.
+
+   8. Limitation of Liability. In no event and under no legal theory,
+      whether in tort (including negligence), contract, or otherwise,
+      unless required by applicable law (such as deliberate and grossly
+      negligent acts) or agreed to in writing, shall any Contributor be
+      liable to You for damages, including any direct, indirect, special,
+      incidental, or consequential damages of any character arising as a
+      result of this License or out of the use or inability to use the
+      Work (including but not limited to damages for loss of goodwill,
+      work stoppage, computer failure or malfunction, or any and all
+      other commercial damages or losses), even if such Contributor
+      has been advised of the possibility of such damages.
+
+   9. Accepting Warranty or Additional Liability. While redistributing
+      the Work or Derivative Works thereof, You may choose to offer,
+      and charge a fee for, acceptance of support, warranty, indemnity,
+      or other liability obligations and/or rights consistent with this
+      License. However, in accepting such obligations, You may act only
+      on Your own behalf and on Your sole responsibility, not on behalf
+      of any other Contributor, and only if You agree to indemnify,
+      defend, and hold each Contributor harmless for any liability
+      incurred by, or claims asserted against, such Contributor by reason
+      of your accepting any such warranty or additional liability.
+
+   END OF TERMS AND CONDITIONS
+
+   APPENDIX: How to apply the Apache License to your work.
+
+      To apply the Apache License to your work, attach the following
+      boilerplate notice, with the fields enclosed by brackets "[]"
+      replaced with your own identifying information. (Don't include
+      the brackets!)  The text should be enclosed in the appropriate
+      comment syntax for the file format. We also recommend that a
+      file or class name and description of purpose be included on the
+      same "printed page" as the copyright notice for easier
+      identification within third-party archives.
+
+   Copyright [yyyy] [name of copyright owner]
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+       http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License.

+ 299 - 0
fid/fid_score.py

@@ -0,0 +1,299 @@
+# Code adapted and modified from https://github.com/mseitzer/pytorch-fid.  Licensing
+# and description duplicated below.
+
+#!/usr/bin/env python3
+"""Calculates the Frechet Inception Distance (FID) to evalulate GANs
+
+The FID metric calculates the distance between two distributions of images.
+Typically, we have summary statistics (mean & covariance matrix) of one
+of these distributions, while the 2nd distribution is given by a GAN.
+
+When run as a stand-alone program, it compares the distribution of
+images that are stored as PNG/JPEG at a specified location with a
+distribution given by summary statistics (in pickle format).
+
+The FID is calculated by assuming that X_1 and X_2 are the activations of
+the pool_3 layer of the inception net for generated samples and real world
+samples respectively.
+
+See --help to see further details.
+
+Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
+of Tensorflow
+
+Copyright 2018 Institute of Bioinformatics, JKU Linz
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+   http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+"""
+import os
+import pathlib
+from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
+
+import numpy as np
+import torch
+from scipy import linalg
+from torch.nn.functional import adaptive_avg_pool2d
+import cv2
+import imageio
+
+try:
+    from tqdm import tqdm
+except ImportError:
+    # If not tqdm is not available, provide a mock version of it
+    def tqdm(x):
+        return x
+
+
+from .inception import InceptionV3
+
+parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
+parser.add_argument(
+    'path',
+    type=str,
+    nargs=2,
+    help=('Path to the generated images or ' 'to .npz statistic files'),
+)
+parser.add_argument('--batch-size', type=int, default=50, help='Batch size to use')
+parser.add_argument(
+    '--dims',
+    type=int,
+    default=2048,
+    choices=list(InceptionV3.BLOCK_INDEX_BY_DIM),
+    help=(
+        'Dimensionality of Inception features to use. '
+        'By default, uses pool3 features'
+    ),
+)
+parser.add_argument(
+    '-c', '--gpu', default='', type=str, help='GPU to use (leave blank for CPU only)'
+)
+
+
+def load_image_resized(fn, sz):
+    return cv2.resize(
+        imageio.imread(str(fn)), dsize=(sz, sz), interpolation=cv2.INTER_CUBIC
+    ).astype(np.float32)
+
+
+def get_activations(
+    files,
+    model,
+    batch_size=50,
+    dims=2048,
+    cuda=False,
+    verbose=False,
+    eval_size: int = 299,
+):
+    """Calculates the activations of the pool_3 layer for all images.
+
+    Params:
+    -- files       : List of image files paths
+    -- model       : Instance of inception model
+    -- batch_size  : Batch size of images for the model to process at once.
+                     Make sure that the number of samples is a multiple of
+                     the batch size, otherwise some samples are ignored. This
+                     behavior is retained to match the original FID score
+                     implementation.
+    -- dims        : Dimensionality of features returned by Inception
+    -- cuda        : If set to True, use GPU
+    -- verbose     : If set to True and parameter out_step is given, the number
+                     of calculated batches is reported.
+    Returns:
+    -- A numpy array of dimension (num images, dims) that contains the
+       activations of the given tensor when feeding inception with the
+       query tensor.
+    """
+    model.eval()
+
+    if len(files) % batch_size != 0:
+        print(
+            (
+                'Warning: number of images is not a multiple of the '
+                'batch size. Some samples are going to be ignored.'
+            )
+        )
+    if batch_size > len(files):
+        print(
+            (
+                'Warning: batch size is bigger than the data size. '
+                'Setting batch size to data size'
+            )
+        )
+        batch_size = len(files)
+
+    n_batches = len(files) // batch_size
+    n_used_imgs = n_batches * batch_size
+
+    pred_arr = np.empty((n_used_imgs, dims))
+
+    for i in tqdm(range(n_batches)):
+        if verbose:
+            print('\rPropagating batch %d/%d' % (i + 1, n_batches), end='', flush=True)
+        start = i * batch_size
+        end = start + batch_size
+
+        images = np.array(
+            [load_image_resized(fn, eval_size) for fn in files[start:end]]
+        )
+        # images = np.array([imageio.imread(str(f)).astype(np.float32)
+        # for f in files[start:end]])
+
+        # Reshape to (n_images, 3, height, width)
+        images = images.transpose((0, 3, 1, 2))
+        images /= 255
+
+        batch = torch.from_numpy(images).type(torch.FloatTensor)
+        if cuda:
+            batch = batch.cuda()
+
+        pred = model(batch)[0]
+
+        # If model output is not scalar, apply global spatial average pooling.
+        # This happens if you choose a dimensionality not equal 2048.
+        if pred.shape[2] != 1 or pred.shape[3] != 1:
+            pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
+
+        pred_arr[start:end] = pred.cpu().data.numpy().reshape(batch_size, -1)
+
+    if verbose:
+        print(' done')
+
+    return pred_arr
+
+
+def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
+    """Numpy implementation of the Frechet Distance.
+    The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
+    and X_2 ~ N(mu_2, C_2) is
+            d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
+
+    Stable version by Dougal J. Sutherland.
+
+    Params:
+    -- mu1   : Numpy array containing the activations of a layer of the
+               inception net (like returned by the function 'get_predictions')
+               for generated samples.
+    -- mu2   : The sample mean over activations, precalculated on an
+               representative data set.
+    -- sigma1: The covariance matrix over activations for generated samples.
+    -- sigma2: The covariance matrix over activations, precalculated on an
+               representative data set.
+
+    Returns:
+    --   : The Frechet Distance.
+    """
+
+    mu1 = np.atleast_1d(mu1)
+    mu2 = np.atleast_1d(mu2)
+
+    sigma1 = np.atleast_2d(sigma1)
+    sigma2 = np.atleast_2d(sigma2)
+
+    assert (
+        mu1.shape == mu2.shape
+    ), 'Training and test mean vectors have different lengths'
+    assert (
+        sigma1.shape == sigma2.shape
+    ), 'Training and test covariances have different dimensions'
+
+    diff = mu1 - mu2
+
+    # Product might be almost singular
+    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
+    if not np.isfinite(covmean).all():
+        msg = (
+            'fid calculation produces singular product; '
+            'adding %s to diagonal of cov estimates'
+        ) % eps
+        print(msg)
+        offset = np.eye(sigma1.shape[0]) * eps
+        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
+
+    # Numerical error might give slight imaginary component
+    if np.iscomplexobj(covmean):
+        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
+            m = np.max(np.abs(covmean.imag))
+            raise ValueError('Imaginary component {}'.format(m))
+        covmean = covmean.real
+
+    tr_covmean = np.trace(covmean)
+
+    return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
+
+
+def calculate_activation_statistics(
+    files, model, batch_size=50, dims=2048, cuda=False, verbose=False
+):
+    """Calculation of the statistics used by the FID.
+    Params:
+    -- files       : List of image files paths
+    -- model       : Instance of inception model
+    -- batch_size  : The images numpy array is split into batches with
+                     batch size batch_size. A reasonable batch size
+                     depends on the hardware.
+    -- dims        : Dimensionality of features returned by Inception
+    -- cuda        : If set to True, use GPU
+    -- verbose     : If set to True and parameter out_step is given, the
+                     number of calculated batches is reported.
+    Returns:
+    -- mu    : The mean over samples of the activations of the pool_3 layer of
+               the inception model.
+    -- sigma : The covariance matrix of the activations of the pool_3 layer of
+               the inception model.
+    """
+    act = get_activations(files, model, batch_size, dims, cuda, verbose)
+    mu = np.mean(act, axis=0)
+    sigma = np.cov(act, rowvar=False)
+    return mu, sigma
+
+
+def _compute_statistics_of_path(path, model, batch_size, dims, cuda):
+    if path.endswith('.npz'):
+        f = np.load(path)
+        m, s = f['mu'][:], f['sigma'][:]
+        f.close()
+    else:
+        path = pathlib.Path(path)
+        files = list(path.glob('*.jpg')) + list(path.glob('*.png'))
+        m, s = calculate_activation_statistics(files, model, batch_size, dims, cuda)
+
+    return m, s
+
+
+def calculate_fid_given_paths(paths, batch_size, cuda, dims):
+    """Calculates the FID of two paths"""
+    for p in paths:
+        if not os.path.exists(p):
+            raise RuntimeError('Invalid path: %s' % p)
+
+    block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
+
+    model = InceptionV3([block_idx])
+    if cuda:
+        model.cuda()
+
+    m1, s1 = _compute_statistics_of_path(paths[0], model, batch_size, dims, cuda)
+    m2, s2 = _compute_statistics_of_path(paths[1], model, batch_size, dims, cuda)
+    fid_value = calculate_frechet_distance(m1, s1, m2, s2)
+
+    return fid_value
+
+
+if __name__ == '__main__':
+    args = parser.parse_args()
+    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
+
+    fid_value = calculate_fid_given_paths(
+        args.path, args.batch_size, args.gpu != '', args.dims
+    )
+    print('FID: ', fid_value)

+ 315 - 0
fid/inception.py

@@ -0,0 +1,315 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torchvision import models
+
+try:
+    from torchvision.models.utils import load_state_dict_from_url
+except ImportError:
+    from torch.utils.model_zoo import load_url as load_state_dict_from_url
+
+# Inception weights ported to Pytorch from
+# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
+FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
+
+
+class InceptionV3(nn.Module):
+    """Pretrained InceptionV3 network returning feature maps"""
+
+    # Index of default block of inception to return,
+    # corresponds to output of final average pooling
+    DEFAULT_BLOCK_INDEX = 3
+
+    # Maps feature dimensionality to their output blocks indices
+    BLOCK_INDEX_BY_DIM = {
+        64: 0,  # First max pooling features
+        192: 1,  # Second max pooling featurs
+        768: 2,  # Pre-aux classifier features
+        2048: 3,  # Final average pooling features
+    }
+
+    def __init__(
+        self,
+        output_blocks=[DEFAULT_BLOCK_INDEX],
+        resize_input=True,
+        normalize_input=True,
+        requires_grad=False,
+        use_fid_inception=True,
+    ):
+        """Build pretrained InceptionV3
+
+        Parameters
+        ----------
+        output_blocks : list of int
+            Indices of blocks to return features of. Possible values are:
+                - 0: corresponds to output of first max pooling
+                - 1: corresponds to output of second max pooling
+                - 2: corresponds to output which is fed to aux classifier
+                - 3: corresponds to output of final average pooling
+        resize_input : bool
+            If true, bilinearly resizes input to width and height 299 before
+            feeding input to model. As the network without fully connected
+            layers is fully convolutional, it should be able to handle inputs
+            of arbitrary size, so resizing might not be strictly needed
+        normalize_input : bool
+            If true, scales the input from range (0, 1) to the range the
+            pretrained Inception network expects, namely (-1, 1)
+        requires_grad : bool
+            If true, parameters of the model require gradients. Possibly useful
+            for finetuning the network
+        use_fid_inception : bool
+            If true, uses the pretrained Inception model used in Tensorflow's
+            FID implementation. If false, uses the pretrained Inception model
+            available in torchvision. The FID Inception model has different
+            weights and a slightly different structure from torchvision's
+            Inception model. If you want to compute FID scores, you are
+            strongly advised to set this parameter to true to get comparable
+            results.
+        """
+        super(InceptionV3, self).__init__()
+
+        self.resize_input = resize_input
+        self.normalize_input = normalize_input
+        self.output_blocks = sorted(output_blocks)
+        self.last_needed_block = max(output_blocks)
+
+        assert self.last_needed_block <= 3, 'Last possible output block index is 3'
+
+        self.blocks = nn.ModuleList()
+
+        if use_fid_inception:
+            inception = fid_inception_v3()
+        else:
+            inception = models.inception_v3(pretrained=True)
+
+        # Block 0: input to maxpool1
+        block0 = [
+            inception.Conv2d_1a_3x3,
+            inception.Conv2d_2a_3x3,
+            inception.Conv2d_2b_3x3,
+            nn.MaxPool2d(kernel_size=3, stride=2),
+        ]
+        self.blocks.append(nn.Sequential(*block0))
+
+        # Block 1: maxpool1 to maxpool2
+        if self.last_needed_block >= 1:
+            block1 = [
+                inception.Conv2d_3b_1x1,
+                inception.Conv2d_4a_3x3,
+                nn.MaxPool2d(kernel_size=3, stride=2),
+            ]
+            self.blocks.append(nn.Sequential(*block1))
+
+        # Block 2: maxpool2 to aux classifier
+        if self.last_needed_block >= 2:
+            block2 = [
+                inception.Mixed_5b,
+                inception.Mixed_5c,
+                inception.Mixed_5d,
+                inception.Mixed_6a,
+                inception.Mixed_6b,
+                inception.Mixed_6c,
+                inception.Mixed_6d,
+                inception.Mixed_6e,
+            ]
+            self.blocks.append(nn.Sequential(*block2))
+
+        # Block 3: aux classifier to final avgpool
+        if self.last_needed_block >= 3:
+            block3 = [
+                inception.Mixed_7a,
+                inception.Mixed_7b,
+                inception.Mixed_7c,
+                nn.AdaptiveAvgPool2d(output_size=(1, 1)),
+            ]
+            self.blocks.append(nn.Sequential(*block3))
+
+        for param in self.parameters():
+            param.requires_grad = requires_grad
+
+    def forward(self, inp):
+        """Get Inception feature maps
+
+        Parameters
+        ----------
+        inp : torch.autograd.Variable
+            Input tensor of shape Bx3xHxW. Values are expected to be in
+            range (0, 1)
+
+        Returns
+        -------
+        List of torch.autograd.Variable, corresponding to the selected output
+        block, sorted ascending by index
+        """
+        outp = []
+        x = inp
+
+        if self.resize_input:
+            x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False)
+
+        if self.normalize_input:
+            x = 2 * x - 1  # Scale from range (0, 1) to range (-1, 1)
+
+        for idx, block in enumerate(self.blocks):
+            x = block(x)
+            if idx in self.output_blocks:
+                outp.append(x)
+
+            if idx == self.last_needed_block:
+                break
+
+        return outp
+
+
+def fid_inception_v3():
+    """Build pretrained Inception model for FID computation
+
+    The Inception model for FID computation uses a different set of weights
+    and has a slightly different structure than torchvision's Inception.
+
+    This method first constructs torchvision's Inception and then patches the
+    necessary parts that are different in the FID Inception model.
+    """
+    inception = models.inception_v3(
+        num_classes=1008, aux_logits=False, pretrained=False
+    )
+    inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
+    inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
+    inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
+    inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
+    inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
+    inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
+    inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
+    inception.Mixed_7b = FIDInceptionE_1(1280)
+    inception.Mixed_7c = FIDInceptionE_2(2048)
+
+    state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
+    inception.load_state_dict(state_dict)
+    return inception
+
+
+class FIDInceptionA(models.inception.InceptionA):
+    """InceptionA block patched for FID computation"""
+
+    def __init__(self, in_channels, pool_features):
+        super(FIDInceptionA, self).__init__(in_channels, pool_features)
+
+    def forward(self, x):
+        branch1x1 = self.branch1x1(x)
+
+        branch5x5 = self.branch5x5_1(x)
+        branch5x5 = self.branch5x5_2(branch5x5)
+
+        branch3x3dbl = self.branch3x3dbl_1(x)
+        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
+
+        # Patch: Tensorflow's average pool does not use the padded zero's in
+        # its average calculation
+        branch_pool = F.avg_pool2d(
+            x, kernel_size=3, stride=1, padding=1, count_include_pad=False
+        )
+        branch_pool = self.branch_pool(branch_pool)
+
+        outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
+        return torch.cat(outputs, 1)
+
+
+class FIDInceptionC(models.inception.InceptionC):
+    """InceptionC block patched for FID computation"""
+
+    def __init__(self, in_channels, channels_7x7):
+        super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
+
+    def forward(self, x):
+        branch1x1 = self.branch1x1(x)
+
+        branch7x7 = self.branch7x7_1(x)
+        branch7x7 = self.branch7x7_2(branch7x7)
+        branch7x7 = self.branch7x7_3(branch7x7)
+
+        branch7x7dbl = self.branch7x7dbl_1(x)
+        branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
+        branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
+        branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
+        branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
+
+        # Patch: Tensorflow's average pool does not use the padded zero's in
+        # its average calculation
+        branch_pool = F.avg_pool2d(
+            x, kernel_size=3, stride=1, padding=1, count_include_pad=False
+        )
+        branch_pool = self.branch_pool(branch_pool)
+
+        outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
+        return torch.cat(outputs, 1)
+
+
+class FIDInceptionE_1(models.inception.InceptionE):
+    """First InceptionE block patched for FID computation"""
+
+    def __init__(self, in_channels):
+        super(FIDInceptionE_1, self).__init__(in_channels)
+
+    def forward(self, x):
+        branch1x1 = self.branch1x1(x)
+
+        branch3x3 = self.branch3x3_1(x)
+        branch3x3 = [
+            self.branch3x3_2a(branch3x3),
+            self.branch3x3_2b(branch3x3),
+        ]
+        branch3x3 = torch.cat(branch3x3, 1)
+
+        branch3x3dbl = self.branch3x3dbl_1(x)
+        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+        branch3x3dbl = [
+            self.branch3x3dbl_3a(branch3x3dbl),
+            self.branch3x3dbl_3b(branch3x3dbl),
+        ]
+        branch3x3dbl = torch.cat(branch3x3dbl, 1)
+
+        # Patch: Tensorflow's average pool does not use the padded zero's in
+        # its average calculation
+        branch_pool = F.avg_pool2d(
+            x, kernel_size=3, stride=1, padding=1, count_include_pad=False
+        )
+        branch_pool = self.branch_pool(branch_pool)
+
+        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
+        return torch.cat(outputs, 1)
+
+
+class FIDInceptionE_2(models.inception.InceptionE):
+    """Second InceptionE block patched for FID computation"""
+
+    def __init__(self, in_channels):
+        super(FIDInceptionE_2, self).__init__(in_channels)
+
+    def forward(self, x):
+        branch1x1 = self.branch1x1(x)
+
+        branch3x3 = self.branch3x3_1(x)
+        branch3x3 = [
+            self.branch3x3_2a(branch3x3),
+            self.branch3x3_2b(branch3x3),
+        ]
+        branch3x3 = torch.cat(branch3x3, 1)
+
+        branch3x3dbl = self.branch3x3dbl_1(x)
+        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+        branch3x3dbl = [
+            self.branch3x3dbl_3a(branch3x3dbl),
+            self.branch3x3dbl_3b(branch3x3dbl),
+        ]
+        branch3x3dbl = torch.cat(branch3x3dbl, 1)
+
+        # Patch: The FID Inception model uses max pooling instead of average
+        # pooling. This is likely an error in this specific Inception
+        # implementation, as other Inception models use average pooling here
+        # (which matches the description in the paper).
+        branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
+        branch_pool = self.branch_pool(branch_pool)
+
+        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
+        return torch.cat(outputs, 1)