server.rs 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. // Copyright 2019-2022 Tauri Programme within The Commons Conservancy
  2. // SPDX-License-Identifier: Apache-2.0
  3. // SPDX-License-Identifier: MIT
  4. use crate::cli::Args;
  5. use anyhow::Error;
  6. use futures_util::TryFutureExt;
  7. use hyper::header::CONTENT_LENGTH;
  8. use hyper::http::uri::Authority;
  9. use hyper::service::{make_service_fn, service_fn};
  10. use hyper::{Body, Client, Method, Request, Response, Server};
  11. use serde::Deserialize;
  12. use serde_json::{json, Map, Value};
  13. use std::convert::Infallible;
  14. use std::path::PathBuf;
  15. use std::process::Child;
  16. type HttpClient = Client<hyper::client::HttpConnector>;
  17. const TAURI_OPTIONS: &str = "tauri:options";
  18. #[derive(Debug, Deserialize)]
  19. struct TauriOptions {
  20. application: PathBuf,
  21. #[serde(default)]
  22. args: Vec<String>,
  23. }
  24. impl TauriOptions {
  25. #[cfg(target_os = "linux")]
  26. fn into_native_object(self) -> Map<String, Value> {
  27. let mut map = Map::new();
  28. map.insert(
  29. "webkitgtk:browserOptions".into(),
  30. json!({"binary": self.application, "args": self.args}),
  31. );
  32. map
  33. }
  34. #[cfg(target_os = "windows")]
  35. fn into_native_object(self) -> Map<String, Value> {
  36. let mut map = Map::new();
  37. map.insert("ms:edgeChromium".into(), json!(true));
  38. map.insert("browserName".into(), json!("webview2"));
  39. map.insert(
  40. "ms:edgeOptions".into(),
  41. json!({"binary": self.application, "args": self.args}),
  42. );
  43. map
  44. }
  45. }
  46. async fn handle(
  47. client: HttpClient,
  48. mut req: Request<Body>,
  49. args: Args,
  50. ) -> Result<Response<Body>, Error> {
  51. // manipulate a new session to convert options to the native driver format
  52. if let (&Method::POST, "/session") = (req.method(), req.uri().path()) {
  53. let (mut parts, body) = req.into_parts();
  54. // get the body from the future stream and parse it as json
  55. let body = hyper::body::to_bytes(body).await?;
  56. let json: Value = serde_json::from_slice(&body)?;
  57. // manipulate the json to convert from tauri option to native driver options
  58. let json = map_capabilities(json);
  59. // serialize json and update the content-length header to be accurate
  60. let bytes = serde_json::to_vec(&json)?;
  61. parts.headers.insert(CONTENT_LENGTH, bytes.len().into());
  62. req = Request::from_parts(parts, bytes.into());
  63. }
  64. client
  65. .request(forward_to_native_driver(req, args)?)
  66. .err_into()
  67. .await
  68. }
  69. /// Transform the request to a request for the native webdriver server.
  70. fn forward_to_native_driver(mut req: Request<Body>, args: Args) -> Result<Request<Body>, Error> {
  71. let host: Authority = {
  72. let headers = req.headers_mut();
  73. headers.remove("host").expect("hyper request has host")
  74. }
  75. .to_str()?
  76. .parse()?;
  77. let path = req
  78. .uri()
  79. .path_and_query()
  80. .expect("hyper request has uri")
  81. .clone();
  82. let uri = format!(
  83. "http://{}:{}{}",
  84. host.host(),
  85. args.native_port,
  86. path.as_str()
  87. );
  88. let (mut parts, body) = req.into_parts();
  89. parts.uri = uri.parse()?;
  90. Ok(Request::from_parts(parts, body))
  91. }
  92. /// only happy path for now, no errors
  93. fn map_capabilities(mut json: Value) -> Value {
  94. let mut native = None;
  95. if let Some(capabilities) = json.get_mut("capabilities") {
  96. if let Some(always_match) = capabilities.get_mut("alwaysMatch") {
  97. if let Some(always_match) = always_match.as_object_mut() {
  98. if let Some(tauri_options) = always_match.remove(TAURI_OPTIONS) {
  99. if let Ok(options) = serde_json::from_value::<TauriOptions>(tauri_options) {
  100. native = Some(options.into_native_object());
  101. }
  102. }
  103. if let Some(native) = native.clone() {
  104. always_match.extend(native);
  105. }
  106. }
  107. }
  108. }
  109. if let Some(native) = native {
  110. if let Some(desired) = json.get_mut("desiredCapabilities") {
  111. if let Some(desired) = desired.as_object_mut() {
  112. desired.remove(TAURI_OPTIONS);
  113. desired.extend(native);
  114. }
  115. }
  116. }
  117. json
  118. }
  119. #[tokio::main(flavor = "current_thread")]
  120. pub async fn run(args: Args, mut _driver: Child) -> Result<(), Error> {
  121. #[cfg(unix)]
  122. let (signals_handle, signals_task) = {
  123. use futures_util::StreamExt;
  124. use signal_hook::consts::signal::*;
  125. let signals = signal_hook_tokio::Signals::new(&[SIGTERM, SIGINT, SIGQUIT])?;
  126. let signals_handle = signals.handle();
  127. let signals_task = tokio::spawn(async move {
  128. let mut signals = signals.fuse();
  129. while let Some(signal) = signals.next().await {
  130. match signal {
  131. SIGTERM | SIGINT | SIGQUIT => {
  132. _driver
  133. .kill()
  134. .expect("unable to kill native webdriver server");
  135. std::process::exit(0);
  136. }
  137. _ => unreachable!(),
  138. }
  139. }
  140. });
  141. (signals_handle, signals_task)
  142. };
  143. let address = std::net::SocketAddr::from(([127, 0, 0, 1], args.port));
  144. // the client we use to proxy requests to the native webdriver
  145. let client = Client::builder()
  146. .http1_preserve_header_case(true)
  147. .http1_title_case_headers(true)
  148. .retry_canceled_requests(false)
  149. .build_http();
  150. // pass a copy of the client to the http request handler
  151. let service = make_service_fn(move |_| {
  152. let client = client.clone();
  153. let args = args.clone();
  154. async move {
  155. Ok::<_, Infallible>(service_fn(move |request| {
  156. handle(client.clone(), request, args.clone())
  157. }))
  158. }
  159. });
  160. // set up a http1 server that uses the service we just created
  161. Server::bind(&address)
  162. .http1_title_case_headers(true)
  163. .http1_preserve_header_case(true)
  164. .http1_only(true)
  165. .serve(service)
  166. .await?;
  167. #[cfg(unix)]
  168. {
  169. signals_handle.close();
  170. signals_task.await?;
  171. }
  172. Ok(())
  173. }