You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

209 lines
6.1 KiB

use std::sync::{Arc, Mutex};
use std::net::{TcpListener, TcpStream};
use std::thread;
use clap::Parser;
use std::io::Write;
use std::io::{self, BufRead};
use chrono::{DateTime, Utc};
use std::collections::BTreeMap;
use serde::{Serialize, Deserialize};
use std::fs::File;
/// Represent one single chat user and their network stuff
struct ChatUser {
/// user id
id: u64,
/// Name of the chat user
name: String,
/// TCPStream socket object thingy
socket: TcpStream,
}
impl ChatUser {
pub fn new(id: u64, name: String, socket: TcpStream) -> Self {
Self{id, name, socket}
}
}
struct ChatServer {
config: ConfigArgs,
user_map: BTreeMap<u64, ChatUser>,
user_id_counter: u64,
}
impl ChatServer {
pub fn new(config: ConfigArgs) -> Self {
Self{config, user_map: BTreeMap::new(), user_id_counter: 1}
}
pub fn get_user_count(&self) -> usize {
self.user_map.len()
}
/// register a client by creating a ChatUser for it and assigning it a user id
pub fn register(&mut self, name: String, socket: TcpStream) -> u64 {
let chat_user = ChatUser::new(self.user_id_counter, name, socket);
let client_id = chat_user.id;
self.user_map.insert(chat_user.id, chat_user);
self.user_id_counter += 1;
client_id
}
pub fn deregister(&mut self, client_id: u64) -> Option<ChatUser> {
self.user_map.remove(&client_id)
}
pub fn send_to_all(&mut self, msg: &str) -> Result<(), ChatError> {
let now: DateTime<Utc> = Utc::now();
let time_msg = format!("[{}] {}", now.format("%H:%M:%S"), msg);
let mut broken_clients = vec![];
for (user_id, user) in self.user_map.iter_mut() {
if write!(user.socket, "{}\n", time_msg).is_err() {
broken_clients.push((*user_id, user.name.clone()));
}
}
// remove users from map
for (user_id, _) in broken_clients.iter() {
self.deregister(*user_id);
}
// send out the news
for (_, name) in broken_clients.iter() {
self.send_to_all(&format!("* {} has left the chat (broken pipe)", name))?;
}
Ok(())
}
}
#[derive(Debug)]
enum ChatError {
IOError(std::io::Error),
MutexPoisonError(),
Protocol(String),
}
impl From<std::io::Error> for ChatError {
fn from(error: std::io::Error) -> Self {
ChatError::IOError(error)
}
}
impl <'a, T> From<std::sync::PoisonError<std::sync::MutexGuard<'a, T>>> for ChatError {
fn from(_error: std::sync::PoisonError<std::sync::MutexGuard<'a, T>>) -> Self {
ChatError::MutexPoisonError()
}
}
fn handle_client(mut stream: TcpStream, server: Arc<Mutex<ChatServer>>, client_id: &mut Option<u64>) -> Result<(), ChatError> {
let mut writer = stream.try_clone()?;
let reader = io::BufReader::new(&mut stream);
let mut lines = reader.lines();
// greet and recv nick
write!(writer, "Hello and welcome to RustChat!\nNick: ")?;
let nick = lines.next().ok_or(ChatError::Protocol(String::from("Could not recv nickname")))??;
// FIXME: check for nick name uniqueness
// print joined msg and register at server
server.lock()?.send_to_all(&format!("* {} has joined the chat", nick))?;
let curr_users = server.lock()?.get_user_count();
let users_in_room_msg = if curr_users == 1 { String::from("is 1 user") } else { format!("are {} users", curr_users) };
write!(writer, "Welcome, {}! You can now start to chat! There {} in the room.\n", nick, users_in_room_msg)?;
// server.lock()?.borrow_mut().push(ChatUser::new(nick.clone(), writer));
*client_id = Some(server.lock()?.register(nick.clone(), writer));
// read lines
for line in lines {
let msg_raw = line?;
if msg_raw.is_empty() {
continue;
}
let msg = format!("<{}> {}", nick, msg_raw);
server.lock()?.send_to_all(&msg)?;
println!("{}", msg);
}
Ok(())
}
fn run_thread(stream: TcpStream, server: Arc<Mutex<ChatServer>>) {
let mut client_id = None;
let quit_reason = match handle_client(stream, Arc::clone(&server), &mut client_id) {
Ok(()) => String::from("Connection closed by user's choice"),
Err(err) => format!("Client foobar't their thread: {:?}", err),
};
if let Some(client_id) = client_id {
let mut srv = server.lock().unwrap();
let name = srv.deregister(client_id).map(|client| client.name).unwrap_or("<unknown user>".into());
let quit_msg = format!("* {} has left the chat ({})", name, quit_reason);
println!("{}", quit_msg);
srv.send_to_all(&quit_msg).unwrap();
} else {
println!("Client thread terminated before identifying themselves");
}
}
/// Chatserver in RUST!
#[derive(Parser, Debug)]
#[clap(about, version, author)]
struct Args {
/// Path to config file
#[clap(short, long, default_value = "config.yaml")]
config: String,
/// IP address to bind to
#[clap(long)]
host: Option<String>,
/// Port to run the server on
#[clap(short, long)]
port: Option<u16>,
}
#[derive(Debug, Serialize, Deserialize)]
struct ConfigArgs {
host: String,
port: u16,
greeting_msg: String,
}
fn main() -> std::io::Result<()> {
// argument parsing
let args = Args::parse();
// read config file
let config_reader = File::open(args.config).expect("Could not open config file");
let mut config: ConfigArgs = serde_yaml::from_reader(&config_reader).expect("Could not parse config file");
if let Some(host) = args.host {
config.host = host;
}
if let Some(port) = args.port {
config.port = port;
}
let binding_host = format!("{}:{}", config.host, config.port);
println!("Binding to {}", binding_host);
let listener = TcpListener::bind(binding_host)?;
let server = Arc::new(Mutex::new(ChatServer::new(config)));
// accept connections and process them serially
for stream in listener.incoming() {
if let Ok(stream) = stream {
let server_clone = Arc::clone(&server);
thread::spawn(move || { run_thread(stream, server_clone); });
}
}
Ok(())
}