瀏覽代碼

fix(tests): start updater server once

Lucas Nogueira 11 月之前
父節點
當前提交
649e01f4e0
共有 1 個文件被更改,包括 96 次插入67 次删除
  1. 96 67
      core/tests/app-updater/tests/update.rs

+ 96 - 67
core/tests/app-updater/tests/update.rs

@@ -17,6 +17,7 @@ use hyper::{
   Body, Method, Request, Response, StatusCode,
 };
 use serde::Serialize;
+use tokio::sync::Mutex;
 use tokio_util::codec::{BytesCodec, FramedRead};
 
 const UPDATER_PRIVATE_KEY: &str = "dW50cnVzdGVkIGNvbW1lbnQ6IHJzaWduIGVuY3J5cHRlZCBzZWNyZXQga2V5ClJXUlRZMEl5dkpDN09RZm5GeVAzc2RuYlNzWVVJelJRQnNIV2JUcGVXZUplWXZXYXpqUUFBQkFBQUFBQUFBQUFBQUlBQUFBQTZrN2RnWGh5dURxSzZiL1ZQSDdNcktiaHRxczQwMXdQelRHbjRNcGVlY1BLMTBxR2dpa3I3dDE1UTVDRDE4MXR4WlQwa1BQaXdxKy9UU2J2QmVSNXhOQWFDeG1GSVllbUNpTGJQRkhhTnROR3I5RmdUZi90OGtvaGhJS1ZTcjdZU0NyYzhQWlQ5cGM9Cg==";
@@ -383,6 +384,13 @@ fn update_app_flow<F: FnOnce(Options<'_>) -> (PathBuf, TauriVersion)>(build_app_
     }
   };
 
+  let updater_state = UpdaterState {
+    target: Default::default(),
+    signature: Default::default(),
+    updater_path: Default::default(),
+  };
+  let (runtime, shutdown_tx) = start_updater_server(updater_state.clone());
+
   for (bundle_target, out_bundle_path) in bundle_paths(&app_root, UPDATE_APP_VERSION) {
     let mut bundle_updater_ext = out_bundle_path
       .extension()
@@ -418,70 +426,10 @@ fn update_app_flow<F: FnOnce(Options<'_>) -> (PathBuf, TauriVersion)>(build_app_
     ));
     std::fs::rename(&out_updater_path, &updater_path).expect("failed to rename bundle");
 
-    let target = target.clone();
-
-    let (tx, rx) = tokio::sync::oneshot::channel::<()>();
-
-    let runtime = tokio::runtime::Runtime::new().unwrap();
-
-    runtime.spawn(async move {
-      // create the updater server
-      let addr = "127.0.0.1:3007".parse().unwrap();
-
-      let make_service = make_service_fn(move |_| {
-        let updater_path = updater_path.clone();
-        let signature = signature.clone();
-        let target = target.clone();
-        async move {
-          Ok::<_, hyper::Error>(service_fn(move |req| {
-            let updater_path = updater_path.clone();
-            let signature = signature.clone();
-            let target = target.clone();
-            async move {
-              match (req.method(), req.uri().path()) {
-                (&Method::GET, "/") => {
-                  let mut platforms = HashMap::new();
-
-                  platforms.insert(
-                    target.clone(),
-                    PlatformUpdate {
-                      signature: signature.clone(),
-                      url: "http://localhost:3007/download",
-                      with_elevated_task: false,
-                    },
-                  );
-                  let body = serde_json::to_vec(&Update {
-                    version: UPDATE_APP_VERSION,
-                    date: time::OffsetDateTime::now_utc()
-                      .format(&time::format_description::well_known::Rfc3339)
-                      .unwrap(),
-                    platforms,
-                  })
-                  .unwrap();
-
-                  Ok(Response::new(hyper::Body::from(body)))
-                }
-                (&Method::GET, "/download") => {
-                  let file = tokio::fs::File::open(&updater_path).await.unwrap();
-                  let stream = FramedRead::new(file, BytesCodec::new());
-                  let body = Body::wrap_stream(stream);
-                  return Ok(Response::new(body));
-                }
-                _ => Response::builder()
-                  .status(StatusCode::NOT_FOUND)
-                  .body("Not Found".into()),
-              }
-            }
-          }))
-        }
-      });
-      let server = hyper::Server::bind(&addr).serve(make_service);
-
-      let graceful = server.with_graceful_shutdown(async {
-        rx.await.ok();
-      });
-
-      graceful.await.unwrap();
+    runtime.block_on(async {
+      *updater_state.target.lock().await = target.clone();
+      *updater_state.signature.lock().await = signature.clone();
+      *updater_state.updater_path.lock().await = updater_path.clone();
     });
 
     let config = Config {
@@ -575,8 +523,89 @@ fn update_app_flow<F: FnOnce(Options<'_>) -> (PathBuf, TauriVersion)>(build_app_
     // force Rust to rebuild the binary so it doesn't conflict with other test runs
     #[cfg(windows)]
     std::fs::remove_file(tauri_v1_fixture_dir.join("target/debug/app-updater.exe")).unwrap();
-
-    // graceful shutdown
-    tx.send(()).unwrap();
   }
+
+  // graceful shutdown
+  shutdown_tx.send(()).unwrap();
+}
+
+#[derive(Clone)]
+struct UpdaterState {
+  target: Arc<Mutex<String>>,
+  signature: Arc<Mutex<String>>,
+  updater_path: Arc<Mutex<PathBuf>>,
+}
+
+fn start_updater_server(
+  state: UpdaterState,
+) -> (tokio::runtime::Runtime, tokio::sync::oneshot::Sender<()>) {
+  let (tx, rx) = tokio::sync::oneshot::channel::<()>();
+
+  let runtime = tokio::runtime::Runtime::new().unwrap();
+
+  runtime.spawn(async move {
+    // create the updater server
+    let addr = "127.0.0.1:3007".parse().unwrap();
+
+    let make_service = make_service_fn(move |_| {
+      let state = state.clone();
+      async move {
+        Ok::<_, hyper::Error>(service_fn(move |req| {
+          let state = state.clone();
+          async move {
+            match (req.method(), req.uri().path()) {
+              (&Method::GET, "/") => {
+                let mut platforms = HashMap::new();
+
+                platforms.insert(
+                  state.target.lock().await.clone(),
+                  PlatformUpdate {
+                    signature: state.signature.lock().await.clone(),
+                    url: "http://localhost:3007/download",
+                    with_elevated_task: false,
+                  },
+                );
+                let body = serde_json::to_vec(&Update {
+                  version: UPDATE_APP_VERSION,
+                  date: time::OffsetDateTime::now_utc()
+                    .format(&time::format_description::well_known::Rfc3339)
+                    .unwrap(),
+                  platforms,
+                })
+                .unwrap();
+
+                Ok(Response::new(hyper::Body::from(body)))
+              }
+              (&Method::GET, "/download") => {
+                println!("downloading updater");
+                let file = tokio::fs::File::open(&*state.updater_path.lock().await)
+                  .await
+                  .unwrap();
+                println!("opened updater file");
+                let stream = FramedRead::new(file, BytesCodec::new());
+                let body = Body::wrap_stream(stream);
+                println!("sending updater response");
+                return Ok(Response::new(body));
+              }
+              _ => Response::builder()
+                .status(StatusCode::NOT_FOUND)
+                .body("Not Found".into()),
+            }
+          }
+        }))
+      }
+    });
+    let server = hyper::Server::bind(&addr).serve(make_service);
+
+    let graceful = server.with_graceful_shutdown(async {
+      println!("received shutdown");
+      rx.await.ok();
+    });
+
+    graceful.await.unwrap();
+
+    println!("done serving updates");
+  });
+
+  (runtime, tx)
 }