use std::sync::{Arc, Mutex}; use std::net::{TcpListener, TcpStream}; use std::thread; use clap::Parser; use std::io::Write; use std::io::{self, BufRead}; use chrono::{DateTime, Utc}; use std::collections::BTreeMap; use serde::{Serialize, Deserialize}; use std::fs::File; /// Represent one single chat user and their network stuff struct ChatUser { /// user id id: u64, /// Name of the chat user name: String, /// TCPStream socket object thingy socket: TcpStream, } impl ChatUser { pub fn new(id: u64, name: String, socket: TcpStream) -> Self { Self{id, name, socket} } } struct ChatServer { config: ConfigArgs, user_map: BTreeMap, user_id_counter: u64, } impl ChatServer { pub fn new(config: ConfigArgs) -> Self { Self{config, user_map: BTreeMap::new(), user_id_counter: 1} } pub fn get_user_count(&self) -> usize { self.user_map.len() } /// register a client by creating a ChatUser for it and assigning it a user id pub fn register(&mut self, name: String, socket: TcpStream) -> u64 { let chat_user = ChatUser::new(self.user_id_counter, name, socket); let client_id = chat_user.id; self.user_map.insert(chat_user.id, chat_user); self.user_id_counter += 1; client_id } pub fn deregister(&mut self, client_id: u64) -> Option { self.user_map.remove(&client_id) } pub fn send_to_all(&mut self, msg: &str) -> Result<(), ChatError> { let now: DateTime = Utc::now(); let time_msg = format!("[{}] {}", now.format("%H:%M:%S"), msg); let mut broken_clients = vec![]; for (user_id, user) in self.user_map.iter_mut() { if write!(user.socket, "{}\n", time_msg).is_err() { broken_clients.push((*user_id, user.name.clone())); } } // remove users from map for (user_id, _) in broken_clients.iter() { self.deregister(*user_id); } // send out the news for (_, name) in broken_clients.iter() { self.send_to_all(&format!("* {} has left the chat (broken pipe)", name))?; } Ok(()) } } #[derive(Debug)] enum ChatError { IOError(std::io::Error), MutexPoisonError(), Protocol(String), } impl From for ChatError { fn from(error: std::io::Error) -> Self { ChatError::IOError(error) } } impl <'a, T> From>> for ChatError { fn from(_error: std::sync::PoisonError>) -> Self { ChatError::MutexPoisonError() } } fn handle_client(mut stream: TcpStream, server: Arc>, client_id: &mut Option) -> Result<(), ChatError> { let mut writer = stream.try_clone()?; let reader = io::BufReader::new(&mut stream); let mut lines = reader.lines(); // greet and recv nick write!(writer, "Hello and welcome to RustChat!\nNick: ")?; let nick = lines.next().ok_or(ChatError::Protocol(String::from("Could not recv nickname")))??; // FIXME: check for nick name uniqueness // print joined msg and register at server server.lock()?.send_to_all(&format!("* {} has joined the chat", nick))?; let curr_users = server.lock()?.get_user_count(); let users_in_room_msg = if curr_users == 1 { String::from("is 1 user") } else { format!("are {} users", curr_users) }; write!(writer, "Welcome, {}! You can now start to chat! There {} in the room.\n", nick, users_in_room_msg)?; // server.lock()?.borrow_mut().push(ChatUser::new(nick.clone(), writer)); *client_id = Some(server.lock()?.register(nick.clone(), writer)); // read lines for line in lines { let msg_raw = line?; if msg_raw.is_empty() { continue; } let msg = format!("<{}> {}", nick, msg_raw); server.lock()?.send_to_all(&msg)?; println!("{}", msg); } Ok(()) } fn run_thread(stream: TcpStream, server: Arc>) { let mut client_id = None; let quit_reason = match handle_client(stream, Arc::clone(&server), &mut client_id) { Ok(()) => String::from("Connection closed by user's choice"), Err(err) => format!("Client foobar't their thread: {:?}", err), }; if let Some(client_id) = client_id { let mut srv = server.lock().unwrap(); let name = srv.deregister(client_id).map(|client| client.name).unwrap_or("".into()); let quit_msg = format!("* {} has left the chat ({})", name, quit_reason); println!("{}", quit_msg); srv.send_to_all(&quit_msg).unwrap(); } else { println!("Client thread terminated before identifying themselves"); } } /// Chatserver in RUST! #[derive(Parser, Debug)] #[clap(about, version, author)] struct Args { /// Path to config file #[clap(short, long, default_value = "config.yaml")] config: String, /// IP address to bind to #[clap(long)] host: Option, /// Port to run the server on #[clap(short, long)] port: Option, } #[derive(Debug, Serialize, Deserialize)] struct ConfigArgs { host: String, port: u16, greeting_msg: String, } fn main() -> std::io::Result<()> { // argument parsing let args = Args::parse(); // read config file let config_reader = File::open(args.config).expect("Could not open config file"); let mut config: ConfigArgs = serde_yaml::from_reader(&config_reader).expect("Could not parse config file"); if let Some(host) = args.host { config.host = host; } if let Some(port) = args.port { config.port = port; } let binding_host = format!("{}:{}", config.host, config.port); println!("Binding to {}", binding_host); let listener = TcpListener::bind(binding_host)?; let server = Arc::new(Mutex::new(ChatServer::new(config))); // accept connections and process them serially for stream in listener.incoming() { if let Ok(stream) = stream { let server_clone = Arc::clone(&server); thread::spawn(move || { run_thread(stream, server_clone); }); } } Ok(()) }