diff --git a/Cargo.toml b/Cargo.toml index 22158a8..c311b27 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,4 +9,3 @@ edition = "2021" [dependencies] clap = { version = "3.0.0-rc.9", features = ["derive"] } chrono = "0.4" -retain_mut = "0.1.5" diff --git a/src/main.rs b/src/main.rs index 52e4d57..9a44d75 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,31 +5,11 @@ use clap::Parser; use std::io::Write; use std::io::{self, BufRead}; use chrono::{DateTime, Utc}; -use std::borrow::{Borrow, BorrowMut}; -use retain_mut::RetainMut; +//use std::borrow::{Borrow, BorrowMut}; +//use retain_mut::RetainMut; use std::collections::BTreeMap; -fn send_to_all(clients: &Arc>>, msg: &str) -> Result<(), ChatError> { - let now: DateTime = Utc::now(); - let time_msg = format!("[{}] {}", now.format("%H:%M:%S"), msg); - let mut broken_clients = vec![]; - clients.lock()?.borrow_mut().retain_mut( - | client | { - if write!(client.socket, "{}\n", time_msg).is_err() { - broken_clients.push(client.name.clone()); - return false - } - true - }); - - for name in broken_clients { - send_to_all(clients, &format!("* {} left the chat (broken pipe)", name))?; - } - - Ok(()) -} - /// Represent one single chat user and their network stuff struct ChatUser { @@ -56,15 +36,50 @@ struct ChatServer { impl ChatServer { pub fn new() -> Self { - Self{user_map: BtreeMap::new(), user_id_counter: 1} + Self{user_map: BTreeMap::new(), user_id_counter: 1} + } + + pub fn get_user_count(&self) -> usize { + self.user_map.len() } - // register a client by creating a ChatUser for it and assigning it a user id - pub fn register(&mut self, name: String, socket: TcpStream) { - let chat_user = ChatUser::new(user_id_counter, name, socket); + /// register a client by creating a ChatUser for it and assigning it a user id + pub fn register(&mut self, name: String, socket: TcpStream) -> u64 { + let chat_user = ChatUser::new(self.user_id_counter, name, socket); + let client_id = chat_user.id; self.user_map.insert(chat_user.id, chat_user); self.user_id_counter += 1; + + client_id + } + + pub fn deregister(&mut self, client_id: u64) -> Option { + self.user_map.remove(&client_id) } + + pub fn send_to_all(&mut self, msg: &str) -> Result<(), ChatError> { + let now: DateTime = Utc::now(); + let time_msg = format!("[{}] {}", now.format("%H:%M:%S"), msg); + let mut broken_clients = vec![]; + for (user_id, user) in self.user_map.iter_mut() { + if write!(user.socket, "{}\n", time_msg).is_err() { + broken_clients.push((*user_id, user.name.clone())); + } + } + + // remove users from map + for (user_id, _) in broken_clients.iter() { + self.deregister(*user_id); + } + + // send out the news + for (_, name) in broken_clients.iter() { + self.send_to_all(&format!("* {} has left the chat (broken pipe)", name))?; + } + + Ok(()) + } + } #[derive(Debug)] @@ -86,7 +101,7 @@ impl <'a, T> From>> for Chat } } -fn handle_client(mut stream: TcpStream, clients: Arc>>) -> Result<(), ChatError> { +fn handle_client(mut stream: TcpStream, server: Arc>, client_id: &mut Option) -> Result<(), ChatError> { let mut writer = stream.try_clone()?; let reader = io::BufReader::new(&mut stream); let mut lines = reader.lines(); @@ -97,11 +112,12 @@ fn handle_client(mut stream: TcpStream, clients: Arc>>) -> R // FIXME: check for nick name uniqueness // print joined msg and register at server - send_to_all(&clients, &format!("* {} has joined the chat", nick))?; - let curr_users = clients.lock()?.borrow().len(); + server.lock()?.send_to_all(&format!("* {} has joined the chat", nick))?; + let curr_users = server.lock()?.get_user_count(); let users_in_room_msg = if curr_users == 1 { String::from("is 1 user") } else { format!("are {} users", curr_users) }; write!(writer, "Welcome, {}! You can now start to chat! There {} in the room.\n", nick, users_in_room_msg)?; - clients.lock()?.borrow_mut().push(ChatUser::new(nick.clone(), writer)); + // server.lock()?.borrow_mut().push(ChatUser::new(nick.clone(), writer)); + *client_id = Some(server.lock()?.register(nick.clone(), writer)); // read lines for line in lines { @@ -110,19 +126,29 @@ fn handle_client(mut stream: TcpStream, clients: Arc>>) -> R continue; } let msg = format!("<{}> {}", nick, msg_raw); - send_to_all(&clients, &msg)?; + server.lock()?.send_to_all(&msg)?; println!("{}", msg); } Ok(()) } -fn run_thread(stream: TcpStream, clients: Arc>>) { - let quit_msg = match handle_client(stream, clients) { +fn run_thread(stream: TcpStream, server: Arc>) { + let mut client_id = None; + let quit_reason = match handle_client(stream, Arc::clone(&server), &mut client_id) { Ok(()) => String::from("Connection closed by user's choice"), Err(err) => format!("Client foobar't their thread: {:?}", err), }; - println!(" * Client has left the chat ({})", quit_msg); + + if let Some(client_id) = client_id { + let mut srv = server.lock().unwrap(); + let name = srv.deregister(client_id).map(|client| client.name).unwrap_or("".into()); + let quit_msg = format!("* {} has left the chat ({})", name, quit_reason); + println!("{}", quit_msg); + srv.send_to_all(&quit_msg).unwrap(); + } else { + println!("Client thread terminated before identifying themselves"); + } } @@ -145,13 +171,13 @@ fn main() -> std::io::Result<()> { println!("Binding to {}", binding_host); let listener = TcpListener::bind(binding_host)?; - let clients = Arc::new(Mutex::new(vec!())); + let server = Arc::new(Mutex::new(ChatServer::new())); // accept connections and process them serially for stream in listener.incoming() { if let Ok(stream) = stream { - let clients_clone = Arc::clone(&clients); - thread::spawn(move || { run_thread(stream, clients_clone); }); + let server_clone = Arc::clone(&server); + thread::spawn(move || { run_thread(stream, server_clone); }); } } Ok(())