Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions bin/src/config.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use serde::Deserialize;
use serde::{Deserialize, Serialize};

use rayhunter::analysis::analyzer::AnalyzerConfig;

use crate::error::RayhunterError;

#[derive(Debug, Deserialize)]
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(default)]
pub struct Config {
pub qmdl_store_path: String,
Expand Down Expand Up @@ -32,11 +32,11 @@ impl Default for Config {
}
}

pub fn parse_config<P>(path: P) -> Result<Config, RayhunterError>
pub async fn parse_config<P>(path: P) -> Result<Config, RayhunterError>
where
P: AsRef<std::path::Path>,
{
if let Ok(config_file) = std::fs::read_to_string(&path) {
if let Ok(config_file) = tokio::fs::read_to_string(&path).await {
Ok(toml::from_str(&config_file).map_err(RayhunterError::ConfigFileParsingError)?)
} else {
Ok(Config::default())
Expand Down
130 changes: 88 additions & 42 deletions bin/src/daemon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,17 @@ mod qmdl_store;
mod server;
mod stats;

use std::net::SocketAddr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;

use crate::config::{parse_args, parse_config};
use crate::diag::run_diag_read_thread;
use crate::error::RayhunterError;
use crate::pcap::get_pcap;
use crate::qmdl_store::RecordingStore;
use crate::server::{get_qmdl, get_zip, serve_static, ServerState};
use crate::stats::get_system_stats;
use crate::server::{get_config, get_qmdl, get_zip, serve_static, set_config, ServerState};
use crate::stats::{get_qmdl_manifest, get_system_stats};

use analysis::{
get_analysis_status, run_analysis_thread, start_analysis, AnalysisCtrlMessage, AnalysisStatus,
Expand All @@ -31,10 +35,8 @@ use diag::{
use log::{error, info};
use qmdl_store::RecordingStoreError;
use rayhunter::diag_device::DiagDevice;
use stats::get_qmdl_manifest;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio::select;
use tokio::sync::mpsc::{self, Sender};
use tokio::sync::{oneshot, RwLock};
use tokio::task::JoinHandle;
Expand All @@ -56,6 +58,8 @@ fn get_router() -> AppRouter {
.route("/api/analysis-report/{name}", get(get_analysis_report))
.route("/api/analysis", get(get_analysis_status))
.route("/api/analysis/{name}", post(start_analysis))
.route("/api/config", get(get_config))
.route("/api/config", post(set_config))
.route("/", get(|| async { Redirect::permanent("/index.html") }))
.route("/{*path}", get(serve_static))
}
Expand All @@ -65,14 +69,14 @@ fn get_router() -> AppRouter {
// (i.e. user hit ctrl+c)
async fn run_server(
task_tracker: &TaskTracker,
config: &config::Config,
state: Arc<ServerState>,
server_shutdown_rx: oneshot::Receiver<()>,
) -> JoinHandle<()> {
info!("spinning up server");
let app = get_router().with_state(state);
let addr = SocketAddr::from(([0, 0, 0, 0], config.port));
let addr = SocketAddr::from(([0, 0, 0, 0], state.config.port));
let listener = TcpListener::bind(&addr).await.unwrap();
let app = get_router().with_state(state);

task_tracker.spawn(async move {
info!("The orca is hunting for stingrays...");
axum::serve(listener, app)
Expand Down Expand Up @@ -118,46 +122,61 @@ async fn init_qmdl_store(config: &config::Config) -> Result<RecordingStore, Rayh
// Start a thread that'll track when user hits ctrl+c. When that happens,
// trigger various cleanup tasks, including sending signals to other threads to
// shutdown
fn run_ctrl_c_thread(
fn run_shutdown_thread(
task_tracker: &TaskTracker,
diag_device_sender: Sender<DiagDeviceCtrlMessage>,
daemon_restart_rx: oneshot::Receiver<()>,
should_restart_flag: Arc<AtomicBool>,
server_shutdown_tx: oneshot::Sender<()>,
maybe_ui_shutdown_tx: Option<oneshot::Sender<()>>,
maybe_key_input_shutdown_tx: Option<oneshot::Sender<()>>,
qmdl_store_lock: Arc<RwLock<RecordingStore>>,
analysis_tx: Sender<AnalysisCtrlMessage>,
) -> JoinHandle<Result<(), RayhunterError>> {
info!("create shutdown thread");

task_tracker.spawn(async move {
match tokio::signal::ctrl_c().await {
Ok(()) => {
let mut qmdl_store = qmdl_store_lock.write().await;
if qmdl_store.current_entry.is_some() {
info!("Closing current QMDL entry...");
qmdl_store.close_current_entry().await?;
info!("Done!");
select! {
res = tokio::signal::ctrl_c() => {
if let Err(err) = res {
error!("Unable to listen for shutdown signal: {}", err);
}

server_shutdown_tx
.send(())
.expect("couldn't send server shutdown signal");
info!("sending UI shutdown");
if let Some(ui_shutdown_tx) = maybe_ui_shutdown_tx {
ui_shutdown_tx
.send(())
.expect("couldn't send ui shutdown signal");
}
diag_device_sender
.send(DiagDeviceCtrlMessage::Exit)
.await
.expect("couldn't send Exit message to diag thread");
analysis_tx
.send(AnalysisCtrlMessage::Exit)
.await
.expect("couldn't send Exit message to analysis thread");
should_restart_flag.store(false, Ordering::Relaxed);
}
Err(err) => {
error!("Unable to listen for shutdown signal: {}", err);
res = daemon_restart_rx => {
if let Err(err) = res {
error!("Unable to listen for shutdown signal: {}", err);
}

should_restart_flag.store(true, Ordering::Relaxed);
}
};

let mut qmdl_store = qmdl_store_lock.write().await;
if qmdl_store.current_entry.is_some() {
info!("Closing current QMDL entry...");
qmdl_store.close_current_entry().await?;
info!("Done!");
}

server_shutdown_tx
.send(())
.expect("couldn't send server shutdown signal");
if let Some(ui_shutdown_tx) = maybe_ui_shutdown_tx {
let _ = ui_shutdown_tx.send(());
}
if let Some(key_input_shutdown_tx) = maybe_key_input_shutdown_tx {
let _ = key_input_shutdown_tx.send(());
}
diag_device_sender
.send(DiagDeviceCtrlMessage::Exit)
.await
.expect("couldn't send Exit message to diag thread");
analysis_tx
.send(AnalysisCtrlMessage::Exit)
.await
.expect("couldn't send Exit message to analysis thread");
Ok(())
})
}
Expand All @@ -167,8 +186,19 @@ async fn main() -> Result<(), RayhunterError> {
env_logger::init();

let args = parse_args();
let config = parse_config(&args.config_path)?;

loop {
let config = parse_config(&args.config_path).await?;
if !run_with_config(&args, config).await? {
return Ok(());
}
}
}

async fn run_with_config(
args: &config::Args,
config: config::Config,
) -> Result<bool, RayhunterError> {
// TaskTrackers give us an interface to spawn tokio threads, and then
// eventually await all of them ending
let task_tracker = TaskTracker::new();
Expand All @@ -181,6 +211,7 @@ async fn main() -> Result<(), RayhunterError> {
let (ui_update_tx, ui_update_rx) = mpsc::channel::<display::DisplayState>(1);
let (analysis_tx, analysis_rx) = mpsc::channel::<AnalysisCtrlMessage>(5);
let mut maybe_ui_shutdown_tx = None;
let mut maybe_key_input_shutdown_tx = None;
if !config.debug_mode {
let (ui_shutdown_tx, ui_shutdown_rx) = oneshot::channel();
maybe_ui_shutdown_tx = Some(ui_shutdown_tx);
Expand All @@ -206,10 +237,18 @@ async fn main() -> Result<(), RayhunterError> {
display::update_ui(&task_tracker, &config, ui_shutdown_rx, ui_update_rx);

info!("Starting Key Input service");
key_input::run_key_input_thread(&task_tracker, &config, diag_tx.clone());
let (key_input_shutdown_tx, key_input_shutdown_rx) = oneshot::channel();
maybe_key_input_shutdown_tx = Some(key_input_shutdown_tx);
key_input::run_key_input_thread(
&task_tracker,
&config,
diag_tx.clone(),
key_input_shutdown_rx,
);
}

let (daemon_restart_tx, daemon_restart_rx) = oneshot::channel::<()>();
let (server_shutdown_tx, server_shutdown_rx) = oneshot::channel::<()>();
info!("create shutdown thread");
let analysis_status_lock = Arc::new(RwLock::new(analysis_status));
run_analysis_thread(
&task_tracker,
Expand All @@ -219,29 +258,36 @@ async fn main() -> Result<(), RayhunterError> {
config.enable_dummy_analyzer,
config.analyzers.clone(),
);
run_ctrl_c_thread(
let should_restart_flag = Arc::new(AtomicBool::new(false));

run_shutdown_thread(
&task_tracker,
diag_tx.clone(),
daemon_restart_rx,
should_restart_flag.clone(),
server_shutdown_tx,
maybe_ui_shutdown_tx,
maybe_key_input_shutdown_tx,
qmdl_store_lock.clone(),
analysis_tx.clone(),
);
let state = Arc::new(ServerState {
config_path: args.config_path.clone(),
config,
qmdl_store_lock: qmdl_store_lock.clone(),
diag_device_ctrl_sender: diag_tx,
ui_update_sender: ui_update_tx,
debug_mode: config.debug_mode,
analysis_status_lock,
analysis_sender: analysis_tx,
daemon_restart_tx: Arc::new(RwLock::new(Some(daemon_restart_tx))),
});
run_server(&task_tracker, &config, state, server_shutdown_rx).await;
run_server(&task_tracker, state, server_shutdown_rx).await;

task_tracker.close();
task_tracker.wait().await;

info!("see you space cowboy...");
Ok(())
Ok(should_restart_flag.load(Ordering::Relaxed))
}

#[cfg(test)]
Expand Down
8 changes: 4 additions & 4 deletions bin/src/diag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ pub fn run_diag_read_thread(
pub async fn start_recording(
State(state): State<Arc<ServerState>>,
) -> Result<(StatusCode, String), (StatusCode, String)> {
if state.debug_mode {
if state.config.debug_mode {
return Err((StatusCode::FORBIDDEN, "server is in debug mode".to_string()));
}

Expand All @@ -179,7 +179,7 @@ pub async fn start_recording(
pub async fn stop_recording(
State(state): State<Arc<ServerState>>,
) -> Result<(StatusCode, String), (StatusCode, String)> {
if state.debug_mode {
if state.config.debug_mode {
return Err((StatusCode::FORBIDDEN, "server is in debug mode".to_string()));
}
state
Expand All @@ -199,7 +199,7 @@ pub async fn delete_recording(
State(state): State<Arc<ServerState>>,
Path(qmdl_name): Path<String>,
) -> Result<(StatusCode, String), (StatusCode, String)> {
if state.debug_mode {
if state.config.debug_mode {
return Err((StatusCode::FORBIDDEN, "server is in debug mode".to_string()));
}
let mut qmdl_store = state.qmdl_store_lock.write().await;
Expand Down Expand Up @@ -244,7 +244,7 @@ pub async fn delete_recording(
pub async fn delete_all_recordings(
State(state): State<Arc<ServerState>>,
) -> Result<(StatusCode, String), (StatusCode, String)> {
if state.debug_mode {
if state.config.debug_mode {
return Err((StatusCode::FORBIDDEN, "server is in debug mode".to_string()));
}
state
Expand Down
18 changes: 14 additions & 4 deletions bin/src/key_input.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use log::error;
use log::{error, info};
use std::time::{Duration, Instant};
use tokio::fs::File;
use tokio::io::AsyncReadExt;
use tokio::sync::mpsc::Sender;
use tokio::sync::oneshot;
use tokio_util::task::TaskTracker;

use crate::config;
Expand All @@ -20,6 +21,7 @@ pub fn run_key_input_thread(
task_tracker: &TaskTracker,
config: &config::Config,
diag_tx: Sender<DiagDeviceCtrlMessage>,
mut ui_shutdown_rx: oneshot::Receiver<()>,
) {
if config.key_input_mode == 0 {
return;
Expand All @@ -40,9 +42,17 @@ pub fn run_key_input_thread(
let mut last_event_time: Option<Instant> = None;

loop {
if let Err(e) = file.read_exact(&mut buffer).await {
error!("failed to read key input: {}", e);
return;
tokio::select! {
_ = &mut ui_shutdown_rx => {
info!("received key input shutdown");
return;
}
result = file.read_exact(&mut buffer) => {
if let Err(e) = result {
error!("failed to read key input: {}", e);
return;
}
}
}

let event = parse_event(buffer);
Expand Down
Loading
Loading