diff --git a/HotKeet/Cargo.lock b/HotKeet/Cargo.lock index 561dfcd..04e59d3 100644 --- a/HotKeet/Cargo.lock +++ b/HotKeet/Cargo.lock @@ -147,7 +147,7 @@ dependencies = [ "enumflags2", "futures-channel", "futures-util", - "rand", + "rand 0.8.5", "serde", "serde_repr", "url", @@ -319,6 +319,12 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + [[package]] name = "bit-set" version = "0.6.0" @@ -1262,6 +1268,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07bbe89c50d7a535e539b8c17bc0b49bdb77747034daa8087407d655f3f7cc1d" dependencies = [ "futures-core", + "futures-sink", ] [[package]] @@ -1355,8 +1362,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" dependencies = [ "cfg-if", + "js-sys", "libc", "wasi", + "wasm-bindgen", ] [[package]] @@ -1366,9 +1375,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" dependencies = [ "cfg-if", + "js-sys", "libc", "r-efi", "wasip2", + "wasm-bindgen", ] [[package]] @@ -1602,6 +1613,7 @@ dependencies = [ "hound", "raw-window-handle", "rdev", + "reqwest", "rfd", "serde", "serde_json", @@ -1618,6 +1630,106 @@ version = "3.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "62adaabb884c94955b19907d60019f4e145d091c75345379e70d1ee696f7854f" +[[package]] +name = "http" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" +dependencies = [ + "bytes", + "itoa", +] + +[[package]] +name = "http-body" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" +dependencies = [ + "bytes", + "http", +] + +[[package]] +name = "http-body-util" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" +dependencies = [ + "bytes", + "futures-core", + "http", + "http-body", + "pin-project-lite", +] + +[[package]] +name = "httparse" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" + +[[package]] +name = "hyper" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ab2d4f250c3d7b1c9fcdff1cece94ea4e2dfbec68614f7b87cb205f24ca9d11" +dependencies = [ + "atomic-waker", + "bytes", + "futures-channel", + "futures-core", + "http", + "http-body", + "httparse", + "itoa", + "pin-project-lite", + "pin-utils", + "smallvec", + "tokio", + "want", +] + +[[package]] +name = "hyper-rustls" +version = "0.27.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" +dependencies = [ + "http", + "hyper", + "hyper-util", + "rustls", + "rustls-pki-types", + "tokio", + "tokio-rustls", + "tower-service", + "webpki-roots", +] + +[[package]] +name = "hyper-util" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96547c2556ec9d12fb1578c4eaf448b04993e7fb79cbaad930a656880a6bdfa0" +dependencies = [ + "base64", + "bytes", + "futures-channel", + "futures-util", + "http", + "http-body", + "hyper", + "ipnet", + "libc", + "percent-encoding", + "pin-project-lite", + "socket2", + "tokio", + "tower-service", + "tracing", +] + [[package]] name = "iana-time-zone" version = "0.1.65" @@ -1798,6 +1910,22 @@ dependencies = [ "libc", ] +[[package]] +name = "ipnet" +version = "2.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d98f6fed1fde3f8c21bc40a1abb88dd75e67924f9cffc3ef95607bad8017f8e2" + +[[package]] +name = "iri-string" +version = "0.7.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c91338f0783edbd6195decb37bae672fd3b165faffb89bf7b9e6942f8b1a731a" +dependencies = [ + "memchr", + "serde", +] + [[package]] name = "itoa" version = "1.0.17" @@ -1936,6 +2064,12 @@ version = "0.4.29" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" +[[package]] +name = "lru-slab" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" + [[package]] name = "mach2" version = "0.4.3" @@ -2611,6 +2745,12 @@ version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a89322df9ebe1c1578d689c92318e070967d1042b512afbe49518723f4e6d5cd" +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + [[package]] name = "piper" version = "0.2.5" @@ -2736,6 +2876,61 @@ dependencies = [ "memchr", ] +[[package]] +name = "quinn" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e20a958963c291dc322d98411f541009df2ced7b5a4f2bd52337638cfccf20" +dependencies = [ + "bytes", + "cfg_aliases 0.2.1", + "pin-project-lite", + "quinn-proto", + "quinn-udp", + "rustc-hash 2.1.1", + "rustls", + "socket2", + "thiserror 2.0.18", + "tokio", + "tracing", + "web-time", +] + +[[package]] +name = "quinn-proto" +version = "0.11.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1906b49b0c3bc04b5fe5d86a77925ae6524a19b816ae38ce1e426255f1d8a31" +dependencies = [ + "bytes", + "getrandom 0.3.4", + "lru-slab", + "rand 0.9.2", + "ring", + "rustc-hash 2.1.1", + "rustls", + "rustls-pki-types", + "slab", + "thiserror 2.0.18", + "tinyvec", + "tracing", + "web-time", +] + +[[package]] +name = "quinn-udp" +version = "0.5.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd" +dependencies = [ + "cfg_aliases 0.2.1", + "libc", + "once_cell", + "socket2", + "tracing", + "windows-sys 0.60.2", +] + [[package]] name = "quote" version = "1.0.45" @@ -2758,8 +2953,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ "libc", - "rand_chacha", - "rand_core", + "rand_chacha 0.3.1", + "rand_core 0.6.4", +] + +[[package]] +name = "rand" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" +dependencies = [ + "rand_chacha 0.9.0", + "rand_core 0.9.5", ] [[package]] @@ -2769,7 +2974,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" dependencies = [ "ppv-lite86", - "rand_core", + "rand_core 0.6.4", +] + +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core 0.9.5", ] [[package]] @@ -2781,6 +2996,15 @@ dependencies = [ "getrandom 0.2.17", ] +[[package]] +name = "rand_core" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76afc826de14238e6e8c374ddcc1fa19e374fd8dd986b0d2af0d02377261d83c" +dependencies = [ + "getrandom 0.3.4", +] + [[package]] name = "raw-window-handle" version = "0.6.2" @@ -2844,6 +3068,46 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "19b30a45b0cd0bcca8037f3d0dc3421eaf95327a17cad11964fb8179b4fc4832" +[[package]] +name = "reqwest" +version = "0.12.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eddd3ca559203180a307f12d114c268abf583f59b03cb906fd0b3ff8646c1147" +dependencies = [ + "base64", + "bytes", + "futures-channel", + "futures-core", + "futures-util", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-rustls", + "hyper-util", + "js-sys", + "log", + "percent-encoding", + "pin-project-lite", + "quinn", + "rustls", + "rustls-pki-types", + "serde", + "serde_json", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tokio-rustls", + "tower", + "tower-http", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", + "webpki-roots", +] + [[package]] name = "rfd" version = "0.14.1" @@ -2867,6 +3131,20 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "ring" +version = "0.17.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" +dependencies = [ + "cc", + "cfg-if", + "getrandom 0.2.17", + "libc", + "untrusted", + "windows-sys 0.52.0", +] + [[package]] name = "rustc-hash" version = "1.1.0" @@ -2905,12 +3183,53 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "rustls" +version = "0.23.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "758025cb5fccfd3bc2fd74708fd4682be41d99e5dff73c377c0646c6012c73a4" +dependencies = [ + "once_cell", + "ring", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls-pki-types" +version = "1.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be040f8b0a225e40375822a563fa9524378b9d63112f53e19ffff34df5d33fdd" +dependencies = [ + "web-time", + "zeroize", +] + +[[package]] +name = "rustls-webpki" +version = "0.103.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7df23109aa6c1567d1c575b9952556388da57401e4ace1d15f79eedad0d8f53" +dependencies = [ + "ring", + "rustls-pki-types", + "untrusted", +] + [[package]] name = "rustversion" version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" +[[package]] +name = "ryu" +version = "1.0.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f" + [[package]] name = "same-file" version = "1.0.6" @@ -2995,6 +3314,18 @@ dependencies = [ "serde_core", ] +[[package]] +name = "serde_urlencoded" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd" +dependencies = [ + "form_urlencoded", + "itoa", + "ryu", + "serde", +] + [[package]] name = "sha1" version = "0.10.6" @@ -3171,6 +3502,12 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + [[package]] name = "syn" version = "1.0.109" @@ -3193,6 +3530,15 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "sync_wrapper" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bf256ce5efdfa370213c1dabab5935a12e49f2c58d15e9eac2870d3b4f27263" +dependencies = [ + "futures-core", +] + [[package]] name = "synstructure" version = "0.13.2" @@ -3290,6 +3636,21 @@ dependencies = [ "zerovec", ] +[[package]] +name = "tinyvec" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa5fdc3bce6191a1dbc8c02d5c8bffcf557bafa17c124c5264a458f1b0613fa" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + [[package]] name = "tokio" version = "1.50.0" @@ -3316,6 +3677,16 @@ dependencies = [ "syn 2.0.117", ] +[[package]] +name = "tokio-rustls" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61" +dependencies = [ + "rustls", + "tokio", +] + [[package]] name = "toml" version = "0.9.12+spec-1.1.0" @@ -3376,6 +3747,51 @@ version = "1.0.6+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ab16f14aed21ee8bfd8ec22513f7287cd4a91aa92e44edfe2c17ddd004e92607" +[[package]] +name = "tower" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebe5ef63511595f1344e2d5cfa636d973292adc0eec1f0ad45fae9f0851ab1d4" +dependencies = [ + "futures-core", + "futures-util", + "pin-project-lite", + "sync_wrapper", + "tokio", + "tower-layer", + "tower-service", +] + +[[package]] +name = "tower-http" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8" +dependencies = [ + "bitflags 2.11.0", + "bytes", + "futures-util", + "http", + "http-body", + "iri-string", + "pin-project-lite", + "tower", + "tower-layer", + "tower-service", +] + +[[package]] +name = "tower-layer" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" + +[[package]] +name = "tower-service" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" + [[package]] name = "tracing" version = "0.1.44" @@ -3424,6 +3840,12 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "try-lock" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" + [[package]] name = "ttf-parser" version = "0.25.1" @@ -3480,6 +3902,12 @@ version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + [[package]] name = "url" version = "2.5.8" @@ -3521,6 +3949,15 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "want" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa7760aed19e106de2c7c0b581b509f2f25d3dacaf737cb82ac61bc6d760b0e" +dependencies = [ + "try-lock", +] + [[package]] name = "wasi" version = "0.11.1+wasi-snapshot-preview1" @@ -3766,6 +4203,15 @@ dependencies = [ "web-sys", ] +[[package]] +name = "webpki-roots" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22cfaf3c063993ff62e73cb4311efde4db1efb31ab78a3e5c457939ad5cc0bed" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "weezl" version = "0.1.12" @@ -4530,7 +4976,7 @@ dependencies = [ "hex", "nix", "ordered-stream", - "rand", + "rand 0.8.5", "serde", "serde_repr", "sha1", @@ -4609,6 +5055,12 @@ dependencies = [ "synstructure", ] +[[package]] +name = "zeroize" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" + [[package]] name = "zerotrie" version = "0.2.3" diff --git a/HotKeet/Cargo.toml b/HotKeet/Cargo.toml index 0101acf..9f5497d 100644 --- a/HotKeet/Cargo.toml +++ b/HotKeet/Cargo.toml @@ -27,6 +27,7 @@ arboard = "3.2" chrono = "0.4" raw-window-handle = "0.6" rfd = "0.14" +reqwest = { version = "0.12", default-features = false, features = ["blocking", "rustls-tls"] } [build-dependencies] winresource = "0.1" diff --git a/HotKeet/src/config.rs b/HotKeet/src/config.rs index dda44c6..90f36d4 100644 --- a/HotKeet/src/config.rs +++ b/HotKeet/src/config.rs @@ -174,6 +174,9 @@ impl DictateConfig { if !p.is_dir() { return false; } + if !crate::model::is_model_valid(p) { + return false; + } } true } diff --git a/HotKeet/src/main.rs b/HotKeet/src/main.rs index ac2b726..fd57142 100644 --- a/HotKeet/src/main.rs +++ b/HotKeet/src/main.rs @@ -4,6 +4,7 @@ #![cfg_attr(windows, windows_subsystem = "windows")] mod companion; +mod model; mod sound; mod win_close; @@ -112,6 +113,11 @@ struct AppState { frame_count: u32, /// Schließkreuz → minimieren (in raw_input_hook gesetzt) pending_minimize_from_close: bool, + /// Download request: path to download model into + download_request_rx: Receiver, + download_progress_tx: Sender, + download_progress_rx: Receiver, + download_progress: Arc>>, } impl eframe::App for AppState { @@ -143,6 +149,15 @@ impl eframe::App for AppState { ctx.send_viewport_cmd(egui::ViewportCommand::Visible(false)); } + // Download request: spawn model download + if let Ok(path) = self.download_request_rx.try_recv() { + model::download_model_async(path, self.download_progress_tx.clone()); + } + // Download progress updates + while let Ok(progress) = self.download_progress_rx.try_recv() { + let _ = self.download_progress.write().map(|mut w| *w = Some(progress)); + } + // Fallback falls raw_input_hook nicht greift (z.B. andere Viewports) if ctx.input(|i| i.viewport().close_requested()) { let minimize = self.config.read().map(|c| c.minimize_to_tray).unwrap_or(true); @@ -301,6 +316,14 @@ fn main() -> eframe::Result<()> { let (test_tx, test_rx): (Sender<()>, Receiver<()>) = std::sync::mpsc::channel(); let (tray_tx, tray_rx): (Sender, Receiver) = std::sync::mpsc::channel(); + let (download_request_tx, download_request_rx): ( + Sender, + Receiver, + ) = std::sync::mpsc::channel(); + let (download_progress_tx, download_progress_rx): ( + Sender, + Receiver, + ) = std::sync::mpsc::channel(); let tray_rx = { #[cfg(windows)] @@ -352,6 +375,14 @@ fn main() -> eframe::Result<()> { let _tray = tray::create_tray(tray_tx); + // Modellpfad gesetzt, Ordner existiert, aber Modell fehlt → Auto-Download starten + if !config_ui.model_path.is_empty() { + let model_path = std::path::Path::new(&config_ui.model_path); + if model_path.is_dir() && !model::is_model_valid(model_path) { + let _ = download_request_tx.send(std::path::PathBuf::from(&config_ui.model_path)); + } + } + // Bei ungültigen Parakeet-Pfaden oder Erststart (beide leer): UI öffnen let start_minimized = if config_ui.needs_initial_config() || !config_ui.has_valid_parakeet_config() { @@ -360,10 +391,13 @@ fn main() -> eframe::Result<()> { config_ui.start_minimized }; let minimize_to_tray = config_ui.minimize_to_tray; + let download_progress = Arc::new(std::sync::RwLock::new(None)); let state = AppState { settings: ui::SettingsApp::new(config_ui) .with_config_sync(config_arc.clone()) - .with_test_sender(test_tx), + .with_test_sender(test_tx) + .with_download_request(download_request_tx) + .with_download_progress(download_progress.clone()), recording_stop: None, hotkey_rx, paste_tx, @@ -378,6 +412,10 @@ fn main() -> eframe::Result<()> { start_minimized_pending: start_minimized, frame_count: 0, pending_minimize_from_close: false, + download_request_rx, + download_progress_tx, + download_progress_rx, + download_progress, }; let mut viewport = egui::ViewportBuilder::default() diff --git a/HotKeet/src/model.rs b/HotKeet/src/model.rs new file mode 100644 index 0000000..16fb80e --- /dev/null +++ b/HotKeet/src/model.rs @@ -0,0 +1,106 @@ +//! Parakeet model verification and download from Hugging Face. + +use std::io::Read; +use std::path::Path; +use std::sync::mpsc::Sender; + +const HF_BASE: &str = "https://huggingface.co/nasedkinpv/parakeet-tdt-0.6b-v3-onnx-int8/resolve/main"; + +/// Required files for Parakeet INT8 model (transcribe-rs). +const REQUIRED_FILES: &[&str] = &[ + "vocab.txt", + "encoder-int8.onnx", + "encoder-int8.onnx.data", + "decoder_joint-int8.onnx", +]; + +/// Progress message during download. +#[derive(Clone)] +pub enum DownloadProgress { + Starting(String), + Downloading { file: String, done: u64, total: Option }, + Finished, + Error(String), +} + +/// Returns true if the directory contains a valid Parakeet INT8 model. +pub fn is_model_valid(path: &Path) -> bool { + if !path.is_dir() { + return false; + } + REQUIRED_FILES + .iter() + .all(|f| path.join(f).is_file()) +} + +/// Downloads the Parakeet model into the given directory. +/// Sends progress updates via `tx`. Runs in a spawned thread. +pub fn download_model_async(path: std::path::PathBuf, tx: Sender) { + std::thread::spawn(move || { + if let Err(e) = download_model(&path, |p| { + let _ = tx.send(p); + }) { + let _ = tx.send(DownloadProgress::Error(e)); + } else { + let _ = tx.send(DownloadProgress::Finished); + } + }); +} + +fn download_model(path: &Path, mut progress: F) -> Result<(), String> +where + F: FnMut(DownloadProgress), +{ + std::fs::create_dir_all(path).map_err(|e| format!("Create dir: {}", e))?; + + let client = reqwest::blocking::Client::builder() + .timeout(std::time::Duration::from_secs(3600)) + .build() + .map_err(|e| e.to_string())?; + + for file_name in REQUIRED_FILES { + let file_path = path.join(file_name); + if file_path.is_file() { + let len = file_path.metadata().map(|m| m.len()).unwrap_or(0); + progress(DownloadProgress::Downloading { + file: file_name.to_string(), + done: len, + total: Some(len), + }); + continue; + } + + progress(DownloadProgress::Starting(file_name.to_string())); + + let url = format!("{}/{}", HF_BASE, file_name); + let mut resp = client + .get(&url) + .send() + .map_err(|e| format!("Request {}: {}", file_name, e))?; + + if !resp.status().is_success() { + return Err(format!("Download {}: HTTP {}", file_name, resp.status())); + } + + let total = resp.content_length(); + let mut out_file = std::fs::File::create(&file_path).map_err(|e| format!("Create {}: {}", file_name, e))?; + + let mut buf = [0u8; 64 * 1024]; + let mut downloaded: u64 = 0; + loop { + let n = resp.read(&mut buf).map_err(|e| format!("Read {}: {}", file_name, e))?; + if n == 0 { + break; + } + std::io::Write::write_all(&mut out_file, &buf[..n]).map_err(|e| format!("Write {}: {}", file_name, e))?; + downloaded += n as u64; + progress(DownloadProgress::Downloading { + file: file_name.to_string(), + done: downloaded, + total, + }); + } + } + + Ok(()) +} diff --git a/HotKeet/src/ui.rs b/HotKeet/src/ui.rs index 53fba12..e59f37a 100644 --- a/HotKeet/src/ui.rs +++ b/HotKeet/src/ui.rs @@ -2,8 +2,10 @@ use crate::config::DictateConfig; use crate::hotkey; +use crate::model; use crate::recording::list_input_sources; use eframe::egui; +use std::path::Path; use std::sync::mpsc::Sender; use std::sync::{Arc, RwLock}; @@ -14,6 +16,8 @@ pub struct SettingsApp { pub test_tx: Option>, /// Hotkey-Feld: true = warte auf Tastendruck pub hotkey_capturing: bool, + pub download_request_tx: Option>, + pub download_progress: Option>>>, } impl SettingsApp { @@ -24,6 +28,8 @@ impl SettingsApp { config_arc: None, test_tx: None, hotkey_capturing: false, + download_request_tx: None, + download_progress: None, } } @@ -37,6 +43,16 @@ impl SettingsApp { self } + pub fn with_download_request(mut self, tx: Sender) -> Self { + self.download_request_tx = Some(tx); + self + } + + pub fn with_download_progress(mut self, progress: Arc>>) -> Self { + self.download_progress = Some(progress); + self + } + /// Ermittelt den anzuzeigenden Quell-Index (0=Companion, 1+=Mikrofon). fn selected_source_index(&self, sources: &[(usize, String)]) -> usize { if self.config.use_companion_microphone { @@ -224,6 +240,43 @@ impl SettingsApp { } }); + let model_path = Path::new(&self.config.model_path); + let model_valid = !self.config.model_path.is_empty() && model::is_model_valid(model_path); + let downloading = self.download_progress.as_ref().and_then(|p| p.read().ok()).and_then(|g| g.clone()); + if !self.config.model_path.is_empty() && !model_valid { + ui.horizontal(|ui| { + ui.colored_label(egui::Color32::YELLOW, "Model not found in folder."); + if let (Some(ref tx), None) = (&self.download_request_tx, &downloading) { + if ui.button("Download model").clicked() { + let path = std::path::PathBuf::from(&self.config.model_path); + let _ = tx.send(path); + self.status = "Downloading…".to_string(); + } + } + }); + if let Some(ref prog) = downloading { + match prog { + model::DownloadProgress::Starting(f) => { + ui.label(format!("Starting: {}", f)); + } + model::DownloadProgress::Downloading { file, done, total } => { + let pct = total.map(|t| if t > 0 { (done * 100 / t) as u32 } else { 0 }); + ui.label(format!("{}: {} MB{}", file, done / 1_000_000, pct.map(|p| format!(" ({}%)", p)).unwrap_or_default())); + } + model::DownloadProgress::Finished => { + ui.colored_label(egui::Color32::GREEN, "Download complete."); + self.status = "Model downloaded. Remember to save.".to_string(); + if let Some(ref p) = self.download_progress { + let _ = p.write().map(|mut w| *w = None); + } + } + model::DownloadProgress::Error(e) => { + ui.colored_label(egui::Color32::RED, format!("Error: {}", e)); + } + } + } + } + ui.add_space(4.0); ui.horizontal(|ui| { diff --git a/README.md b/README.md index 1a0aef3..73bd687 100644 --- a/README.md +++ b/README.md @@ -47,7 +47,7 @@ Default hotkey: **Ctrl+Shift+D** (hold = record, release = transcribe + paste) | **Hotkey** | Global Push-to-Talk (e.g. Ctrl+Shift+D) | | **Input source** | Companion app or microphone | | **parakeet-cli path** | Empty = in PATH (default: parakeet-cli) | -| **Model path** | Empty = default path (platform-dependent) | +| **Model path** | Empty = default path. If folder is empty or invalid, use "Download model" to fetch from Hugging Face | | **Paste method** | Auto | Keyboard buffer | Clipboard | **Storage location:** `%LOCALAPPDATA%\HotKeet\settings.json` (Windows) or `~/.config/HotKeet/settings.json` (Linux/macOS)