From 51cdf13953101a97d4296e9a79120df3aef8150d Mon Sep 17 00:00:00 2001 From: chmanie Date: Thu, 8 Jun 2023 18:55:11 +0200 Subject: [PATCH] Add stream-download package --- Cargo.lock | 39 +++- Cargo.toml | 1 + stream-download/.gitignore | 10 + stream-download/Cargo.toml | 28 +++ stream-download/LICENSE | 21 ++ stream-download/README.md | 1 + stream-download/examples/audio.rs | 18 ++ stream-download/examples/no_tokio_runtime.rs | 17 ++ stream-download/examples/stream.rs | 18 ++ stream-download/src/http.rs | 81 +++++++ stream-download/src/lib.rs | 166 ++++++++++++++ stream-download/src/source.rs | 228 +++++++++++++++++++ 12 files changed, 627 insertions(+), 1 deletion(-) create mode 100644 stream-download/.gitignore create mode 100644 stream-download/Cargo.toml create mode 100644 stream-download/LICENSE create mode 100644 stream-download/README.md create mode 100644 stream-download/examples/audio.rs create mode 100644 stream-download/examples/no_tokio_runtime.rs create mode 100644 stream-download/examples/stream.rs create mode 100644 stream-download/src/http.rs create mode 100644 stream-download/src/lib.rs create mode 100644 stream-download/src/source.rs diff --git a/Cargo.lock b/Cargo.lock index 807c604..f20375f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -392,6 +392,12 @@ dependencies = [ "libloading", ] +[[package]] +name = "claxon" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4bfbf56724aa9eca8afa4fcfadeb479e722935bb2a0900c2d37e0cc477af0688" + [[package]] name = "combine" version = "4.6.6" @@ -1015,6 +1021,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "hound" +version = "3.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4d13cdbd5dbb29f9c88095bbdc2590c9cba0d0a1269b983fef6b2cdd7e9f4db1" + [[package]] name = "http" version = "0.2.9" @@ -1293,6 +1305,17 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" +[[package]] +name = "lewton" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "777b48df9aaab155475a83a7df3070395ea1ac6902f5cd062b8f2b028075c030" +dependencies = [ + "byteorder", + "ogg", + "tinyvec", +] + [[package]] name = "libc" version = "0.2.144" @@ -1666,6 +1689,15 @@ dependencies = [ "cc", ] +[[package]] +name = "ogg" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6951b4e8bf21c8193da321bcce9c9dd2e13c858fe078bf9054a288b419ae5d6e" +dependencies = [ + "byteorder", +] + [[package]] name = "once_cell" version = "1.17.1" @@ -2177,7 +2209,10 @@ version = "0.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bdf1d4dea18dff2e9eb6dca123724f8b60ef44ad74a9ad283cdfe025df7e73fa" dependencies = [ + "claxon", "cpal", + "hound", + "lewton", "symphonia", ] @@ -2503,7 +2538,6 @@ checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" [[package]] name = "stream-download" version = "0.1.0" -source = "git+https://github.com/aschey/stream-download-rs.git#f453f227d81ebee2c19001c32e8ee505dd354131" dependencies = [ "async-trait", "bytes", @@ -2512,9 +2546,12 @@ dependencies = [ "parking_lot", "rangemap", "reqwest", + "rodio", + "symphonia", "tempfile", "tokio", "tracing", + "tracing-subscriber", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index c299094..7493500 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,5 +4,6 @@ members = [ "cbd-tui", "crabidy-core", "crabidy-server", + "stream-download", "tidaldy", ] diff --git a/stream-download/.gitignore b/stream-download/.gitignore new file mode 100644 index 0000000..088ba6b --- /dev/null +++ b/stream-download/.gitignore @@ -0,0 +1,10 @@ +# Generated by Cargo +# will have compiled files and executables +/target/ + +# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries +# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html +Cargo.lock + +# These are backup files generated by rustfmt +**/*.rs.bk diff --git a/stream-download/Cargo.toml b/stream-download/Cargo.toml new file mode 100644 index 0000000..925e0f3 --- /dev/null +++ b/stream-download/Cargo.toml @@ -0,0 +1,28 @@ +[package] +edition = "2021" +name = "stream-download" +version = "0.1.0" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +async-trait = "0.1" +bytes = "1" +futures = "0.3" +futures-util = "0.3" +parking_lot = "0.12" +rangemap = "1" +reqwest = { version = "0.11", features = ["stream"], optional = true } +symphonia = "0.5" +tempfile = "3" +tokio = { version = "1", features = ["sync", "macros"] } +tracing = "0.1" + +[features] +default = ["http"] +http = ["reqwest"] + +[dev-dependencies] +rodio = "0.17.1" +tracing-subscriber = "0.3.16" +tokio = { version = "1", features = ["sync", "macros", "rt-multi-thread"] } diff --git a/stream-download/LICENSE b/stream-download/LICENSE new file mode 100644 index 0000000..cb5dd95 --- /dev/null +++ b/stream-download/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Austin Schey + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/stream-download/README.md b/stream-download/README.md new file mode 100644 index 0000000..917f8f2 --- /dev/null +++ b/stream-download/README.md @@ -0,0 +1 @@ +# stream-download-rs \ No newline at end of file diff --git a/stream-download/examples/audio.rs b/stream-download/examples/audio.rs new file mode 100644 index 0000000..b70f97f --- /dev/null +++ b/stream-download/examples/audio.rs @@ -0,0 +1,18 @@ +use stream_download::StreamDownload; + +#[tokio::main] +async fn main() { + tracing_subscriber::fmt().init(); + let (_stream, handle) = rodio::OutputStream::try_default().unwrap(); + let sink = rodio::Sink::try_new(&handle).unwrap(); + + let reader = StreamDownload::new_http( + "https://dl.espressif.com/dl/audio/ff-16b-2c-44100hz.flac" + .parse() + .unwrap(), + ); + + sink.append(rodio::Decoder::new(reader).unwrap()); + + sink.sleep_until_end(); +} diff --git a/stream-download/examples/no_tokio_runtime.rs b/stream-download/examples/no_tokio_runtime.rs new file mode 100644 index 0000000..75bc90d --- /dev/null +++ b/stream-download/examples/no_tokio_runtime.rs @@ -0,0 +1,17 @@ +use stream_download::StreamDownload; + +fn main() { + tracing_subscriber::fmt().init(); + let (_stream, handle) = rodio::OutputStream::try_default().unwrap(); + let sink = rodio::Sink::try_new(&handle).unwrap(); + + let reader = StreamDownload::new_http( + "https://dl.espressif.com/dl/audio/ff-16b-2c-44100hz.flac" + .parse() + .unwrap(), + ); + + sink.append(rodio::Decoder::new(reader).unwrap()); + + sink.sleep_until_end(); +} diff --git a/stream-download/examples/stream.rs b/stream-download/examples/stream.rs new file mode 100644 index 0000000..cb32450 --- /dev/null +++ b/stream-download/examples/stream.rs @@ -0,0 +1,18 @@ +use stream_download::StreamDownload; + +#[tokio::main] +async fn main() { + tracing_subscriber::fmt().init(); + let (_stream, handle) = rodio::OutputStream::try_default().unwrap(); + let sink = rodio::Sink::try_new(&handle).unwrap(); + + let reader = StreamDownload::new_http( + "https://uk1.internet-radio.com/proxy/pinknoise?mp=/stream" + .parse() + .unwrap(), + ); + + sink.append(rodio::Decoder::new(reader).unwrap()); + + sink.sleep_until_end(); +} diff --git a/stream-download/src/http.rs b/stream-download/src/http.rs new file mode 100644 index 0000000..3d7cf0f --- /dev/null +++ b/stream-download/src/http.rs @@ -0,0 +1,81 @@ +use async_trait::async_trait; +use bytes::Bytes; +use futures::Stream; +use reqwest::Client; +use std::{ + pin::Pin, + str::FromStr, + task::{self, Poll}, +}; +use tracing::{info, warn}; + +use crate::source::SourceStream; + +pub struct HttpStream { + stream: Box> + Unpin + Send + Sync>, + client: Client, + content_length: Option, + url: reqwest::Url, +} + +impl Stream for HttpStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { + Pin::new(&mut self.stream).poll_next(cx) + } +} + +#[async_trait] +impl SourceStream for HttpStream { + type Url = reqwest::Url; + type Error = reqwest::Error; + + async fn create(url: Self::Url) -> Self { + let client = Client::new(); + info!("Requesting content length"); + let response = client.get(url.as_str()).send().await.unwrap(); + + let mut content_length = None; + if let Some(length) = response.headers().get(reqwest::header::CONTENT_LENGTH) { + let length = u64::from_str(length.to_str().unwrap()).unwrap(); + info!("Got content length {length}"); + content_length = Some(length); + } else { + warn!("Content length header missing"); + } + + let stream = response.bytes_stream(); + Self { + stream: Box::new(stream), + client, + content_length, + url, + } + } + + async fn content_length(&self) -> Option { + self.content_length + } + async fn seek(&mut self, pos: u64) { + info!("Seeking"); + self.stream = Box::new( + self.client + .get(self.url.as_str()) + .header( + "Range", + format!( + "bytes={pos}-{}", + self.content_length + .map(|l| l.to_string()) + .unwrap_or_default() + ), + ) + .send() + .await + .unwrap() + .bytes_stream(), + ); + info!("Done seeking"); + } +} diff --git a/stream-download/src/lib.rs b/stream-download/src/lib.rs new file mode 100644 index 0000000..f02551b --- /dev/null +++ b/stream-download/src/lib.rs @@ -0,0 +1,166 @@ +use source::{Source, SourceHandle, SourceStream}; +use std::{ + io::{self, BufReader, Read, Seek, SeekFrom}, + thread, +}; +use symphonia::core::io::MediaSource; +use tempfile::NamedTempFile; +use tracing::debug; + +#[cfg(feature = "http")] +pub mod http; +pub mod source; + +#[derive(Debug)] +pub struct StreamDownload { + output_reader: BufReader, + handle: SourceHandle, + read_position: u64, +} + +impl StreamDownload { + #[cfg(feature = "http")] + pub fn new_http(url: reqwest::Url) -> Self { + Self::new::(url) + } + + pub fn new(url: S::Url) -> Self { + let tempfile = tempfile::Builder::new().tempfile().unwrap(); + let source = Source::new(tempfile.reopen().unwrap()); + let handle = source.source_handle(); + + if let Ok(handle) = tokio::runtime::Handle::try_current() { + handle.spawn(async move { + let stream = S::create(url).await; + source.download(stream).await; + }); + } else { + thread::spawn(move || { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + rt.block_on(async move { + let stream = S::create(url).await; + source.download(stream).await; + }); + }); + }; + + Self { + output_reader: BufReader::new(tempfile), + read_position: 0, + handle, + } + } + + pub fn from_stream(stream: S) -> Self { + let tempfile = tempfile::Builder::new().tempfile().unwrap(); + let source = Source::new(tempfile.reopen().unwrap()); + let handle = source.source_handle(); + + if let Ok(handle) = tokio::runtime::Handle::try_current() { + handle.spawn(async move { + source.download(stream).await; + }); + } else { + thread::spawn(move || { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap(); + + rt.block_on(async move { + source.download(stream).await; + }); + }); + }; + + Self { + output_reader: BufReader::new(tempfile), + handle, + read_position: 0, + } + } +} + +impl Read for StreamDownload { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + debug!("Read request buf len: {}", buf.len()); + + let requested_position = self.read_position + buf.len() as u64; + debug!( + "read: current position: {} requested position: {requested_position}", + self.read_position + ); + + if let Some(closest_set) = self.handle.downloaded().get(&self.read_position) { + debug!("Already downloaded {closest_set:?}"); + if closest_set.end >= requested_position { + let read_len = self.output_reader.read(buf); + if let Ok(read_len) = read_len { + self.read_position += read_len as u64; + } + return read_len; + } + } + self.handle.request_position(requested_position); + + debug!("waiting for position"); + self.handle.wait_for_requested_position(); + + debug!("reached requested position {requested_position}"); + self.output_reader.read(buf) + } +} + +impl Seek for StreamDownload { + fn seek(&mut self, pos: SeekFrom) -> io::Result { + let seek_pos = match pos { + SeekFrom::Start(pos) => pos, + SeekFrom::End(pos) => { + if let Some(length) = self.handle.content_length() { + (length as i64 + pos) as u64 + } else { + return Err(io::Error::new( + io::ErrorKind::Unsupported, + "Cannot seek from end when content length is unknown", + )); + } + } + SeekFrom::Current(pos) => (self.read_position as i64 + pos) as u64, + }; + + if let Some(closest_set) = self.handle.downloaded().get(&seek_pos) { + if closest_set.end >= seek_pos { + let new_pos = self.output_reader.seek(pos); + if let Ok(new_pos) = new_pos { + self.read_position = new_pos; + } + } + } + + self.handle.request_position(seek_pos); + debug!( + "seek: current position {seek_pos} requested position {:?}. waiting", + seek_pos + ); + self.handle.seek(seek_pos); + self.handle.wait_for_requested_position(); + + debug!("reached seek position"); + self.output_reader.seek(pos) + } +} + +impl MediaSource for StreamDownload { + fn is_seekable(&self) -> bool { + true + } + + // FIXME: Can this be implemented? + fn byte_len(&self) -> Option { + None + } +} diff --git a/stream-download/src/source.rs b/stream-download/src/source.rs new file mode 100644 index 0000000..9a21738 --- /dev/null +++ b/stream-download/src/source.rs @@ -0,0 +1,228 @@ +use async_trait::async_trait; +use bytes::Bytes; +use futures::{Stream, StreamExt}; +use parking_lot::{Condvar, Mutex, RwLock, RwLockReadGuard}; +use rangemap::RangeSet; +use std::{ + error::Error, + fs::File, + io::{BufWriter, Seek, SeekFrom, Write}, + sync::{ + atomic::{AtomicI64, Ordering}, + Arc, + }, +}; +use tokio::sync::mpsc; +use tracing::{debug, info, trace}; + +#[async_trait] +pub trait SourceStream: + Stream> + Unpin + Send + Sync + 'static +{ + type Url: Send; + type Error: Error + Send; + + async fn create(url: Self::Url) -> Self; + async fn content_length(&self) -> Option; + async fn seek(&mut self, position: u64); +} + +#[derive(Debug, Clone)] +pub struct SourceHandle { + downloaded: Arc>>, + requested_position: Arc, + position_reached: Arc<(Mutex, Condvar)>, + content_length_retrieved: Arc<(Mutex, Condvar)>, + content_length: Arc, + seek_tx: mpsc::Sender, +} + +impl SourceHandle { + pub fn downloaded(&self) -> RwLockReadGuard> { + self.downloaded.read() + } + + pub fn request_position(&self, position: u64) { + self.requested_position + .store(position as i64, Ordering::SeqCst); + } + + pub fn wait_for_requested_position(&self) { + let (mutex, cvar) = &*self.position_reached; + let mut waiter = mutex.lock(); + if !waiter.stream_done { + debug!("Waiting for requested position"); + cvar.wait_while(&mut waiter, |waiter| { + !waiter.stream_done && !waiter.position_reached + }); + if !waiter.stream_done { + waiter.position_reached = false; + } + + debug!("Position reached"); + } + } + + pub fn seek(&self, position: u64) { + self.seek_tx.try_send(position).ok(); + } + + pub fn content_length(&self) -> Option { + let (mutex, cvar) = &*self.content_length_retrieved; + let mut done = mutex.lock(); + if !*done { + cvar.wait_while(&mut done, |done| !*done); + } + let length = self.content_length.load(Ordering::SeqCst); + if length > -1 { + Some(length as u64) + } else { + None + } + } +} + +#[derive(Default, Debug)] +struct Waiter { + position_reached: bool, + stream_done: bool, +} + +pub struct Source { + writer: BufWriter, + downloaded: Arc>>, + position: u64, + requested_position: Arc, + position_reached: Arc<(Mutex, Condvar)>, + content_length_retrieved: Arc<(Mutex, Condvar)>, + content_length: Arc, + seek_tx: mpsc::Sender, + seek_rx: mpsc::Receiver, +} + +const PREFETCH_BYTES: u64 = 1024 * 256; + +impl Source { + pub fn new(tempfile: File) -> Self { + let (seek_tx, seek_rx) = mpsc::channel(32); + Self { + writer: BufWriter::new(tempfile), + downloaded: Default::default(), + position: Default::default(), + requested_position: Arc::new(AtomicI64::new(-1)), + position_reached: Default::default(), + content_length_retrieved: Default::default(), + seek_tx, + seek_rx, + content_length: Default::default(), + } + } + + pub async fn download(mut self, mut stream: S) { + info!("Starting file download"); + let content_length = stream.content_length().await; + if let Some(content_length) = content_length { + self.content_length + .swap(content_length as i64, Ordering::SeqCst); + } else { + self.content_length.swap(-1, Ordering::SeqCst); + } + + { + let (mutex, cvar) = &*self.content_length_retrieved; + *mutex.lock() = true; + cvar.notify_all(); + } + + let mut initial_buffer = 0; + loop { + if let Some(bytes) = stream.next().await { + let bytes = bytes.unwrap(); + self.writer.write_all(&bytes).unwrap(); + initial_buffer += bytes.len() as u64; + trace!("Prefetch: {}/{} bytes", initial_buffer, PREFETCH_BYTES); + if initial_buffer >= PREFETCH_BYTES { + self.position += initial_buffer; + self.downloaded.write().insert(0..initial_buffer); + break; + } + } else { + info!("File shorter than prefetch length"); + self.writer.flush().unwrap(); + self.position += initial_buffer; + self.downloaded.write().insert(0..initial_buffer); + let (mutex, cvar) = &*self.position_reached; + (mutex.lock()).stream_done = true; + cvar.notify_all(); + return; + } + } + + info!("Prefetch complete"); + loop { + tokio::select! { + bytes = stream.next() => { + if let Some(bytes) = bytes { + let bytes = bytes.unwrap(); + let chunk_len = bytes.len() as u64; + self.writer.write_all(&bytes).unwrap(); + let new_position = self.position + chunk_len; + + trace!("Received response chunk. position={}", new_position); + self.downloaded.write().insert(self.position..new_position); + let requested = self.requested_position.load(Ordering::SeqCst); + if requested > -1 { + debug!("downloader: requested {requested} current {}", new_position); + } + + if requested > -1 && new_position as i64 >= requested { + info!("Notifying"); + self.requested_position.store(-1, Ordering::SeqCst); + let (mutex, cvar) = &*self.position_reached; + (mutex.lock()).position_reached = true; + cvar.notify_all(); + } + self.position = new_position; + } else { + info!("Stream finished downloading"); + self.writer.flush().unwrap(); + let (mutex, cvar) = &*self.position_reached; + (mutex.lock()).stream_done = true; + cvar.notify_all(); + return; + } + }, + pos = self.seek_rx.recv() => { + if let Some(pos) = pos { + debug!("Received seek position {pos}"); + let do_seek = { + let downloaded = self.downloaded.read(); + if let Some(range) = downloaded.get(&pos) { + !range.contains(&self.position) + } else { + true + } + }; + + if do_seek { + stream.seek(pos).await; + self.writer.seek(SeekFrom::Start(pos)).unwrap(); + self.position = pos; + } + } + } + } + } + } + + pub fn source_handle(&self) -> SourceHandle { + SourceHandle { + downloaded: self.downloaded.clone(), + requested_position: self.requested_position.clone(), + position_reached: self.position_reached.clone(), + seek_tx: self.seek_tx.clone(), + content_length_retrieved: self.content_length_retrieved.clone(), + content_length: self.content_length.clone(), + } + } +}