abstio/
download.rs

1use std::io::Write;
2
3use anyhow::{Context, Result};
4use futures_channel::mpsc;
5
6use abstutil::prettyprint_usize;
7
8/// Downloads bytes from a URL. This must be called with a tokio runtime somewhere. The caller
9/// creates an mpsc channel pair and provides the sender. Progress will be described through it.
10pub async fn download_bytes<I: AsRef<str>>(
11    url: I,
12    post_body: Option<String>,
13    progress: &mut mpsc::Sender<String>,
14) -> Result<Vec<u8>> {
15    let url = url.as_ref();
16    info!("Downloading {}", url);
17    let mut resp = if let Some(body) = post_body {
18        reqwest::Client::new()
19            .post(url)
20            .body(body)
21            .send()
22            .await
23            .unwrap()
24    } else {
25        reqwest::get(url).await.unwrap()
26    };
27    resp.error_for_status_ref()
28        .with_context(|| url.to_string())?;
29
30    let total_size = resp.content_length().map(|x| x as usize);
31    let mut bytes = Vec::new();
32    while let Some(chunk) = resp.chunk().await.unwrap() {
33        // TODO Throttle?
34        let msg = if let Some(n) = total_size {
35            format!(
36                "{:.2}% ({} / {} bytes)",
37                (bytes.len() as f64) / (n as f64) * 100.0,
38                prettyprint_usize(bytes.len()),
39                prettyprint_usize(n)
40            )
41        } else {
42            // One example where the HTTP response won't say the response size is the Overpass API
43            format!(
44                "{} bytes (unknown total size)",
45                prettyprint_usize(bytes.len())
46            )
47        };
48        if let Err(err) = progress.try_send(msg) {
49            warn!("Couldn't send download progress message: {}", err);
50        }
51
52        bytes.write_all(&chunk).unwrap();
53    }
54    println!();
55    Ok(bytes)
56}
57
58/// Download a file from a URL. This must be called with a tokio runtime somewhere. Progress will
59/// be printed to STDOUT.
60pub async fn download_to_file<I1: AsRef<str>, I2: AsRef<str>>(
61    url: I1,
62    post_body: Option<String>,
63    path: I2,
64) -> Result<()> {
65    let (mut tx, rx) = futures_channel::mpsc::channel(1000);
66    print_download_progress(rx);
67    let bytes = download_bytes(url, post_body, &mut tx).await?;
68    let path = path.as_ref();
69    || -> Result<()> {
70        fs_err::create_dir_all(std::path::Path::new(path).parent().unwrap())?;
71        let mut file = fs_err::File::create(path)?;
72        file.write_all(&bytes)?;
73        Ok(())
74    }()
75    .with_context(|| path.to_string())
76}
77
78/// Print download progress to STDOUT. Pass this the receiver, then call download_to_file or
79/// download_bytes with the sender.
80pub fn print_download_progress(mut progress: mpsc::Receiver<String>) {
81    tokio::task::spawn_blocking(move || loop {
82        match progress.try_next() {
83            Ok(Some(msg)) => {
84                abstutil::clear_current_line();
85                print!("{}", msg);
86                std::io::stdout().flush().unwrap();
87            }
88            Ok(None) => break,
89            // Per
90            // https://docs.rs/futures-channel/0.3.14/futures_channel/mpsc/struct.Receiver.html#method.try_next,
91            // this means no messages are available yet
92            Err(_) => {}
93        }
94    });
95}