app.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. # import the necessary packages
  2. import os
  3. import sys
  4. import requests
  5. import ssl
  6. from flask import Flask
  7. from flask import request
  8. from flask import jsonify
  9. from flask import send_file
  10. from uuid import uuid4
  11. from os import path
  12. import torch
  13. import fastai
  14. from fasterai.visualize import *
  15. from pathlib import Path
  16. torch.backends.cudnn.benchmark=True
  17. image_colorizer = get_image_colorizer(artistic=True)
  18. video_colorizer = get_video_colorizer()
  19. os.environ['CUDA_VISIBLE_DEVICES']='0'
  20. app = Flask(__name__)
  21. # define a predict function as an endpoint
  22. @app.route("/process_image", methods=["POST"])
  23. def process_image():
  24. source_url = request.json["source_url"]
  25. render_factor = int(request.json["render_factor"])
  26. upload_directory = 'upload'
  27. if not os.path.exists(upload_directory):
  28. os.mkdir(upload_directory)
  29. random_filename = str(uuid4()) + '.png'
  30. image_colorizer.plot_transformed_image_from_url(url=source_url, path=os.path.join(upload_directory, random_filename), figsize=(20,20),
  31. render_factor=render_factor, display_render_factor=True, compare=False)
  32. callback = send_file(os.path.join("result_images", random_filename), mimetype='image/jpeg')
  33. os.remove(os.path.join("result_images", random_filename))
  34. os.remove(os.path.join("upload", random_filename))
  35. return callback
  36. @app.route("/process_video", methods=["POST"])
  37. def process_video():
  38. source_url = request.json["source_url"]
  39. render_factor = int(request.json["render_factor"])
  40. upload_directory = 'upload'
  41. if not os.path.exists(upload_directory):
  42. os.mkdir(upload_directory)
  43. random_filename = str(uuid4()) + '.mp4'
  44. video_path = video_colorizer.colorize_from_url(source_url, random_filename, render_factor)
  45. callback = send_file(os.path.join("video/result/", random_filename), mimetype='application/octet-stream')
  46. os.remove(os.path.join("video/result/", random_filename))
  47. return callback
  48. if __name__ == '__main__':
  49. port = 5000
  50. host = '0.0.0.0'
  51. app.run(host=host, port=port, threaded=True)