Преглед на файлове

fix(core): rewrite `asset` protocol streaming, closes #6375 (#6390)

Co-authored-by: Lucas Nogueira <lucas@tauri.studio>
Amr Bashir преди 2 години
родител
ревизия
45330e3819
променени са 5 файла, в които са добавени 358 реда и са изтрити 252 реда
  1. 5 0
      .changes/core-asset-protocol-streaming-crash.md
  2. 217 0
      core/tauri/src/asset_protocol.rs
  3. 2 1
      core/tauri/src/lib.rs
  4. 5 195
      core/tauri/src/manager.rs
  5. 129 56
      examples/streaming/main.rs

+ 5 - 0
.changes/core-asset-protocol-streaming-crash.md

@@ -0,0 +1,5 @@
+---
+'tauri': 'patch:enhance'
+---
+
+Enhance the `asset` protocol to support streaming of large files.

+ 217 - 0
core/tauri/src/asset_protocol.rs

@@ -0,0 +1,217 @@
+// Copyright 2019-2023 Tauri Programme within The Commons Conservancy
+// SPDX-License-Identifier: Apache-2.0
+// SPDX-License-Identifier: MIT
+
+#![cfg(protocol_asset)]
+
+use crate::api::file::SafePathBuf;
+use crate::scope::FsScope;
+use rand::RngCore;
+use std::io::SeekFrom;
+use tauri_runtime::http::HttpRange;
+use tauri_runtime::http::{
+  header::*, status::StatusCode, MimeType, Request, Response, ResponseBuilder,
+};
+use tauri_utils::debug_eprintln;
+use tokio::fs::File;
+use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt};
+use url::Position;
+use url::Url;
+
+pub fn asset_protocol_handler(
+  request: &Request,
+  scope: FsScope,
+  window_origin: String,
+) -> Result<Response, Box<dyn std::error::Error>> {
+  let parsed_path = Url::parse(request.uri())?;
+  let filtered_path = &parsed_path[..Position::AfterPath];
+  let path = filtered_path
+    .strip_prefix("asset://localhost/")
+    // the `strip_prefix` only returns None when a request is made to `https://tauri.$P` on Windows
+    // where `$P` is not `localhost/*`
+    .unwrap_or("");
+  let path = percent_encoding::percent_decode(path.as_bytes())
+    .decode_utf8_lossy()
+    .to_string();
+
+  if let Err(e) = SafePathBuf::new(path.clone().into()) {
+    debug_eprintln!("asset protocol path \"{}\" is not valid: {}", path, e);
+    return ResponseBuilder::new().status(403).body(Vec::new());
+  }
+
+  if !scope.is_allowed(&path) {
+    debug_eprintln!("asset protocol not configured to allow the path: {}", path);
+    return ResponseBuilder::new().status(403).body(Vec::new());
+  }
+
+  let mut resp = ResponseBuilder::new().header("Access-Control-Allow-Origin", &window_origin);
+
+  crate::async_runtime::block_on(async move {
+    let mut file = File::open(&path).await?;
+
+    // get file length
+    let len = {
+      let old_pos = file.stream_position().await?;
+      let len = file.seek(SeekFrom::End(0)).await?;
+      file.seek(SeekFrom::Start(old_pos)).await?;
+      len
+    };
+
+    // get file mime type
+    let (mime_type, read_bytes) = {
+      let nbytes = len.min(8192);
+      let mut magic_buf = Vec::with_capacity(nbytes as usize);
+      let old_pos = file.stream_position().await?;
+      (&mut file).take(nbytes).read_to_end(&mut magic_buf).await?;
+      file.seek(SeekFrom::Start(old_pos)).await?;
+      (
+        MimeType::parse(&magic_buf, &path),
+        // return the `magic_bytes` if we read the whole file
+        // to avoid reading it again later if this is not a range request
+        if len < 8192 { Some(magic_buf) } else { None },
+      )
+    };
+
+    resp = resp.header(CONTENT_TYPE, &mime_type);
+
+    // handle 206 (partial range) http requests
+    let response = if let Some(range_header) = request
+      .headers()
+      .get("range")
+      .and_then(|r| r.to_str().map(|r| r.to_string()).ok())
+    {
+      resp = resp.header(ACCEPT_RANGES, "bytes");
+
+      let not_satisfiable = || {
+        ResponseBuilder::new()
+          .status(StatusCode::RANGE_NOT_SATISFIABLE)
+          .header(CONTENT_RANGE, format!("bytes */{len}"))
+          .body(vec![])
+      };
+
+      // parse range header
+      let ranges = if let Ok(ranges) = HttpRange::parse(&range_header, len) {
+        ranges
+          .iter()
+          // map the output to spec range <start-end>, example: 0-499
+          .map(|r| (r.start, r.start + r.length - 1))
+          .collect::<Vec<_>>()
+      } else {
+        return not_satisfiable();
+      };
+
+      /// The Maximum bytes we send in one range
+      const MAX_LEN: u64 = 1000 * 1024;
+
+      // single-part range header
+      if ranges.len() == 1 {
+        let &(start, mut end) = ranges.first().unwrap();
+
+        // check if a range is not satisfiable
+        //
+        // this should be already taken care of by the range parsing library
+        // but checking here again for extra assurance
+        if start >= len || end >= len || end < start {
+          return not_satisfiable();
+        }
+
+        // adjust end byte for MAX_LEN
+        end = start + (end - start).min(len - start).min(MAX_LEN - 1);
+
+        // calculate number of bytes needed to be read
+        let nbytes = end + 1 - start;
+
+        let mut buf = Vec::with_capacity(nbytes as usize);
+        file.seek(SeekFrom::Start(start)).await?;
+        file.take(nbytes).read_to_end(&mut buf).await?;
+
+        resp = resp.header(CONTENT_RANGE, format!("bytes {start}-{end}/{len}"));
+        resp = resp.header(CONTENT_LENGTH, end + 1 - start);
+        resp = resp.status(StatusCode::PARTIAL_CONTENT);
+        resp.body(buf)
+      } else {
+        // multi-part range header
+        let mut buf = Vec::new();
+        let ranges = ranges
+          .iter()
+          .filter_map(|&(start, mut end)| {
+            // filter out unsatisfiable ranges
+            //
+            // this should be already taken care of by the range parsing library
+            // but checking here again for extra assurance
+            if start >= len || end >= len || end < start {
+              None
+            } else {
+              // adjust end byte for MAX_LEN
+              end = start + (end - start).min(len - start).min(MAX_LEN - 1);
+              Some((start, end))
+            }
+          })
+          .collect::<Vec<_>>();
+
+        let boundary = random_boundary();
+        let boundary_sep = format!("\r\n--{boundary}\r\n");
+        let boundary_closer = format!("\r\n--{boundary}\r\n");
+
+        resp = resp.header(
+          CONTENT_TYPE,
+          format!("multipart/byteranges; boundary={boundary}"),
+        );
+
+        for (end, start) in ranges {
+          // a new range is being written, write the range boundary
+          buf.write_all(boundary_sep.as_bytes()).await?;
+
+          // write the needed headers `Content-Type` and `Content-Range`
+          buf
+            .write_all(format!("{CONTENT_TYPE}: {mime_type}\r\n").as_bytes())
+            .await?;
+          buf
+            .write_all(format!("{CONTENT_RANGE}: bytes {start}-{end}/{len}\r\n").as_bytes())
+            .await?;
+
+          // write the separator to indicate the start of the range body
+          buf.write_all("\r\n".as_bytes()).await?;
+
+          // calculate number of bytes needed to be read
+          let nbytes = end + 1 - start;
+
+          let mut local_buf = Vec::with_capacity(nbytes as usize);
+          file.seek(SeekFrom::Start(start)).await?;
+          (&mut file).take(nbytes).read_to_end(&mut local_buf).await?;
+          buf.extend_from_slice(&local_buf);
+        }
+        // all ranges have been written, write the closing boundary
+        buf.write_all(boundary_closer.as_bytes()).await?;
+
+        resp.body(buf)
+      }
+    } else {
+      // avoid reading the file if we already read it
+      // as part of mime type detection
+      let buf = if let Some(b) = read_bytes {
+        b
+      } else {
+        let mut local_buf = Vec::with_capacity(len as usize);
+        file.read_to_end(&mut local_buf).await?;
+        local_buf
+      };
+      resp = resp.header(CONTENT_LENGTH, len);
+      resp.body(buf)
+    };
+
+    response
+  })
+}
+
+fn random_boundary() -> String {
+  let mut x = [0_u8; 30];
+  rand::thread_rng().fill_bytes(&mut x);
+  (x[..])
+    .iter()
+    .map(|&x| format!("{x:x}"))
+    .fold(String::new(), |mut a, x| {
+      a.push_str(x.as_str());
+      a
+    })
+}

+ 2 - 1
core/tauri/src/lib.rs

@@ -184,13 +184,14 @@ mod pattern;
 pub mod plugin;
 pub mod window;
 use tauri_runtime as runtime;
+#[cfg(protocol_asset)]
+mod asset_protocol;
 /// The allowlist scopes.
 pub mod scope;
 mod state;
 #[cfg(updater)]
 #[cfg_attr(doc_cfg, doc(cfg(feature = "updater")))]
 pub mod updater;
-
 pub use tauri_utils as utils;
 
 /// A Tauri [`Runtime`] wrapper around wry.

+ 5 - 195
core/tauri/src/manager.rs

@@ -498,203 +498,13 @@ impl<R: Runtime> WindowManager<R> {
 
     #[cfg(protocol_asset)]
     if !registered_scheme_protocols.contains(&"asset".into()) {
-      use crate::api::file::SafePathBuf;
-      use tokio::io::{AsyncReadExt, AsyncSeekExt};
-      use url::Position;
       let asset_scope = self.state().get::<crate::Scopes>().asset_protocol.clone();
       pending.register_uri_scheme_protocol("asset", move |request| {
-        let parsed_path = Url::parse(request.uri())?;
-        let filtered_path = &parsed_path[..Position::AfterPath];
-        let path = filtered_path
-          .strip_prefix("asset://localhost/")
-          // the `strip_prefix` only returns None when a request is made to `https://tauri.$P` on Windows
-          // where `$P` is not `localhost/*`
-          .unwrap_or("");
-        let path = percent_encoding::percent_decode(path.as_bytes())
-          .decode_utf8_lossy()
-          .to_string();
-
-        if let Err(e) = SafePathBuf::new(path.clone().into()) {
-          debug_eprintln!("asset protocol path \"{}\" is not valid: {}", path, e);
-          return HttpResponseBuilder::new().status(403).body(Vec::new());
-        }
-
-        if !asset_scope.is_allowed(&path) {
-          debug_eprintln!("asset protocol not configured to allow the path: {}", path);
-          return HttpResponseBuilder::new().status(403).body(Vec::new());
-        }
-
-        let path_ = path.clone();
-
-        let mut response =
-          HttpResponseBuilder::new().header("Access-Control-Allow-Origin", &window_origin);
-
-        // handle 206 (partial range) http request
-        if let Some(range) = request
-          .headers()
-          .get("range")
-          .and_then(|r| r.to_str().map(|r| r.to_string()).ok())
-        {
-          #[derive(Default)]
-          struct RangeMetadata {
-            file: Option<tokio::fs::File>,
-            range: Option<crate::runtime::http::HttpRange>,
-            metadata: Option<std::fs::Metadata>,
-            headers: HashMap<&'static str, String>,
-            status_code: u16,
-            body: Vec<u8>,
-          }
-
-          let mut range_metadata = crate::async_runtime::safe_block_on(async move {
-            let mut data = RangeMetadata::default();
-            // open the file
-            let mut file = match tokio::fs::File::open(path_.clone()).await {
-              Ok(file) => file,
-              Err(e) => {
-                debug_eprintln!("Failed to open asset: {}", e);
-                data.status_code = 404;
-                return data;
-              }
-            };
-            // Get the file size
-            let file_size = match file.metadata().await {
-              Ok(metadata) => {
-                let len = metadata.len();
-                data.metadata.replace(metadata);
-                len
-              }
-              Err(e) => {
-                debug_eprintln!("Failed to read asset metadata: {}", e);
-                data.file.replace(file);
-                data.status_code = 404;
-                return data;
-              }
-            };
-            // parse the range
-            let range = match crate::runtime::http::HttpRange::parse(
-              &if range.ends_with("-*") {
-                range.chars().take(range.len() - 1).collect::<String>()
-              } else {
-                range.clone()
-              },
-              file_size,
-            ) {
-              Ok(r) => r,
-              Err(e) => {
-                debug_eprintln!("Failed to parse range {}: {:?}", range, e);
-                data.file.replace(file);
-                data.status_code = 400;
-                return data;
-              }
-            };
-
-            // FIXME: Support multiple ranges
-            // let support only 1 range for now
-            if let Some(range) = range.first() {
-              data.range.replace(*range);
-              let mut real_length = range.length;
-              // prevent max_length;
-              // specially on webview2
-              if range.length > file_size / 3 {
-                // max size sent (400ko / request)
-                // as it's local file system we can afford to read more often
-                real_length = std::cmp::min(file_size - range.start, 1024 * 400);
-              }
-
-              // last byte we are reading, the length of the range include the last byte
-              // who should be skipped on the header
-              let last_byte = range.start + real_length - 1;
-
-              data.headers.insert("Connection", "Keep-Alive".into());
-              data.headers.insert("Accept-Ranges", "bytes".into());
-              data
-                .headers
-                .insert("Content-Length", real_length.to_string());
-              data.headers.insert(
-                "Content-Range",
-                format!("bytes {}-{last_byte}/{file_size}", range.start),
-              );
-
-              if let Err(e) = file.seek(std::io::SeekFrom::Start(range.start)).await {
-                debug_eprintln!("Failed to seek file to {}: {}", range.start, e);
-                data.file.replace(file);
-                data.status_code = 422;
-                return data;
-              }
-
-              let mut f = file.take(real_length);
-              let r = f.read_to_end(&mut data.body).await;
-              file = f.into_inner();
-              data.file.replace(file);
-
-              if let Err(e) = r {
-                debug_eprintln!("Failed read file: {}", e);
-                data.status_code = 422;
-                return data;
-              }
-              // partial content
-              data.status_code = 206;
-            } else {
-              data.status_code = 200;
-            }
-
-            data
-          });
-
-          for (k, v) in range_metadata.headers {
-            response = response.header(k, v);
-          }
-
-          let mime_type = if let (Some(mut file), Some(metadata), Some(range)) = (
-            range_metadata.file,
-            range_metadata.metadata,
-            range_metadata.range,
-          ) {
-            // if we're already reading the beginning of the file, we do not need to re-read it
-            if range.start == 0 {
-              MimeType::parse(&range_metadata.body, &path)
-            } else {
-              let (status, bytes) = crate::async_runtime::safe_block_on(async move {
-                let mut status = None;
-                if let Err(e) = file.rewind().await {
-                  debug_eprintln!("Failed to rewind file: {}", e);
-                  status.replace(422);
-                  (status, Vec::with_capacity(0))
-                } else {
-                  // taken from https://docs.rs/infer/0.9.0/src/infer/lib.rs.html#240-251
-                  let limit = std::cmp::min(metadata.len(), 8192) as usize + 1;
-                  let mut bytes = Vec::with_capacity(limit);
-                  if let Err(e) = file.take(8192).read_to_end(&mut bytes).await {
-                    debug_eprintln!("Failed read file: {}", e);
-                    status.replace(422);
-                  }
-                  (status, bytes)
-                }
-              });
-              if let Some(s) = status {
-                range_metadata.status_code = s;
-              }
-              MimeType::parse(&bytes, &path)
-            }
-          } else {
-            MimeType::parse(&range_metadata.body, &path)
-          };
-          response
-            .mimetype(&mime_type)
-            .status(range_metadata.status_code)
-            .body(range_metadata.body)
-        } else {
-          match crate::async_runtime::safe_block_on(async move { tokio::fs::read(path_).await }) {
-            Ok(data) => {
-              let mime_type = MimeType::parse(&data, &path);
-              response.mimetype(&mime_type).body(data)
-            }
-            Err(e) => {
-              debug_eprintln!("Failed to read file: {}", e);
-              response.status(404).body(Vec::new())
-            }
-          }
-        }
+        crate::asset_protocol::asset_protocol_handler(
+          request,
+          asset_scope.clone(),
+          window_origin.clone(),
+        )
       });
     }
 

+ 129 - 56
examples/streaming/main.rs

@@ -4,14 +4,15 @@
 
 #![cfg_attr(not(debug_assertions), windows_subsystem = "windows")]
 
+use std::sync::{Arc, Mutex};
+
 fn main() {
   use std::{
-    cmp::min,
-    io::{Read, Seek, SeekFrom},
+    io::{Read, Seek, SeekFrom, Write},
     path::PathBuf,
     process::{Command, Stdio},
   };
-  use tauri::http::{HttpRange, ResponseBuilder};
+  use tauri::http::{header::*, status::StatusCode, HttpRange, ResponseBuilder};
 
   let video_file = PathBuf::from("test_video.mp4");
   let video_url =
@@ -35,77 +36,149 @@ fn main() {
     assert!(video_file.exists());
   }
 
+  // NOTE: for production use `rand` crate to generate a random boundary
+  let boundary_id = Arc::new(Mutex::new(0));
+
   tauri::Builder::default()
     .invoke_handler(tauri::generate_handler![video_uri])
     .register_uri_scheme_protocol("stream", move |_app, request| {
-      // prepare our response
-      let mut response = ResponseBuilder::new();
       // get the file path
       let path = request.uri().strip_prefix("stream://localhost/").unwrap();
       let path = percent_encoding::percent_decode(path.as_bytes())
         .decode_utf8_lossy()
         .to_string();
 
-      if path != "example/test_video.mp4" {
-        // return error 404 if it's not out video
-        return response.mimetype("text/plain").status(404).body(Vec::new());
+      if path != "test_video.mp4" {
+        // return error 404 if it's not our video
+        return ResponseBuilder::new().status(404).body(Vec::new());
       }
 
-      // read our file
-      let mut content = std::fs::File::open(&video_file)?;
-      let mut buf = Vec::new();
+      let mut file = std::fs::File::open(&path)?;
 
-      // default status code
-      let mut status_code = 200;
+      // get file length
+      let len = {
+        let old_pos = file.stream_position()?;
+        let len = file.seek(SeekFrom::End(0))?;
+        file.seek(SeekFrom::Start(old_pos))?;
+        len
+      };
+
+      let mut resp = ResponseBuilder::new().header(CONTENT_TYPE, "video/mp4");
 
       // if the webview sent a range header, we need to send a 206 in return
       // Actually only macOS and Windows are supported. Linux will ALWAYS return empty headers.
-      if let Some(range) = request.headers().get("range") {
-        // Get the file size
-        let file_size = content.metadata().unwrap().len();
-
-        // we parse the range header with tauri helper
-        let range = HttpRange::parse(range.to_str().unwrap(), file_size).unwrap();
-        // let support only 1 range for now
-        let first_range = range.first();
-        if let Some(range) = first_range {
-          let mut real_length = range.length;
-
-          // prevent max_length;
-          // specially on webview2
-          if range.length > file_size / 3 {
-            // max size sent (400ko / request)
-            // as it's local file system we can afford to read more often
-            real_length = min(file_size - range.start, 1024 * 400);
+      let response = if let Some(range_header) = request.headers().get("range") {
+        let not_satisfiable = || {
+          ResponseBuilder::new()
+            .status(StatusCode::RANGE_NOT_SATISFIABLE)
+            .header(CONTENT_RANGE, format!("bytes */{len}"))
+            .body(vec![])
+        };
+
+        // parse range header
+        let ranges = if let Ok(ranges) = HttpRange::parse(range_header.to_str()?, len) {
+          ranges
+            .iter()
+            // map the output back to spec range <start-end>, example: 0-499
+            .map(|r| (r.start, r.start + r.length - 1))
+            .collect::<Vec<_>>()
+        } else {
+          return not_satisfiable();
+        };
+
+        /// The Maximum bytes we send in one range
+        const MAX_LEN: u64 = 1000 * 1024;
+
+        if ranges.len() == 1 {
+          let &(start, mut end) = ranges.first().unwrap();
+
+          // check if a range is not satisfiable
+          //
+          // this should be already taken care of by HttpRange::parse
+          // but checking here again for extra assurance
+          if start >= len || end >= len || end < start {
+            return not_satisfiable();
           }
 
-          // last byte we are reading, the length of the range include the last byte
-          // who should be skipped on the header
-          let last_byte = range.start + real_length - 1;
-          // partial content
-          status_code = 206;
-
-          // Only macOS and Windows are supported, if you set headers in linux they are ignored
-          response = response
-            .header("Connection", "Keep-Alive")
-            .header("Accept-Ranges", "bytes")
-            .header("Content-Length", real_length)
-            .header(
-              "Content-Range",
-              format!("bytes {}-{}/{}", range.start, last_byte, file_size),
-            );
-
-          // FIXME: Add ETag support (caching on the webview)
-
-          // seek our file bytes
-          content.seek(SeekFrom::Start(range.start))?;
-          content.take(real_length).read_to_end(&mut buf)?;
+          // adjust end byte for MAX_LEN
+          end = start + (end - start).min(len - start).min(MAX_LEN - 1);
+
+          // calculate number of bytes needed to be read
+          let bytes_to_read = end + 1 - start;
+
+          // allocate a buf with a suitable capacity
+          let mut buf = Vec::with_capacity(bytes_to_read as usize);
+          // seek the file to the starting byte
+          file.seek(SeekFrom::Start(start))?;
+          // read the needed bytes
+          file.take(bytes_to_read).read_to_end(&mut buf)?;
+
+          resp = resp.header(CONTENT_RANGE, format!("bytes {start}-{end}/{len}"));
+          resp = resp.header(CONTENT_LENGTH, end + 1 - start);
+          resp = resp.status(StatusCode::PARTIAL_CONTENT);
+          resp.body(buf)
         } else {
-          content.read_to_end(&mut buf)?;
-        }
-      }
+          let mut buf = Vec::new();
+          let ranges = ranges
+            .iter()
+            .filter_map(|&(start, mut end)| {
+              // filter out unsatisfiable ranges
+              //
+              // this should be already taken care of by HttpRange::parse
+              // but checking here again for extra assurance
+              if start >= len || end >= len || end < start {
+                None
+              } else {
+                // adjust end byte for MAX_LEN
+                end = start + (end - start).min(len - start).min(MAX_LEN - 1);
+                Some((start, end))
+              }
+            })
+            .collect::<Vec<_>>();
+
+          let mut id = boundary_id.lock().unwrap();
+          *id += 1;
+          let boundary = format!("sadasq2e{id}");
+          let boundary_sep = format!("\r\n--{boundary}\r\n");
+          let boundary_closer = format!("\r\n--{boundary}\r\n");
+
+          resp = resp.header(
+            CONTENT_TYPE,
+            format!("multipart/byteranges; boundary={boundary}"),
+          );
+
+          for (end, start) in ranges {
+            // a new range is being written, write the range boundary
+            buf.write_all(boundary_sep.as_bytes())?;
+
+            // write the needed headers `Content-Type` and `Content-Range`
+            buf.write_all(format!("{CONTENT_TYPE}: video/mp4\r\n").as_bytes())?;
+            buf.write_all(format!("{CONTENT_RANGE}: bytes {start}-{end}/{len}\r\n").as_bytes())?;
+
+            // write the separator to indicate the start of the range body
+            buf.write_all("\r\n".as_bytes())?;
+
+            // calculate number of bytes needed to be read
+            let bytes_to_read = end + 1 - start;
+
+            let mut local_buf = vec![0_u8; bytes_to_read as usize];
+            file.seek(SeekFrom::Start(start))?;
+            file.read_exact(&mut local_buf)?;
+            buf.extend_from_slice(&local_buf);
+          }
+          // all ranges have been written, write the closing boundary
+          buf.write_all(boundary_closer.as_bytes())?;
 
-      response.mimetype("video/mp4").status(status_code).body(buf)
+          resp.body(buf)
+        }
+      } else {
+        resp = resp.header(CONTENT_LENGTH, len);
+        let mut buf = Vec::with_capacity(len as usize);
+        file.read_to_end(&mut buf)?;
+        resp.body(buf)
+      };
+
+      response
     })
     .run(tauri::generate_context!(
       "../../examples/streaming/tauri.conf.json"
@@ -125,5 +198,5 @@ fn video_uri() -> (&'static str, std::path::PathBuf) {
   }
 
   #[cfg(not(feature = "protocol-asset"))]
-  ("stream", "example/test_video.mp4".into())
+  ("stream", "test_video.mp4".into())
 }