server.rs 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. // Copyright 2019-2021 Tauri Programme within The Commons Conservancy
  2. // SPDX-License-Identifier: Apache-2.0
  3. // SPDX-License-Identifier: MIT
  4. use crate::cli::Args;
  5. use anyhow::Error;
  6. use futures::TryFutureExt;
  7. use hyper::header::CONTENT_LENGTH;
  8. use hyper::http::uri::Authority;
  9. use hyper::service::{make_service_fn, service_fn};
  10. use hyper::{Body, Client, Method, Request, Response, Server};
  11. use serde::Deserialize;
  12. use serde_json::{json, Map, Value};
  13. use std::convert::Infallible;
  14. use std::path::PathBuf;
  15. use std::process::Child;
  16. type HttpClient = Client<hyper::client::HttpConnector>;
  17. const TAURI_OPTIONS: &str = "tauri:options";
  18. #[derive(Debug, Deserialize)]
  19. struct TauriOptions {
  20. application: PathBuf,
  21. }
  22. impl TauriOptions {
  23. #[cfg(target_os = "linux")]
  24. fn into_native_object(self) -> Map<String, Value> {
  25. let mut map = Map::new();
  26. map.insert(
  27. "webkitgtk:browserOptions".into(),
  28. json!({"binary": self.application}),
  29. );
  30. map
  31. }
  32. #[cfg(target_os = "windows")]
  33. fn into_native_object(self) -> Map<String, Value> {
  34. let mut map = Map::new();
  35. map.insert("ms:edgeChromium".into(), json!(true));
  36. map.insert("browserName".into(), json!("webview2"));
  37. map.insert("ms:edgeOptions".into(), json!({"binary": self.application}));
  38. map
  39. }
  40. }
  41. async fn handle(
  42. client: HttpClient,
  43. mut req: Request<Body>,
  44. args: Args,
  45. ) -> Result<Response<Body>, Error> {
  46. // manipulate a new session to convert options to the native driver format
  47. if let (&Method::POST, "/session") = (req.method(), req.uri().path()) {
  48. let (mut parts, body) = req.into_parts();
  49. // get the body from the future stream and parse it as json
  50. let body = hyper::body::to_bytes(body).await?;
  51. let json: Value = serde_json::from_slice(&body)?;
  52. // manipulate the json to convert from tauri option to native driver options
  53. let json = map_capabilities(json);
  54. // serialize json and update the content-length header to be accurate
  55. let bytes = serde_json::to_vec(&json)?;
  56. parts.headers.insert(CONTENT_LENGTH, bytes.len().into());
  57. req = Request::from_parts(parts, bytes.into());
  58. }
  59. client
  60. .request(forward_to_native_driver(req, args)?)
  61. .err_into()
  62. .await
  63. }
  64. /// Transform the request to a request for the native webdriver server.
  65. fn forward_to_native_driver(mut req: Request<Body>, args: Args) -> Result<Request<Body>, Error> {
  66. let host: Authority = {
  67. let headers = req.headers_mut();
  68. headers.remove("host").expect("hyper request has host")
  69. }
  70. .to_str()?
  71. .parse()?;
  72. let path = req
  73. .uri()
  74. .path_and_query()
  75. .expect("hyper request has uri")
  76. .clone();
  77. let uri = format!(
  78. "http://{}:{}{}",
  79. host.host(),
  80. args.native_port,
  81. path.as_str()
  82. );
  83. let (mut parts, body) = req.into_parts();
  84. parts.uri = uri.parse()?;
  85. Ok(Request::from_parts(parts, body))
  86. }
  87. /// only happy path for now, no errors
  88. fn map_capabilities(mut json: Value) -> Value {
  89. let mut native = None;
  90. if let Some(capabilities) = json.get_mut("capabilities") {
  91. if let Some(always_match) = capabilities.get_mut("alwaysMatch") {
  92. if let Some(always_match) = always_match.as_object_mut() {
  93. if let Some(tauri_options) = always_match.remove(TAURI_OPTIONS) {
  94. if let Ok(options) = serde_json::from_value::<TauriOptions>(tauri_options) {
  95. native = Some(options.into_native_object());
  96. }
  97. }
  98. if let Some(native) = native.clone() {
  99. always_match.extend(native);
  100. }
  101. }
  102. }
  103. }
  104. if let Some(native) = native {
  105. if let Some(desired) = json.get_mut("desiredCapabilities") {
  106. if let Some(desired) = desired.as_object_mut() {
  107. desired.remove(TAURI_OPTIONS);
  108. desired.extend(native);
  109. }
  110. }
  111. }
  112. json
  113. }
  114. #[tokio::main(flavor = "current_thread")]
  115. pub async fn run(args: Args, mut _driver: Child) -> Result<(), Error> {
  116. #[cfg(unix)]
  117. let (signals_handle, signals_task) = {
  118. use futures::StreamExt;
  119. use signal_hook::consts::signal::*;
  120. let signals = signal_hook_tokio::Signals::new(&[SIGTERM, SIGINT, SIGQUIT])?;
  121. let signals_handle = signals.handle();
  122. let signals_task = tokio::spawn(async move {
  123. let mut signals = signals.fuse();
  124. while let Some(signal) = signals.next().await {
  125. match signal {
  126. SIGTERM | SIGINT | SIGQUIT => {
  127. _driver
  128. .kill()
  129. .expect("unable to kill native webdriver server");
  130. std::process::exit(0);
  131. }
  132. _ => unreachable!(),
  133. }
  134. }
  135. });
  136. (signals_handle, signals_task)
  137. };
  138. let address = std::net::SocketAddr::from(([127, 0, 0, 1], args.port));
  139. // the client we use to proxy requests to the native webdriver
  140. let client = Client::builder()
  141. .http1_preserve_header_case(true)
  142. .http1_title_case_headers(true)
  143. .retry_canceled_requests(false)
  144. .build_http();
  145. // pass a copy of the client to the http request handler
  146. let service = make_service_fn(move |_| {
  147. let client = client.clone();
  148. let args = args.clone();
  149. async move {
  150. Ok::<_, Infallible>(service_fn(move |request| {
  151. handle(client.clone(), request, args.clone())
  152. }))
  153. }
  154. });
  155. // set up a http1 server that uses the service we just created
  156. Server::bind(&address)
  157. .http1_title_case_headers(true)
  158. .http1_preserve_header_case(true)
  159. .http1_only(true)
  160. .serve(service)
  161. .await?;
  162. #[cfg(unix)]
  163. {
  164. signals_handle.close();
  165. signals_task.await?;
  166. }
  167. Ok(())
  168. }