protocol.rs 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447
  1. // Copyright 2019-2024 Tauri Programme within The Commons Conservancy
  2. // SPDX-License-Identifier: Apache-2.0
  3. // SPDX-License-Identifier: MIT
  4. use std::{borrow::Cow, sync::Arc};
  5. use crate::{
  6. manager::AppManager,
  7. webview::{InvokeRequest, UriSchemeProtocolHandler},
  8. Runtime,
  9. };
  10. use http::{
  11. header::{ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_ORIGIN, CONTENT_TYPE},
  12. HeaderValue, Method, StatusCode,
  13. };
  14. use super::{CallbackFn, InvokeResponse, InvokeResponseBody};
  15. const TAURI_CALLBACK_HEADER_NAME: &str = "Tauri-Callback";
  16. const TAURI_ERROR_HEADER_NAME: &str = "Tauri-Error";
  17. pub fn message_handler<R: Runtime>(
  18. manager: Arc<AppManager<R>>,
  19. ) -> crate::runtime::webview::WebviewIpcHandler<crate::EventLoopMessage, R> {
  20. Box::new(move |webview, request| handle_ipc_message(request, &manager, &webview.label))
  21. }
  22. pub fn get<R: Runtime>(manager: Arc<AppManager<R>>, label: String) -> UriSchemeProtocolHandler {
  23. Box::new(move |request, responder| {
  24. #[cfg(feature = "tracing")]
  25. let span = tracing::trace_span!(
  26. "ipc::request",
  27. kind = "custom-protocol",
  28. request = tracing::field::Empty
  29. )
  30. .entered();
  31. let manager = manager.clone();
  32. let label = label.clone();
  33. let respond = move |mut response: http::Response<Cow<'static, [u8]>>| {
  34. response
  35. .headers_mut()
  36. .insert(ACCESS_CONTROL_ALLOW_ORIGIN, HeaderValue::from_static("*"));
  37. responder.respond(response);
  38. };
  39. match *request.method() {
  40. Method::POST => {
  41. if let Some(webview) = manager.get_webview(&label) {
  42. match parse_invoke_request(&manager, request) {
  43. Ok(request) => {
  44. #[cfg(feature = "tracing")]
  45. span.record(
  46. "request",
  47. match &request.body {
  48. InvokeBody::Json(j) => serde_json::to_string(j).unwrap(),
  49. InvokeBody::Raw(b) => serde_json::to_string(b).unwrap(),
  50. },
  51. );
  52. #[cfg(feature = "tracing")]
  53. let request_span = tracing::trace_span!("ipc::request::handle", cmd = request.cmd);
  54. webview.on_message(
  55. request,
  56. Box::new(move |_webview, _cmd, response, _callback, _error| {
  57. #[cfg(feature = "tracing")]
  58. let _respond_span = tracing::trace_span!(
  59. parent: &request_span,
  60. "ipc::request::respond"
  61. )
  62. .entered();
  63. #[cfg(feature = "tracing")]
  64. let response_span = tracing::trace_span!(
  65. "ipc::request::response",
  66. response = serde_json::to_string(&response).unwrap(),
  67. mime_type = tracing::field::Empty
  68. )
  69. .entered();
  70. let (mut response, mime_type) = match response {
  71. InvokeResponse::Ok(InvokeResponseBody::Json(v)) => (
  72. http::Response::new(serde_json::to_vec(&v).unwrap().into()),
  73. mime::APPLICATION_JSON,
  74. ),
  75. InvokeResponse::Ok(InvokeResponseBody::Raw(v)) => (
  76. http::Response::new(v.into()),
  77. mime::APPLICATION_OCTET_STREAM,
  78. ),
  79. InvokeResponse::Err(e) => {
  80. let mut response =
  81. http::Response::new(serde_json::to_vec(&e.0).unwrap().into());
  82. *response.status_mut() = StatusCode::BAD_REQUEST;
  83. (response, mime::APPLICATION_JSON)
  84. }
  85. };
  86. #[cfg(feature = "tracing")]
  87. response_span.record("mime_type", mime_type.essence_str());
  88. response.headers_mut().insert(
  89. CONTENT_TYPE,
  90. HeaderValue::from_str(mime_type.essence_str()).unwrap(),
  91. );
  92. respond(response);
  93. }),
  94. );
  95. }
  96. Err(e) => {
  97. respond(
  98. http::Response::builder()
  99. .status(StatusCode::BAD_REQUEST)
  100. .header(CONTENT_TYPE, mime::TEXT_PLAIN.essence_str())
  101. .body(e.as_bytes().to_vec().into())
  102. .unwrap(),
  103. );
  104. }
  105. }
  106. } else {
  107. respond(
  108. http::Response::builder()
  109. .status(StatusCode::BAD_REQUEST)
  110. .header(CONTENT_TYPE, mime::TEXT_PLAIN.essence_str())
  111. .body(
  112. "failed to acquire webview reference"
  113. .as_bytes()
  114. .to_vec()
  115. .into(),
  116. )
  117. .unwrap(),
  118. );
  119. }
  120. }
  121. Method::OPTIONS => {
  122. let mut r = http::Response::new(Vec::new().into());
  123. r.headers_mut().insert(
  124. ACCESS_CONTROL_ALLOW_HEADERS,
  125. HeaderValue::from_static("Content-Type, Tauri-Callback, Tauri-Error, Tauri-Channel-Id"),
  126. );
  127. respond(r);
  128. }
  129. _ => {
  130. let mut r = http::Response::new(
  131. "only POST and OPTIONS are allowed"
  132. .as_bytes()
  133. .to_vec()
  134. .into(),
  135. );
  136. *r.status_mut() = StatusCode::METHOD_NOT_ALLOWED;
  137. r.headers_mut().insert(
  138. CONTENT_TYPE,
  139. HeaderValue::from_str(mime::TEXT_PLAIN.essence_str()).unwrap(),
  140. );
  141. respond(r);
  142. }
  143. }
  144. })
  145. }
  146. fn handle_ipc_message<R: Runtime>(message: String, manager: &AppManager<R>, label: &str) {
  147. if let Some(webview) = manager.get_webview(label) {
  148. #[cfg(feature = "tracing")]
  149. let _span =
  150. tracing::trace_span!("ipc::request", kind = "post-message", request = message).entered();
  151. use serde::{Deserialize, Deserializer};
  152. pub(crate) struct HeaderMap(http::HeaderMap);
  153. impl<'de> Deserialize<'de> for HeaderMap {
  154. fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
  155. where
  156. D: Deserializer<'de>,
  157. {
  158. let map = std::collections::HashMap::<String, String>::deserialize(deserializer)?;
  159. let mut headers = http::HeaderMap::default();
  160. for (key, value) in map {
  161. if let (Ok(key), Ok(value)) = (
  162. http::header::HeaderName::from_bytes(key.as_bytes()),
  163. http::HeaderValue::from_str(&value),
  164. ) {
  165. headers.insert(key, value);
  166. } else {
  167. return Err(serde::de::Error::custom(format!(
  168. "invalid header `{key}` `{value}`"
  169. )));
  170. }
  171. }
  172. Ok(Self(headers))
  173. }
  174. }
  175. #[derive(Deserialize)]
  176. struct RequestOptions {
  177. headers: HeaderMap,
  178. }
  179. #[derive(Deserialize)]
  180. struct Message {
  181. cmd: String,
  182. callback: CallbackFn,
  183. error: CallbackFn,
  184. payload: serde_json::Value,
  185. options: Option<RequestOptions>,
  186. }
  187. #[allow(unused_mut)]
  188. let mut invoke_message: Option<crate::Result<Message>> = None;
  189. #[cfg(feature = "isolation")]
  190. {
  191. #[derive(serde::Deserialize)]
  192. struct IsolationMessage<'a> {
  193. cmd: String,
  194. callback: CallbackFn,
  195. error: CallbackFn,
  196. payload: crate::utils::pattern::isolation::RawIsolationPayload<'a>,
  197. options: Option<RequestOptions>,
  198. }
  199. if let crate::Pattern::Isolation { crypto_keys, .. } = &*manager.pattern {
  200. #[cfg(feature = "tracing")]
  201. let _span = tracing::trace_span!("ipc::request::decrypt_isolation_payload").entered();
  202. invoke_message.replace(
  203. serde_json::from_str::<IsolationMessage<'_>>(&message)
  204. .map_err(Into::into)
  205. .and_then(|message| {
  206. Ok(Message {
  207. cmd: message.cmd,
  208. callback: message.callback,
  209. error: message.error,
  210. payload: serde_json::from_slice(&crypto_keys.decrypt(message.payload)?)?,
  211. options: message.options,
  212. })
  213. }),
  214. );
  215. }
  216. }
  217. let message = invoke_message.unwrap_or_else(|| {
  218. #[cfg(feature = "tracing")]
  219. let _span = tracing::trace_span!("ipc::request::deserialize").entered();
  220. serde_json::from_str::<Message>(&message).map_err(Into::into)
  221. });
  222. match message {
  223. Ok(message) => {
  224. let request = InvokeRequest {
  225. cmd: message.cmd,
  226. callback: message.callback,
  227. error: message.error,
  228. body: message.payload.into(),
  229. headers: message.options.map(|o| o.headers.0).unwrap_or_default(),
  230. };
  231. #[cfg(feature = "tracing")]
  232. let request_span = tracing::trace_span!("ipc::request::handle", cmd = request.cmd);
  233. webview.on_message(
  234. request,
  235. Box::new(move |webview, cmd, response, callback, error| {
  236. use crate::ipc::{
  237. format_callback::{
  238. format as format_callback, format_result as format_callback_result,
  239. },
  240. Channel,
  241. };
  242. use crate::sealed::ManagerBase;
  243. #[cfg(feature = "tracing")]
  244. let _respond_span = tracing::trace_span!(
  245. parent: &request_span,
  246. "ipc::request::respond"
  247. )
  248. .entered();
  249. // the channel data command is the only command that uses a custom protocol on Linux
  250. if webview.manager().webview.invoke_responder.is_none()
  251. && cmd != crate::ipc::channel::FETCH_CHANNEL_DATA_COMMAND
  252. {
  253. fn responder_eval<R: Runtime>(
  254. webview: &crate::Webview<R>,
  255. js: crate::Result<String>,
  256. error: CallbackFn,
  257. ) {
  258. let eval_js = match js {
  259. Ok(js) => js,
  260. Err(e) => format_callback(error, &e.to_string())
  261. .expect("unable to serialize response error string to json"),
  262. };
  263. let _ = webview.eval(&eval_js);
  264. }
  265. #[cfg(feature = "tracing")]
  266. let _response_span = tracing::trace_span!(
  267. "ipc::request::response",
  268. response = serde_json::to_string(&response).unwrap(),
  269. mime_type = match &response {
  270. InvokeResponse::Ok(InvokeResponseBody::Json(_)) => mime::APPLICATION_JSON,
  271. InvokeResponse::Ok(InvokeResponseBody::Raw(_)) => mime::APPLICATION_OCTET_STREAM,
  272. InvokeResponse::Err(_) => mime::APPLICATION_JSON,
  273. }
  274. .essence_str()
  275. )
  276. .entered();
  277. match response {
  278. InvokeResponse::Ok(InvokeResponseBody::Json(v)) => {
  279. if !(cfg!(target_os = "macos") || cfg!(target_os = "ios")) && v.len() > 4000 {
  280. let _ = Channel::from_callback_fn(webview, callback).send(&v);
  281. } else {
  282. responder_eval(
  283. &webview,
  284. format_callback_result(Result::<_, ()>::Ok(v), callback, error),
  285. error,
  286. )
  287. }
  288. }
  289. InvokeResponse::Ok(InvokeResponseBody::Raw(v)) => {
  290. if cfg!(target_os = "macos") || cfg!(target_os = "ios") {
  291. responder_eval(
  292. &webview,
  293. format_callback_result(Result::<_, ()>::Ok(v), callback, error),
  294. error,
  295. );
  296. } else {
  297. let _ = Channel::from_callback_fn(webview, callback).send(v);
  298. }
  299. }
  300. InvokeResponse::Err(e) => responder_eval(
  301. &webview,
  302. format_callback_result(Result::<(), _>::Err(&e.0), callback, error),
  303. error,
  304. ),
  305. }
  306. }
  307. }),
  308. );
  309. }
  310. Err(e) => {
  311. #[cfg(feature = "tracing")]
  312. tracing::trace!("ipc.request.error {}", e);
  313. let _ = webview.eval(&format!(
  314. r#"console.error({})"#,
  315. serde_json::Value::String(e.to_string())
  316. ));
  317. }
  318. }
  319. }
  320. }
  321. fn parse_invoke_request<R: Runtime>(
  322. #[allow(unused_variables)] manager: &AppManager<R>,
  323. request: http::Request<Vec<u8>>,
  324. ) -> std::result::Result<InvokeRequest, String> {
  325. #[allow(unused_mut)]
  326. let (parts, mut body) = request.into_parts();
  327. // skip leading `/`
  328. let cmd = percent_encoding::percent_decode(parts.uri.path()[1..].as_bytes())
  329. .decode_utf8_lossy()
  330. .to_string();
  331. // on Android and on Linux (without the linux-ipc-protocol Cargo feature) we cannot read the request body
  332. // so we must ignore it because some commands use the IPC for faster response
  333. let has_payload = !body.is_empty();
  334. #[cfg(feature = "isolation")]
  335. if let crate::Pattern::Isolation { crypto_keys, .. } = &*manager.pattern {
  336. // if the platform does not support request body, we ignore it
  337. if has_payload {
  338. #[cfg(feature = "tracing")]
  339. let _span = tracing::trace_span!("ipc::request::decrypt_isolation_payload").entered();
  340. body = crate::utils::pattern::isolation::RawIsolationPayload::try_from(&body)
  341. .and_then(|raw| crypto_keys.decrypt(raw))
  342. .map_err(|e| e.to_string())?;
  343. }
  344. }
  345. let callback = CallbackFn(
  346. parts
  347. .headers
  348. .get(TAURI_CALLBACK_HEADER_NAME)
  349. .ok_or("missing Tauri-Callback header")?
  350. .to_str()
  351. .map_err(|_| "Tauri callback header value must be a string")?
  352. .parse()
  353. .map_err(|_| "Tauri callback header value must be a numeric string")?,
  354. );
  355. let error = CallbackFn(
  356. parts
  357. .headers
  358. .get(TAURI_ERROR_HEADER_NAME)
  359. .ok_or("missing Tauri-Error header")?
  360. .to_str()
  361. .map_err(|_| "Tauri error header value must be a string")?
  362. .parse()
  363. .map_err(|_| "Tauri error header value must be a numeric string")?,
  364. );
  365. let content_type = parts
  366. .headers
  367. .get(reqwest::header::CONTENT_TYPE)
  368. .and_then(|h| h.to_str().ok())
  369. .map(|mime| mime.parse())
  370. .unwrap_or(Ok(mime::APPLICATION_OCTET_STREAM))
  371. .map_err(|_| "unknown content type")?;
  372. #[cfg(feature = "tracing")]
  373. let span = tracing::trace_span!("ipc::request::deserialize").entered();
  374. let body = if content_type == mime::APPLICATION_OCTET_STREAM {
  375. body.into()
  376. } else if content_type == mime::APPLICATION_JSON {
  377. // if the platform does not support request body, we ignore it
  378. if has_payload {
  379. serde_json::from_slice::<serde_json::Value>(&body)
  380. .map_err(|e| e.to_string())?
  381. .into()
  382. } else {
  383. serde_json::Value::Object(Default::default()).into()
  384. }
  385. } else {
  386. return Err(format!("content type {content_type} is not implemented"));
  387. };
  388. #[cfg(feature = "tracing")]
  389. drop(span);
  390. let payload = InvokeRequest {
  391. cmd,
  392. callback,
  393. error,
  394. body,
  395. headers: parts.headers,
  396. };
  397. Ok(payload)
  398. }