app.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. # import the necessary packages
  2. import os
  3. import sys
  4. import requests
  5. import ssl
  6. from flask import Flask
  7. from flask import request
  8. from flask import jsonify
  9. from flask import send_file
  10. from uuid import uuid4
  11. from os import path
  12. import torch
  13. import fastai
  14. from fasterai.visualize import *
  15. from pathlib import Path
  16. import traceback
  17. torch.backends.cudnn.benchmark=True
  18. image_colorizer = get_image_colorizer(artistic=True)
  19. video_colorizer = get_video_colorizer()
  20. os.environ['CUDA_VISIBLE_DEVICES']='0'
  21. app = Flask(__name__)
  22. # define a predict function as an endpoint
  23. @app.route("/process_image", methods=["POST"])
  24. def process_image():
  25. try:
  26. source_url = request.json["source_url"]
  27. render_factor = int(request.json["render_factor"])
  28. upload_directory = 'upload'
  29. if not os.path.exists(upload_directory):
  30. os.mkdir(upload_directory)
  31. random_filename = str(uuid4()) + '.png'
  32. image_colorizer.plot_transformed_image_from_url(url=source_url, path=os.path.join(upload_directory, random_filename), figsize=(20,20),
  33. render_factor=render_factor, display_render_factor=True, compare=False)
  34. callback = send_file(os.path.join("result_images", random_filename), mimetype='image/jpeg')
  35. return callback
  36. except:
  37. traceback.print_exc()
  38. return {message: 'input error'}, 400
  39. finally:
  40. os.remove(os.path.join("result_images", random_filename))
  41. os.remove(os.path.join("upload", random_filename))
  42. @app.route("/process_video", methods=["POST"])
  43. def process_video():
  44. try:
  45. source_url = request.json["source_url"]
  46. render_factor = int(request.json["render_factor"])
  47. upload_directory = 'upload'
  48. if not os.path.exists(upload_directory):
  49. os.mkdir(upload_directory)
  50. random_filename = str(uuid4()) + '.mp4'
  51. video_path = video_colorizer.colorize_from_url(source_url, random_filename, render_factor)
  52. callback = send_file(os.path.join("video/result/", random_filename), mimetype='application/octet-stream')
  53. return callback
  54. except:
  55. traceback.print_exc()
  56. return {message: 'input error'}, 400
  57. finally:
  58. os.remove(os.path.join("video/result/", random_filename))
  59. if __name__ == '__main__':
  60. port = 5000
  61. host = '0.0.0.0'
  62. app.run(host=host, port=port, threaded=True)