use axum::{ body::Body, extract::{DefaultBodyLimit, Multipart, Path, Request, State}, http::{header, StatusCode}, middleware::{self, Next}, response::{Html, IntoResponse, Redirect, Response}, routing::get, Router, }; use base64::{Engine, engine::general_purpose::STANDARD as B64}; use clap::Parser; use hyper_util::{rt::TokioIo, server::conn::auto::Builder, service::TowerToHyperService}; use rustls::ServerConfig; use std::{io::Write, path::{Path as FsPath, PathBuf}, sync::Arc, time::Instant}; use tokio::{fs, io::AsyncWriteExt, net::TcpListener, signal}; use tokio_rustls::TlsAcceptor; type St = Arc; struct AppState { root: PathBuf, pass: Option } #[derive(Parser)] struct Args { #[arg(short, long, default_value_t = 8000)] port: u16, #[arg(default_value = ".")] directory: PathBuf, #[arg(long)] pass: Option, } #[tokio::main] async fn main() { let args = Args::parse(); let root = std::fs::canonicalize(&args.directory) .unwrap_or_else(|_| panic!("Not found: {}", args.directory.display())); let st: St = Arc::new(AppState { root: root.clone(), pass: args.pass }); let app = Router::new() .route("/", get(get_h).post(upload_h)) .route("/{*path}", get(get_h).post(upload_h)) .layer(DefaultBodyLimit::max(100 * 1024 * 1024 * 1024)) .layer(middleware::from_fn_with_state(st.clone(), auth)) .with_state(st); let acceptor = self_signed_tls(); let addr = format!("[::]:{}", args.port); let listener = TcpListener::bind(&addr).await.unwrap(); println!("Serving {} on https://{addr}", root.display()); let shutdown = async { signal::ctrl_c().await.ok(); println!("\nshutting down..."); }; tokio::pin!(shutdown); loop { tokio::select! { _ = &mut shutdown => break, res = listener.accept() => { let (stream, _peer) = match res { Ok(c) => c, Err(e) => { eprintln!("accept: {e}"); continue } }; let (acc, app) = (acceptor.clone(), app.clone()); tokio::spawn(async move { let Ok(tls) = acc.accept(stream).await.inspect_err(|e| eprintln!("tls: {e}")) else { return }; let svc = TowerToHyperService::new(app); let _ = Builder::new(hyper_util::rt::TokioExecutor::new()) .serve_connection(TokioIo::new(tls), svc).await; }); } } } } fn self_signed_tls() -> TlsAcceptor { use std::os::unix::fs::MetadataExt; let dir = std::path::Path::new("/tmp/fileserver"); let uid = unsafe { libc::getuid() }; if dir.exists() { let meta = std::fs::metadata(dir).expect("can't stat /tmp/fileserver"); if meta.uid() != uid { eprintln!("/tmp/fileserver owned by uid {}, expected {uid} — refusing", meta.uid()); std::process::exit(1); } } else { std::fs::create_dir(dir).expect("can't create /tmp/fileserver"); } let lock = std::fs::File::create("/tmp/fileserver/.lock").unwrap(); use std::os::unix::io::AsRawFd; if unsafe { libc::flock(lock.as_raw_fd(), libc::LOCK_EX | libc::LOCK_NB) } != 0 { eprintln!("another fileserver is already running"); std::process::exit(1); } std::mem::forget(lock); for entry in std::fs::read_dir(dir).unwrap() { let p = entry.unwrap().path(); if p.file_name().unwrap() != ".lock" { let _ = std::fs::remove_file(&p); } } let c = rcgen::generate_simple_self_signed(vec!["localhost".into(), "0.0.0.0".into(), "127.0.0.1".into()]) .expect("cert gen failed"); let (cp, kp) = (c.cert.pem(), c.signing_key.serialize_pem()); std::fs::write("/tmp/fileserver/cert.pem", &cp).unwrap(); std::fs::write("/tmp/fileserver/key.pem", &kp).unwrap(); println!("cert: /tmp/fileserver/cert.pem\nkey: /tmp/fileserver/key.pem"); let certs: Vec<_> = rustls_pemfile::certs(&mut cp.as_bytes()).collect::>().unwrap(); let key = rustls_pemfile::private_key(&mut kp.as_bytes()).unwrap().unwrap(); TlsAcceptor::from(Arc::new(ServerConfig::builder().with_no_client_auth().with_single_cert(certs, key).unwrap())) } async fn auth(State(st): State, req: Request, next: Next) -> Response { let method = req.method().clone(); let uri = req.uri().path().to_string(); let resp = auth_inner(&st, req, next).await; println!("{method} {uri} {}", resp.status().as_u16()); resp } async fn auth_inner(st: &AppState, req: Request, next: Next) -> Response { let Some(pw) = &st.pass else { return next.run(req).await }; let ok = req.headers().get(header::AUTHORIZATION) .and_then(|v| v.to_str().ok()) .and_then(|v| v.strip_prefix("Basic ")) .and_then(|b| B64.decode(b).ok()) .and_then(|b| String::from_utf8(b).ok()) .is_some_and(|c| c.ends_with(&format!(":{pw}"))); if ok { return next.run(req).await } Response::builder().status(StatusCode::UNAUTHORIZED) .header(header::WWW_AUTHENTICATE, r#"Basic realm="fileserver""#) .body(Body::from("Unauthorized")).unwrap() } fn upath(p: &Option>) -> &str { p.as_ref().map(|p| p.as_str()).unwrap_or("") } fn resolve(root: &FsPath, p: &str) -> Option { let d = percent_encoding::percent_decode_str(p).decode_utf8_lossy(); let c = d.trim_start_matches('/'); let f = if c.is_empty() { root.to_path_buf() } else { root.join(c) }; f.canonicalize().ok().filter(|p| p.starts_with(root)) } async fn get_h(State(st): State, path: Option>) -> Response { let up = upath(&path); let Some(p) = resolve(&st.root, up) else { return StatusCode::FORBIDDEN.into_response() }; if p.is_dir() { dir_list(&p, up).await } else if p.is_file() { serve_file(&p).await } else { StatusCode::NOT_FOUND.into_response() } } async fn upload_h(State(st): State, path: Option>, mut mp: Multipart) -> Response { let up = upath(&path); let Some(dir) = resolve(&st.root, up) else { return StatusCode::FORBIDDEN.into_response() }; let dir = if dir.is_dir() { dir } else { dir.parent().unwrap_or(&st.root).to_path_buf() }; while let Ok(Some(mut field)) = mp.next_field().await { if field.name() != Some("file") { continue } let name = match field.file_name() { Some(n) => sanitize(n), None => continue }; if name.is_empty() { continue } let dest = dir.join(&name); if !dest.starts_with(&st.root) { return StatusCode::FORBIDDEN.into_response() } println!("receiving: {}", dest.display()); let t = Instant::now(); let Ok(mut f) = fs::File::create(&dest).await else { return StatusCode::INTERNAL_SERVER_ERROR.into_response() }; let (mut sz, mut last_print): (u64, u64) = (0, 0); while let Ok(Some(ch)) = field.chunk().await { if f.write_all(&ch).await.is_err() { let _ = fs::remove_file(&dest).await; return StatusCode::INTERNAL_SERVER_ERROR.into_response(); } sz += ch.len() as u64; if sz - last_print >= 10 * 1024 * 1024 { let el = t.elapsed().as_secs_f64(); let spd = if el > 0.0 { (sz as f64 / el) as u64 } else { 0 }; print!("\r {} received, {}/s", hsz(sz), hsz(spd)); std::io::stdout().flush().ok(); last_print = sz; } } if last_print > 0 { println!() } if sz == 0 { let _ = fs::remove_file(&dest).await; continue } let el = t.elapsed().as_secs_f64(); println!("done: {} ({} @ {}/s)", dest.display(), hsz(sz), hsz(if el > 0.0 { (sz as f64 / el) as u64 } else { 0 })); } Redirect::to(&if up.is_empty() { "/".into() } else { format!("/{up}") }).into_response() } async fn serve_file(p: &FsPath) -> Response { let Ok(b) = fs::read(p).await else { return StatusCode::INTERNAL_SERVER_ERROR.into_response() }; let mime = mime_guess::from_path(p).first_or_octet_stream().to_string(); let name = p.file_name().unwrap().to_string_lossy(); let inline = mime.starts_with("text/") || mime.starts_with("image/") || mime == "application/pdf"; let d = format!("{}; filename=\"{name}\"", if inline { "inline" } else { "attachment" }); Response::builder().header(header::CONTENT_TYPE, mime).header(header::CONTENT_DISPOSITION, d) .body(Body::from(b)).unwrap() } async fn dir_list(dir: &FsPath, url_path: &str) -> Response { let Ok(mut rd) = fs::read_dir(dir).await else { return StatusCode::FORBIDDEN.into_response() }; let (mut ds, mut fs_): (Vec<(String, String)>, Vec<(String, String, u64)>) = (vec![], vec![]); while let Ok(Some(e)) = rd.next_entry().await { let n = e.file_name().to_string_lossy().into_owned(); let Ok(m) = e.metadata().await else { continue }; let enc = percent_encoding::utf8_percent_encode(&n, percent_encoding::NON_ALPHANUMERIC).to_string(); if m.is_dir() { ds.push((n, enc)) } else { fs_.push((n, enc, m.len())) } } ds.sort_by(|a, b| a.0.to_lowercase().cmp(&b.0.to_lowercase())); fs_.sort_by(|a, b| a.0.to_lowercase().cmp(&b.0.to_lowercase())); let dp = if url_path.is_empty() { "/" } else { url_path }; let de = esc(dp); let mut r = String::new(); if !url_path.is_empty() { r.push_str("../\n") } for (n, e) in &ds { r.push_str(&format!("{}/\n", esc(n))) } for (n, e, s) in &fs_ { let (nm, sz) = (esc(n), hsz(*s)); let pad = 60usize.saturating_sub(nm.len()).max(2); r.push_str(&format!("{nm}{:>w$}\n", sz, w = pad + sz.len())); } Html(format!(r#"{de}

{de}

{r}
drop files |
"#)).into_response() } fn sanitize(n: &str) -> String { FsPath::new(n).file_name().map(|n| n.to_string_lossy().replace(['/', '\\', '\0'], "_")).unwrap_or_default() } fn esc(s: &str) -> String { s.replace('&', "&").replace('<', "<").replace('>', ">").replace('"', """) } fn hsz(b: u64) -> String { let mut s = b as f64; for u in ["B","KB","MB","GB","TB"] { if s < 1024.0 { return if u=="B" { format!("{s:.0} {u}") } else { format!("{s:.1} {u}") } } s /= 1024.0 } format!("{s:.1} PB") }