protocol.rs 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457
  1. // Copyright 2019-2024 Tauri Programme within The Commons Conservancy
  2. // SPDX-License-Identifier: Apache-2.0
  3. // SPDX-License-Identifier: MIT
  4. use std::{borrow::Cow, sync::Arc};
  5. use crate::{
  6. manager::AppManager,
  7. webview::{InvokeRequest, UriSchemeProtocolHandler},
  8. Runtime,
  9. };
  10. use http::{
  11. header::{ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_ORIGIN, CONTENT_TYPE},
  12. HeaderValue, Method, StatusCode,
  13. };
  14. use super::{CallbackFn, InvokeResponse, InvokeResponseBody};
  15. const TAURI_CALLBACK_HEADER_NAME: &str = "Tauri-Callback";
  16. const TAURI_ERROR_HEADER_NAME: &str = "Tauri-Error";
  17. pub fn message_handler<R: Runtime>(
  18. manager: Arc<AppManager<R>>,
  19. ) -> crate::runtime::webview::WebviewIpcHandler<crate::EventLoopMessage, R> {
  20. Box::new(move |webview, request| handle_ipc_message(request, &manager, &webview.label))
  21. }
  22. pub fn get<R: Runtime>(manager: Arc<AppManager<R>>, label: String) -> UriSchemeProtocolHandler {
  23. Box::new(move |request, responder| {
  24. #[cfg(feature = "tracing")]
  25. let span = tracing::trace_span!(
  26. "ipc::request",
  27. kind = "custom-protocol",
  28. request = tracing::field::Empty
  29. )
  30. .entered();
  31. let manager = manager.clone();
  32. let label = label.clone();
  33. let respond = move |mut response: http::Response<Cow<'static, [u8]>>| {
  34. response
  35. .headers_mut()
  36. .insert(ACCESS_CONTROL_ALLOW_ORIGIN, HeaderValue::from_static("*"));
  37. responder.respond(response);
  38. };
  39. match *request.method() {
  40. Method::POST => {
  41. if let Some(webview) = manager.get_webview(&label) {
  42. #[cfg(target_os = "linux")]
  43. let respond = {
  44. let webview_ = webview.clone();
  45. move |response| {
  46. let _ = webview_.run_on_main_thread(move || {
  47. respond(response);
  48. });
  49. }
  50. };
  51. match parse_invoke_request(&manager, request) {
  52. Ok(request) => {
  53. #[cfg(feature = "tracing")]
  54. span.record(
  55. "request",
  56. match &request.body {
  57. InvokeBody::Json(j) => serde_json::to_string(j).unwrap(),
  58. InvokeBody::Raw(b) => serde_json::to_string(b).unwrap(),
  59. },
  60. );
  61. #[cfg(feature = "tracing")]
  62. let request_span = tracing::trace_span!("ipc::request::handle", cmd = request.cmd);
  63. webview.on_message(
  64. request,
  65. Box::new(move |_webview, _cmd, response, _callback, _error| {
  66. #[cfg(feature = "tracing")]
  67. let _respond_span = tracing::trace_span!(
  68. parent: &request_span,
  69. "ipc::request::respond"
  70. )
  71. .entered();
  72. #[cfg(feature = "tracing")]
  73. let response_span = tracing::trace_span!(
  74. "ipc::request::response",
  75. response = serde_json::to_string(&response).unwrap(),
  76. mime_type = tracing::field::Empty
  77. )
  78. .entered();
  79. let (mut response, mime_type) = match response {
  80. InvokeResponse::Ok(InvokeResponseBody::Json(v)) => (
  81. http::Response::new(serde_json::to_vec(&v).unwrap().into()),
  82. mime::APPLICATION_JSON,
  83. ),
  84. InvokeResponse::Ok(InvokeResponseBody::Raw(v)) => (
  85. http::Response::new(v.into()),
  86. mime::APPLICATION_OCTET_STREAM,
  87. ),
  88. InvokeResponse::Err(e) => {
  89. let mut response =
  90. http::Response::new(serde_json::to_vec(&e.0).unwrap().into());
  91. *response.status_mut() = StatusCode::BAD_REQUEST;
  92. (response, mime::APPLICATION_JSON)
  93. }
  94. };
  95. #[cfg(feature = "tracing")]
  96. response_span.record("mime_type", mime_type.essence_str());
  97. response.headers_mut().insert(
  98. CONTENT_TYPE,
  99. HeaderValue::from_str(mime_type.essence_str()).unwrap(),
  100. );
  101. respond(response);
  102. }),
  103. );
  104. }
  105. Err(e) => {
  106. respond(
  107. http::Response::builder()
  108. .status(StatusCode::BAD_REQUEST)
  109. .header(CONTENT_TYPE, mime::TEXT_PLAIN.essence_str())
  110. .body(e.as_bytes().to_vec().into())
  111. .unwrap(),
  112. );
  113. }
  114. }
  115. } else {
  116. respond(
  117. http::Response::builder()
  118. .status(StatusCode::BAD_REQUEST)
  119. .header(CONTENT_TYPE, mime::TEXT_PLAIN.essence_str())
  120. .body(
  121. "failed to acquire webview reference"
  122. .as_bytes()
  123. .to_vec()
  124. .into(),
  125. )
  126. .unwrap(),
  127. );
  128. }
  129. }
  130. Method::OPTIONS => {
  131. let mut r = http::Response::new(Vec::new().into());
  132. r.headers_mut().insert(
  133. ACCESS_CONTROL_ALLOW_HEADERS,
  134. HeaderValue::from_static("Content-Type, Tauri-Callback, Tauri-Error, Tauri-Channel-Id"),
  135. );
  136. respond(r);
  137. }
  138. _ => {
  139. let mut r = http::Response::new(
  140. "only POST and OPTIONS are allowed"
  141. .as_bytes()
  142. .to_vec()
  143. .into(),
  144. );
  145. *r.status_mut() = StatusCode::METHOD_NOT_ALLOWED;
  146. r.headers_mut().insert(
  147. CONTENT_TYPE,
  148. HeaderValue::from_str(mime::TEXT_PLAIN.essence_str()).unwrap(),
  149. );
  150. respond(r);
  151. }
  152. }
  153. })
  154. }
  155. fn handle_ipc_message<R: Runtime>(message: String, manager: &AppManager<R>, label: &str) {
  156. if let Some(webview) = manager.get_webview(label) {
  157. #[cfg(feature = "tracing")]
  158. let _span =
  159. tracing::trace_span!("ipc::request", kind = "post-message", request = message).entered();
  160. use serde::{Deserialize, Deserializer};
  161. pub(crate) struct HeaderMap(http::HeaderMap);
  162. impl<'de> Deserialize<'de> for HeaderMap {
  163. fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
  164. where
  165. D: Deserializer<'de>,
  166. {
  167. let map = std::collections::HashMap::<String, String>::deserialize(deserializer)?;
  168. let mut headers = http::HeaderMap::default();
  169. for (key, value) in map {
  170. if let (Ok(key), Ok(value)) = (
  171. http::header::HeaderName::from_bytes(key.as_bytes()),
  172. http::HeaderValue::from_str(&value),
  173. ) {
  174. headers.insert(key, value);
  175. } else {
  176. return Err(serde::de::Error::custom(format!(
  177. "invalid header `{key}` `{value}`"
  178. )));
  179. }
  180. }
  181. Ok(Self(headers))
  182. }
  183. }
  184. #[derive(Deserialize)]
  185. struct RequestOptions {
  186. headers: HeaderMap,
  187. }
  188. #[derive(Deserialize)]
  189. struct Message {
  190. cmd: String,
  191. callback: CallbackFn,
  192. error: CallbackFn,
  193. payload: serde_json::Value,
  194. options: Option<RequestOptions>,
  195. }
  196. #[allow(unused_mut)]
  197. let mut invoke_message: Option<crate::Result<Message>> = None;
  198. #[cfg(feature = "isolation")]
  199. {
  200. #[derive(serde::Deserialize)]
  201. struct IsolationMessage<'a> {
  202. cmd: String,
  203. callback: CallbackFn,
  204. error: CallbackFn,
  205. payload: crate::utils::pattern::isolation::RawIsolationPayload<'a>,
  206. options: Option<RequestOptions>,
  207. }
  208. if let crate::Pattern::Isolation { crypto_keys, .. } = &*manager.pattern {
  209. #[cfg(feature = "tracing")]
  210. let _span = tracing::trace_span!("ipc::request::decrypt_isolation_payload").entered();
  211. invoke_message.replace(
  212. serde_json::from_str::<IsolationMessage<'_>>(&message)
  213. .map_err(Into::into)
  214. .and_then(|message| {
  215. Ok(Message {
  216. cmd: message.cmd,
  217. callback: message.callback,
  218. error: message.error,
  219. payload: serde_json::from_slice(&crypto_keys.decrypt(message.payload)?)?,
  220. options: message.options,
  221. })
  222. }),
  223. );
  224. }
  225. }
  226. let message = invoke_message.unwrap_or_else(|| {
  227. #[cfg(feature = "tracing")]
  228. let _span = tracing::trace_span!("ipc::request::deserialize").entered();
  229. serde_json::from_str::<Message>(&message).map_err(Into::into)
  230. });
  231. match message {
  232. Ok(message) => {
  233. let request = InvokeRequest {
  234. cmd: message.cmd,
  235. callback: message.callback,
  236. error: message.error,
  237. body: message.payload.into(),
  238. headers: message.options.map(|o| o.headers.0).unwrap_or_default(),
  239. };
  240. #[cfg(feature = "tracing")]
  241. let request_span = tracing::trace_span!("ipc::request::handle", cmd = request.cmd);
  242. webview.on_message(
  243. request,
  244. Box::new(move |webview, cmd, response, callback, error| {
  245. use crate::ipc::{
  246. format_callback::{
  247. format as format_callback, format_result as format_callback_result,
  248. },
  249. Channel,
  250. };
  251. use crate::sealed::ManagerBase;
  252. #[cfg(feature = "tracing")]
  253. let _respond_span = tracing::trace_span!(
  254. parent: &request_span,
  255. "ipc::request::respond"
  256. )
  257. .entered();
  258. // the channel data command is the only command that uses a custom protocol on Linux
  259. if webview.manager().webview.invoke_responder.is_none()
  260. && cmd != crate::ipc::channel::FETCH_CHANNEL_DATA_COMMAND
  261. {
  262. fn responder_eval<R: Runtime>(
  263. webview: &crate::Webview<R>,
  264. js: crate::Result<String>,
  265. error: CallbackFn,
  266. ) {
  267. let eval_js = match js {
  268. Ok(js) => js,
  269. Err(e) => format_callback(error, &e.to_string())
  270. .expect("unable to serialize response error string to json"),
  271. };
  272. let _ = webview.eval(&eval_js);
  273. }
  274. #[cfg(feature = "tracing")]
  275. let _response_span = tracing::trace_span!(
  276. "ipc::request::response",
  277. response = serde_json::to_string(&response).unwrap(),
  278. mime_type = match &response {
  279. InvokeResponse::Ok(InvokeResponseBody::Json(_)) => mime::APPLICATION_JSON,
  280. InvokeResponse::Ok(InvokeResponseBody::Raw(_)) => mime::APPLICATION_OCTET_STREAM,
  281. InvokeResponse::Err(_) => mime::APPLICATION_JSON,
  282. }
  283. .essence_str()
  284. )
  285. .entered();
  286. match response {
  287. InvokeResponse::Ok(InvokeResponseBody::Json(v)) => {
  288. if !(cfg!(target_os = "macos") || cfg!(target_os = "ios")) && v.len() > 4000 {
  289. let _ = Channel::from_callback_fn(webview, callback).send(&v);
  290. } else {
  291. responder_eval(
  292. &webview,
  293. format_callback_result(Result::<_, ()>::Ok(v), callback, error),
  294. error,
  295. )
  296. }
  297. }
  298. InvokeResponse::Ok(InvokeResponseBody::Raw(v)) => {
  299. if cfg!(target_os = "macos") || cfg!(target_os = "ios") {
  300. responder_eval(
  301. &webview,
  302. format_callback_result(Result::<_, ()>::Ok(v), callback, error),
  303. error,
  304. );
  305. } else {
  306. let _ = Channel::from_callback_fn(webview, callback).send(v);
  307. }
  308. }
  309. InvokeResponse::Err(e) => responder_eval(
  310. &webview,
  311. format_callback_result(Result::<(), _>::Err(&e.0), callback, error),
  312. error,
  313. ),
  314. }
  315. }
  316. }),
  317. );
  318. }
  319. Err(e) => {
  320. #[cfg(feature = "tracing")]
  321. tracing::trace!("ipc.request.error {}", e);
  322. let _ = webview.eval(&format!(
  323. r#"console.error({})"#,
  324. serde_json::Value::String(e.to_string())
  325. ));
  326. }
  327. }
  328. }
  329. }
  330. fn parse_invoke_request<R: Runtime>(
  331. #[allow(unused_variables)] manager: &AppManager<R>,
  332. request: http::Request<Vec<u8>>,
  333. ) -> std::result::Result<InvokeRequest, String> {
  334. #[allow(unused_mut)]
  335. let (parts, mut body) = request.into_parts();
  336. // skip leading `/`
  337. let cmd = percent_encoding::percent_decode(parts.uri.path()[1..].as_bytes())
  338. .decode_utf8_lossy()
  339. .to_string();
  340. // on Android and on Linux (without the linux-ipc-protocol Cargo feature) we cannot read the request body
  341. // so we must ignore it because some commands use the IPC for faster response
  342. let has_payload = !body.is_empty();
  343. #[cfg(feature = "isolation")]
  344. if let crate::Pattern::Isolation { crypto_keys, .. } = &*manager.pattern {
  345. // if the platform does not support request body, we ignore it
  346. if has_payload {
  347. #[cfg(feature = "tracing")]
  348. let _span = tracing::trace_span!("ipc::request::decrypt_isolation_payload").entered();
  349. body = crate::utils::pattern::isolation::RawIsolationPayload::try_from(&body)
  350. .and_then(|raw| crypto_keys.decrypt(raw))
  351. .map_err(|e| e.to_string())?;
  352. }
  353. }
  354. let callback = CallbackFn(
  355. parts
  356. .headers
  357. .get(TAURI_CALLBACK_HEADER_NAME)
  358. .ok_or("missing Tauri-Callback header")?
  359. .to_str()
  360. .map_err(|_| "Tauri callback header value must be a string")?
  361. .parse()
  362. .map_err(|_| "Tauri callback header value must be a numeric string")?,
  363. );
  364. let error = CallbackFn(
  365. parts
  366. .headers
  367. .get(TAURI_ERROR_HEADER_NAME)
  368. .ok_or("missing Tauri-Error header")?
  369. .to_str()
  370. .map_err(|_| "Tauri error header value must be a string")?
  371. .parse()
  372. .map_err(|_| "Tauri error header value must be a numeric string")?,
  373. );
  374. let content_type = parts
  375. .headers
  376. .get(reqwest::header::CONTENT_TYPE)
  377. .and_then(|h| h.to_str().ok())
  378. .map(|mime| mime.parse())
  379. .unwrap_or(Ok(mime::APPLICATION_OCTET_STREAM))
  380. .map_err(|_| "unknown content type")?;
  381. #[cfg(feature = "tracing")]
  382. let span = tracing::trace_span!("ipc::request::deserialize").entered();
  383. let body = if content_type == mime::APPLICATION_OCTET_STREAM {
  384. body.into()
  385. } else if content_type == mime::APPLICATION_JSON {
  386. // if the platform does not support request body, we ignore it
  387. if has_payload {
  388. serde_json::from_slice::<serde_json::Value>(&body)
  389. .map_err(|e| e.to_string())?
  390. .into()
  391. } else {
  392. serde_json::Value::Object(Default::default()).into()
  393. }
  394. } else {
  395. return Err(format!("content type {content_type} is not implemented"));
  396. };
  397. #[cfg(feature = "tracing")]
  398. drop(span);
  399. let payload = InvokeRequest {
  400. cmd,
  401. callback,
  402. error,
  403. body,
  404. headers: parts.headers,
  405. };
  406. Ok(payload)
  407. }