You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
159 lines
4.6 KiB
159 lines
4.6 KiB
use std::sync::{Arc, Mutex};
|
|
use std::net::{TcpListener, TcpStream};
|
|
use std::thread;
|
|
use clap::Parser;
|
|
use std::io::Write;
|
|
use std::io::{self, BufRead};
|
|
use chrono::{DateTime, Utc};
|
|
use std::borrow::{Borrow, BorrowMut};
|
|
use retain_mut::RetainMut;
|
|
use std::collections::BTreeMap;
|
|
|
|
|
|
fn send_to_all(clients: &Arc<Mutex<Vec<ChatUser>>>, msg: &str) -> Result<(), ChatError> {
|
|
let now: DateTime<Utc> = Utc::now();
|
|
let time_msg = format!("[{}] {}", now.format("%H:%M:%S"), msg);
|
|
let mut broken_clients = vec![];
|
|
clients.lock()?.borrow_mut().retain_mut(
|
|
| client | {
|
|
if write!(client.socket, "{}\n", time_msg).is_err() {
|
|
broken_clients.push(client.name.clone());
|
|
return false
|
|
}
|
|
true
|
|
});
|
|
|
|
for name in broken_clients {
|
|
send_to_all(clients, &format!("* {} left the chat (broken pipe)", name))?;
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
|
|
/// Represent one single chat user and their network stuff
|
|
struct ChatUser {
|
|
/// user id
|
|
id: u64,
|
|
|
|
/// Name of the chat user
|
|
name: String,
|
|
|
|
/// TCPStream socket object thingy
|
|
socket: TcpStream,
|
|
}
|
|
|
|
impl ChatUser {
|
|
pub fn new(id: u64, name: String, socket: TcpStream) -> Self {
|
|
Self{id, name, socket}
|
|
}
|
|
}
|
|
|
|
struct ChatServer {
|
|
user_map: BTreeMap<u64, ChatUser>,
|
|
user_id_counter: u64,
|
|
}
|
|
|
|
impl ChatServer {
|
|
pub fn new() -> Self {
|
|
Self{user_map: BtreeMap::new(), user_id_counter: 1}
|
|
}
|
|
|
|
// register a client by creating a ChatUser for it and assigning it a user id
|
|
pub fn register(&mut self, name: String, socket: TcpStream) {
|
|
let chat_user = ChatUser::new(user_id_counter, name, socket);
|
|
self.user_map.insert(chat_user.id, chat_user);
|
|
self.user_id_counter += 1;
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
enum ChatError {
|
|
IOError(std::io::Error),
|
|
MutexPoisonError(),
|
|
Protocol(String),
|
|
}
|
|
|
|
impl From<std::io::Error> for ChatError {
|
|
fn from(error: std::io::Error) -> Self {
|
|
ChatError::IOError(error)
|
|
}
|
|
}
|
|
|
|
impl <'a, T> From<std::sync::PoisonError<std::sync::MutexGuard<'a, T>>> for ChatError {
|
|
fn from(_error: std::sync::PoisonError<std::sync::MutexGuard<'a, T>>) -> Self {
|
|
ChatError::MutexPoisonError()
|
|
}
|
|
}
|
|
|
|
fn handle_client(mut stream: TcpStream, clients: Arc<Mutex<Vec<ChatUser>>>) -> Result<(), ChatError> {
|
|
let mut writer = stream.try_clone()?;
|
|
let reader = io::BufReader::new(&mut stream);
|
|
let mut lines = reader.lines();
|
|
|
|
// greet and recv nick
|
|
write!(writer, "Hello and welcome to RustChat!\nNick: ")?;
|
|
let nick = lines.next().ok_or(ChatError::Protocol(String::from("Could not recv nickname")))??;
|
|
// FIXME: check for nick name uniqueness
|
|
|
|
// print joined msg and register at server
|
|
send_to_all(&clients, &format!("* {} has joined the chat", nick))?;
|
|
let curr_users = clients.lock()?.borrow().len();
|
|
let users_in_room_msg = if curr_users == 1 { String::from("is 1 user") } else { format!("are {} users", curr_users) };
|
|
write!(writer, "Welcome, {}! You can now start to chat! There {} in the room.\n", nick, users_in_room_msg)?;
|
|
clients.lock()?.borrow_mut().push(ChatUser::new(nick.clone(), writer));
|
|
|
|
// read lines
|
|
for line in lines {
|
|
let msg_raw = line?;
|
|
if msg_raw.is_empty() {
|
|
continue;
|
|
}
|
|
let msg = format!("<{}> {}", nick, msg_raw);
|
|
send_to_all(&clients, &msg)?;
|
|
println!("{}", msg);
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
|
|
fn run_thread(stream: TcpStream, clients: Arc<Mutex<Vec<ChatUser>>>) {
|
|
let quit_msg = match handle_client(stream, clients) {
|
|
Ok(()) => String::from("Connection closed by user's choice"),
|
|
Err(err) => format!("Client foobar't their thread: {:?}", err),
|
|
};
|
|
println!(" * Client has left the chat ({})", quit_msg);
|
|
}
|
|
|
|
|
|
/// Chatserver in RUST!
|
|
#[derive(Parser, Debug)]
|
|
#[clap(about, version, author)]
|
|
struct Args {
|
|
/// IP address to bind to
|
|
#[clap(long, default_value = "0.0.0.0")]
|
|
host: String,
|
|
|
|
/// Port to run the server on
|
|
#[clap(short, long, default_value_t = 1337)]
|
|
port: i16,
|
|
}
|
|
|
|
fn main() -> std::io::Result<()> {
|
|
let args = Args::parse();
|
|
let binding_host = format!("{}:{}", args.host, args.port);
|
|
println!("Binding to {}", binding_host);
|
|
|
|
let listener = TcpListener::bind(binding_host)?;
|
|
let clients = Arc::new(Mutex::new(vec!()));
|
|
|
|
// accept connections and process them serially
|
|
for stream in listener.incoming() {
|
|
if let Ok(stream) = stream {
|
|
let clients_clone = Arc::clone(&clients);
|
|
thread::spawn(move || { run_thread(stream, clients_clone); });
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|