server.rs 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. // Copyright 2019-2024 Tauri Programme within The Commons Conservancy
  2. // SPDX-License-Identifier: Apache-2.0
  3. // SPDX-License-Identifier: MIT
  4. use crate::cli::Args;
  5. use anyhow::Error;
  6. use futures_util::TryFutureExt;
  7. use hyper::header::CONTENT_LENGTH;
  8. use hyper::http::uri::Authority;
  9. use hyper::service::{make_service_fn, service_fn};
  10. use hyper::{Body, Client, Method, Request, Response, Server};
  11. use serde::Deserialize;
  12. use serde_json::{json, Map, Value};
  13. use std::convert::Infallible;
  14. use std::path::PathBuf;
  15. use std::process::Child;
  16. type HttpClient = Client<hyper::client::HttpConnector>;
  17. const TAURI_OPTIONS: &str = "tauri:options";
  18. #[derive(Debug, Deserialize)]
  19. #[serde(rename_all = "camelCase")]
  20. struct TauriOptions {
  21. application: PathBuf,
  22. #[serde(default)]
  23. args: Vec<String>,
  24. #[cfg(target_os = "windows")]
  25. #[serde(default)]
  26. webview_options: Option<Value>,
  27. }
  28. impl TauriOptions {
  29. #[cfg(target_os = "linux")]
  30. fn into_native_object(self) -> Map<String, Value> {
  31. let mut map = Map::new();
  32. map.insert(
  33. "webkitgtk:browserOptions".into(),
  34. json!({"binary": self.application, "args": self.args}),
  35. );
  36. map
  37. }
  38. #[cfg(target_os = "windows")]
  39. fn into_native_object(self) -> Map<String, Value> {
  40. let mut map = Map::new();
  41. map.insert("ms:edgeChromium".into(), json!(true));
  42. map.insert("browserName".into(), json!("webview2"));
  43. map.insert(
  44. "ms:edgeOptions".into(),
  45. json!({"binary": self.application, "args": self.args, "webviewOptions": self.webview_options}),
  46. );
  47. map
  48. }
  49. }
  50. async fn handle(
  51. client: HttpClient,
  52. mut req: Request<Body>,
  53. args: Args,
  54. ) -> Result<Response<Body>, Error> {
  55. // manipulate a new session to convert options to the native driver format
  56. if let (&Method::POST, "/session") = (req.method(), req.uri().path()) {
  57. let (mut parts, body) = req.into_parts();
  58. // get the body from the future stream and parse it as json
  59. let body = hyper::body::to_bytes(body).await?;
  60. let json: Value = serde_json::from_slice(&body)?;
  61. // manipulate the json to convert from tauri option to native driver options
  62. let json = map_capabilities(json);
  63. // serialize json and update the content-length header to be accurate
  64. let bytes = serde_json::to_vec(&json)?;
  65. parts.headers.insert(CONTENT_LENGTH, bytes.len().into());
  66. req = Request::from_parts(parts, bytes.into());
  67. }
  68. client
  69. .request(forward_to_native_driver(req, args)?)
  70. .err_into()
  71. .await
  72. }
  73. /// Transform the request to a request for the native webdriver server.
  74. fn forward_to_native_driver(mut req: Request<Body>, args: Args) -> Result<Request<Body>, Error> {
  75. let host: Authority = {
  76. let headers = req.headers_mut();
  77. headers.remove("host").expect("hyper request has host")
  78. }
  79. .to_str()?
  80. .parse()?;
  81. let path = req
  82. .uri()
  83. .path_and_query()
  84. .expect("hyper request has uri")
  85. .clone();
  86. let uri = format!(
  87. "http://{}:{}{}",
  88. host.host(),
  89. args.native_port,
  90. path.as_str()
  91. );
  92. let (mut parts, body) = req.into_parts();
  93. parts.uri = uri.parse()?;
  94. Ok(Request::from_parts(parts, body))
  95. }
  96. /// only happy path for now, no errors
  97. fn map_capabilities(mut json: Value) -> Value {
  98. let mut native = None;
  99. if let Some(capabilities) = json.get_mut("capabilities") {
  100. if let Some(always_match) = capabilities.get_mut("alwaysMatch") {
  101. if let Some(always_match) = always_match.as_object_mut() {
  102. if let Some(tauri_options) = always_match.remove(TAURI_OPTIONS) {
  103. if let Ok(options) = serde_json::from_value::<TauriOptions>(tauri_options) {
  104. native = Some(options.into_native_object());
  105. }
  106. }
  107. if let Some(native) = native.clone() {
  108. always_match.extend(native);
  109. }
  110. }
  111. }
  112. }
  113. if let Some(native) = native {
  114. if let Some(desired) = json.get_mut("desiredCapabilities") {
  115. if let Some(desired) = desired.as_object_mut() {
  116. desired.remove(TAURI_OPTIONS);
  117. desired.extend(native);
  118. }
  119. }
  120. }
  121. json
  122. }
  123. #[tokio::main(flavor = "current_thread")]
  124. pub async fn run(args: Args, mut _driver: Child) -> Result<(), Error> {
  125. #[cfg(unix)]
  126. let (signals_handle, signals_task) = {
  127. use futures_util::StreamExt;
  128. use signal_hook::consts::signal::*;
  129. let signals = signal_hook_tokio::Signals::new(&[SIGTERM, SIGINT, SIGQUIT])?;
  130. let signals_handle = signals.handle();
  131. let signals_task = tokio::spawn(async move {
  132. let mut signals = signals.fuse();
  133. while let Some(signal) = signals.next().await {
  134. match signal {
  135. SIGTERM | SIGINT | SIGQUIT => {
  136. _driver
  137. .kill()
  138. .expect("unable to kill native webdriver server");
  139. std::process::exit(0);
  140. }
  141. _ => unreachable!(),
  142. }
  143. }
  144. });
  145. (signals_handle, signals_task)
  146. };
  147. let address = std::net::SocketAddr::from(([127, 0, 0, 1], args.port));
  148. // the client we use to proxy requests to the native webdriver
  149. let client = Client::builder()
  150. .http1_preserve_header_case(true)
  151. .http1_title_case_headers(true)
  152. .retry_canceled_requests(false)
  153. .build_http();
  154. // pass a copy of the client to the http request handler
  155. let service = make_service_fn(move |_| {
  156. let client = client.clone();
  157. let args = args.clone();
  158. async move {
  159. Ok::<_, Infallible>(service_fn(move |request| {
  160. handle(client.clone(), request, args.clone())
  161. }))
  162. }
  163. });
  164. // set up a http1 server that uses the service we just created
  165. Server::bind(&address)
  166. .http1_title_case_headers(true)
  167. .http1_preserve_header_case(true)
  168. .http1_only(true)
  169. .serve(service)
  170. .await?;
  171. #[cfg(unix)]
  172. {
  173. signals_handle.close();
  174. signals_task.await?;
  175. }
  176. Ok(())
  177. }