app.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. # import the necessary packages
  2. import os
  3. import sys
  4. import requests
  5. import ssl
  6. from flask import Flask
  7. from flask import request
  8. from flask import jsonify
  9. from flask import send_file
  10. from app_utils import download
  11. from app_utils import generate_random_filename
  12. from app_utils import clean_me
  13. from app_utils import clean_all
  14. from app_utils import create_directory
  15. from app_utils import get_model_bin
  16. from app_utils import convertToJPG
  17. from os import path
  18. import torch
  19. import fastai
  20. from fasterai.visualize import *
  21. from pathlib import Path
  22. import traceback
  23. torch.backends.cudnn.benchmark=True
  24. os.environ['CUDA_VISIBLE_DEVICES']='0'
  25. app = Flask(__name__)
  26. # define a predict function as an endpoint
  27. @app.route("/process", methods=["POST"])
  28. def process_image():
  29. input_path = generate_random_filename(upload_directory,"jpeg")
  30. output_path = os.path.join(results_img_directory, os.path.basename(input_path))
  31. try:
  32. url = request.json["source_url"]
  33. render_factor = int(request.json["render_factor"])
  34. download(url, input_path)
  35. try:
  36. image_colorizer.plot_transformed_image(path=input_path, figsize=(20,20),
  37. render_factor=render_factor, display_render_factor=True, compare=False)
  38. except:
  39. convertToJPG(input_path)
  40. image_colorizer.plot_transformed_image(path=input_path, figsize=(20,20),
  41. render_factor=render_factor, display_render_factor=True, compare=False)
  42. callback = send_file(output_path, mimetype='image/jpeg')
  43. return callback, 200
  44. except:
  45. traceback.print_exc()
  46. return {'message': 'input error'}, 400
  47. finally:
  48. pass
  49. clean_all([
  50. input_path,
  51. output_path
  52. ])
  53. if __name__ == '__main__':
  54. global upload_directory
  55. global results_img_directory
  56. global image_colorizer
  57. upload_directory = '/data/upload/'
  58. create_directory(upload_directory)
  59. results_img_directory = '/data/result_images/'
  60. create_directory(results_img_directory)
  61. model_directory = '/data/models/'
  62. create_directory(model_directory)
  63. artistic_model_url = 'https://www.dropbox.com/s/zkehq1uwahhbc2o/ColorizeArtistic_gen.pth?dl=0'
  64. get_model_bin(artistic_model_url, os.path.join(model_directory, 'ColorizeArtistic_gen.pth'))
  65. image_colorizer = get_image_colorizer(artistic=True)
  66. port = 5000
  67. host = '0.0.0.0'
  68. app.run(host=host, port=port, threaded=False)