diff --git a/src/main.cpp b/src/main.cpp index a4fe019..4970411 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -1,13 +1,11 @@ //this snippet just opens a server and stops it from closing. #include "nn.hpp" #include -int main(){ - std::string msg; - runServer(9090,"PassTest"); - - //this part stops it from reaching the "return 0" line. - //the timer is so it does not eat up CPU power +int main(){ + runServer(9090, "PassTest"); + // this part stops it from reaching the "return 0" line. + // the timer is so it does not eat up CPU power while(true){ std::this_thread::sleep_for(std::chrono::milliseconds(100)); } diff --git a/src/nn.cpp b/src/nn.cpp index 784dbf5..5887a44 100644 --- a/src/nn.cpp +++ b/src/nn.cpp @@ -1,9 +1,8 @@ -//This library is made by @Nullora on Github. The link can be found here, with documentation aswell: https://github.com/Nullora/NovusNet +//This library is made by @Nullora on Github. The link can be found here, and documentation aswell: https://github.com/Nullora/NovusNet //NovusNet is a c++ networking library made to facilitate connection between devices while keeping it fast and secure. //It's fully free and anyone can distribute/use it. //Last updated: 27/3/26 #include "nn.hpp" -#include #include #include #include @@ -12,78 +11,176 @@ #include #include #include -#include #include #include -#include -#include -#include #include #include #include -#include -#include +#include +#include +#include +#include +#include +#include +#include + +// FIX: Separate send/recv mutexes per connection. +// OpenSSL allows concurrent SSL_read + SSL_write on the same object only when +// each direction is serialised independently. One combined mutex would block +// reads while a large file send is in progress (and vice-versa). +struct ConnMutexes { + std::mutex send_mtx; + std::mutex recv_mtx; +}; + +// ─── Global state ───────────────────────────────────────────────────────────── +static std::map s_clients; +static std::atomic s_client_index{0}; +static std::function s_messageCallback; +static SSL_CTX* s_ssl_ctx = nullptr; +static SSL* s_client_ssl = nullptr; +static std::mutex s_clients_mutex; +static std::map> s_conn_mutexes; +static ConnMutexes s_client_conn_mutexes; // client side + +// ─── Internal helpers ───────────────────────────────────────────────────────── + +static SSL* getSSL(int id) { + std::lock_guard lk(s_clients_mutex); + auto it = s_clients.find(id); + return it != s_clients.end() ? it->second : s_client_ssl; +} + +static std::shared_ptr getConn(int id) { + std::lock_guard lk(s_clients_mutex); + auto it = s_conn_mutexes.find(id); + return it != s_conn_mutexes.end() ? it->second : nullptr; +} +// FIX: _sendRaw / _recvRaw are unlocked. Callers must hold the appropriate mutex +// before calling them. This lets sendFile call _sendRaw for the filename without +// deadlocking against itself (original code called sendMsg while holding io_lock, +// but sendMsg never acquired that lock, so the lock was completely ignored). + +static bool _sendRaw(SSL* ssl, const std::string& msg) { + uint32_t len_net = htonl((uint32_t)msg.size()); + const char* hdr = reinterpret_cast(&len_net); + int hdr_left = (int)sizeof(len_net), hdr_sent = 0; + while (hdr_left > 0) { + int r = SSL_write(ssl, hdr + hdr_sent, hdr_left); + if (r <= 0) return false; + hdr_sent += r; + hdr_left -= r; + } + int left = (int)msg.size(), sent = 0; + while (left > 0) { + int r = SSL_write(ssl, msg.c_str() + sent, left); + if (r <= 0) return false; + sent += r; + left -= r; + } + return true; +} + +static std::string _recvRaw(SSL* ssl) { + uint32_t len_net = 0; + int hdr_left = (int)sizeof(len_net), hdr_recv = 0; + while (hdr_left > 0) { + int r = SSL_read(ssl, reinterpret_cast(&len_net) + hdr_recv, hdr_left); + if (r <= 0) return "EXITED(C-178)"; + hdr_recv += r; + hdr_left -= r; + } + int len = (int)ntohl(len_net); + if (len <= 0 || len > 4 * 1024 * 1024) return "EXITED(C-178)"; + std::string msg(len, '\0'); + int left = len, recvd = 0; + while (left > 0) { + int r = SSL_read(ssl, msg.data() + recvd, left); + if (r <= 0) return "EXITED(C-178)"; + recvd += r; + left -= r; + } + return msg; +} + +static auto printProgress = [](uint64_t done, uint64_t total) { + int pct = (int)((done * 100) / total); + int filled = pct / 5; + std::cout << "\r["; + for (int i = 0; i < 20; ++i) std::cout << (i < filled ? '#' : '-'); + std::cout << "] " << pct << "% " << std::flush; +}; -std::map clients; -std::atomic clients_index = 0; -std::function messageCallback; -SSL_CTX* ssl_ctx = nullptr; -SSL* client_ssl = nullptr; -std::mutex clients_mutex; +// ─── Public API ─────────────────────────────────────────────────────────────── -void onMessage(std::function callback){ - messageCallback = callback; +void onMessage(std::function callback) { + s_messageCallback = callback; } -SSL* getSSL(int id) { - std::lock_guard lock(clients_mutex); - auto it = clients.find(id); - if (it != clients.end()) return it->second; - return client_ssl; // Fallback for client-side +// FIX: sendMsg now acquires the per-client send mutex before touching SSL. +// Previously it called SSL_write with no lock at all; two simultaneous sendMsg +// calls for the same client (e.g. broadcasting from different threads) would +// corrupt the SSL state and crash. +bool sendMsg(std::string msg, int id) { + SSL* ssl = getSSL(id); + if (!ssl) return false; + auto conn = getConn(id); + if (conn) { + std::lock_guard lk(conn->send_mtx); + return _sendRaw(ssl, msg); + } + // Client side: no entry in s_conn_mutexes, use the dedicated client mutex. + std::lock_guard lk(s_client_conn_mutexes.send_mtx); + return _sendRaw(ssl, msg); } +// FIX: recvMsg now acquires the per-client recv mutex. On the server side this +// is only ever called during the auth handshake (before the receive thread +// starts). On the client side it is the normal way to receive messages. +std::string recvMsg(int id) { + SSL* ssl = getSSL(id); + if (!ssl) return "EXITED(C-178)"; + auto conn = getConn(id); + if (conn) { + std::lock_guard lk(conn->recv_mtx); + return _recvRaw(ssl); + } + std::lock_guard lk(s_client_conn_mutexes.recv_mtx); + return _recvRaw(ssl); +} -// spawns a background thread that accepts incoming connections. -// each accepted client gets their own recv thread. void runServer(int port, std::string password) { - // initialize OpenSSL and load certificates SSL_CTX* ctx = SSL_CTX_new(TLS_server_method()); - if (!SSL_CTX_use_certificate_file(ctx, "cert.pem", SSL_FILETYPE_PEM)) { - ERR_print_errors_fp(stderr); - SSL_CTX_free(ctx); - return; - } - if (!SSL_CTX_use_PrivateKey_file(ctx, "key.pem", SSL_FILETYPE_PEM)) { + if (!SSL_CTX_use_certificate_file(ctx, "cert.pem", SSL_FILETYPE_PEM) || + !SSL_CTX_use_PrivateKey_file(ctx, "key.pem", SSL_FILETYPE_PEM)) { ERR_print_errors_fp(stderr); SSL_CTX_free(ctx); return; } - ssl_ctx = ctx; + s_ssl_ctx = ctx; + int server_fd = socket(AF_INET, SOCK_STREAM, 0); - if(server_fd < 0){ perror("socket failed"); return; } + if (server_fd < 0) { perror("socket"); return; } int opt = 1; setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)); sockaddr_in addr{}; - addr.sin_family = AF_INET; + addr.sin_family = AF_INET; addr.sin_addr.s_addr = INADDR_ANY; - addr.sin_port = htons(port); - - int b = bind(server_fd, (sockaddr*)&addr, sizeof(addr)); - if(b < 0){ perror("bind failed"); return; } - int l = listen(server_fd, 32); - if(l < 0){ perror("listen failed"); return; } + addr.sin_port = htons(port); + if (bind(server_fd, (sockaddr*)&addr, sizeof(addr)) < 0) { perror("bind"); return; } + if (listen(server_fd, 32) < 0) { perror("listen"); return; } std::cout << "Server on port " << port << "\n"; - std::thread([server_fd,password]() { + std::thread([server_fd, password]() { while (true) { sockaddr_in client_addr{}; - socklen_t len = sizeof(client_addr); - int new_fd = accept(server_fd, (sockaddr*)&client_addr, &len); - if(new_fd < 0) continue; - //wrap fd in ssl - SSL* ssl = SSL_new(ssl_ctx); + socklen_t clen = sizeof(client_addr); + int new_fd = accept(server_fd, (sockaddr*)&client_addr, &clen); + if (new_fd < 0) continue; + + SSL* ssl = SSL_new(s_ssl_ctx); SSL_set_fd(ssl, new_fd); if (SSL_accept(ssl) <= 0) { ERR_print_errors_fp(stderr); @@ -91,237 +188,201 @@ void runServer(int port, std::string password) { close(new_fd); continue; } - //check password (Access control) - // register the new client - int ci = ++clients_index; + + int ci = ++s_client_index; + auto conn = std::make_shared(); + { + std::lock_guard lk(s_clients_mutex); + s_clients[ci] = ssl; + s_conn_mutexes[ci] = conn; + } + + // Auth: receive password under the recv mutex. + std::string recvdp; { - std::lock_guard lock(clients_mutex); - clients[ci] = ssl; + std::lock_guard lk(conn->recv_mtx); + recvdp = _recvRaw(ssl); } - //check password (Access control) - std::string recvdp = recvMsg(ci); - if(recvdp!=password){ + + if (recvdp != password) { std::cout << "KICKED (wrong password): " << inet_ntoa(client_addr.sin_addr) << "\n"; { - std::lock_guard lock(clients_mutex); - SSL_shutdown(clients[ci]); - SSL_free(clients[ci]); - clients.erase(ci); + std::lock_guard lk(s_clients_mutex); + SSL_shutdown(ssl); + SSL_free(ssl); + s_clients.erase(ci); + s_conn_mutexes.erase(ci); } close(new_fd); - }else{ + } else { std::cout << "CONNECTED: " << inet_ntoa(client_addr.sin_addr) << "\n"; - sendMsg(std::to_string(ci),ci); - // stabilize client_index to pass to thread - int new_ci = ci; - SSL* new_ssl = ssl; - std::thread([new_ci,new_ssl,new_fd]() { + { + std::lock_guard lk(conn->send_mtx); + _sendRaw(ssl, std::to_string(ci)); + } + + std::thread([ci, ssl, new_fd, conn]() { while (true) { - std::string msg = recvMsg(new_ci); - if (msg=="EXITED(C-178)"){ - std::cout<<"client "< lk(conn->recv_mtx); + msg = _recvRaw(ssl); + } + if (msg == "EXITED(C-178)") { + std::cout << "client " << ci << " disconnected, cleaning.\n"; { - std::lock_guard lock(clients_mutex); - SSL_shutdown(clients[new_ci]); - SSL_free(clients[new_ci]); - clients.erase(new_ci); + std::lock_guard lk(s_clients_mutex); + SSL_shutdown(ssl); + SSL_free(ssl); + s_clients.erase(ci); + s_conn_mutexes.erase(ci); } close(new_fd); break; } - if (messageCallback) messageCallback(new_ci, msg); + if (s_messageCallback) s_messageCallback(ci, msg); } }).detach(); } } }).detach(); } -int runClient(std::string ip, int port,std::string password) { - int client = socket(AF_INET, SOCK_STREAM, 0); - if(client<0){ - perror("socket failed"); - return -1; - } - struct sockaddr_in serverAddress; - serverAddress.sin_family = AF_INET; - serverAddress.sin_port = htons(port); - inet_pton(AF_INET, ip.c_str(), &serverAddress.sin_addr); - - if(connect(client, (struct sockaddr*)&serverAddress, sizeof(serverAddress)) == -1){ - perror("connect failed"); - close(client); + +int runClient(std::string ip, int port, std::string password) { + int fd = socket(AF_INET, SOCK_STREAM, 0); + if (fd < 0) { perror("socket"); return -1; } + + sockaddr_in srv{}; + srv.sin_family = AF_INET; + srv.sin_port = htons(port); + inet_pton(AF_INET, ip.c_str(), &srv.sin_addr); + + if (connect(fd, (sockaddr*)&srv, sizeof(srv)) < 0) { + perror("connect"); + close(fd); return -1; } - SSL_CTX* client_ctx = SSL_CTX_new(TLS_client_method()); - SSL_CTX_set_verify(client_ctx, SSL_VERIFY_PEER, nullptr); - SSL_CTX_load_verify_locations(client_ctx, "cert.pem", nullptr); - SSL* ssl = SSL_new(client_ctx); - SSL_CTX_free(client_ctx); - SSL_set_fd(ssl, client); + + SSL_CTX* ctx = SSL_CTX_new(TLS_client_method()); + SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER, nullptr); + SSL_CTX_load_verify_locations(ctx, "cert.pem", nullptr); + // SSL_new increments the CTX refcount, so it's safe to free ctx here. + // The context will be freed when the SSL object is freed. + SSL* ssl = SSL_new(ctx); + SSL_CTX_free(ctx); + SSL_set_fd(ssl, fd); + if (SSL_connect(ssl) <= 0) { ERR_print_errors_fp(stderr); SSL_free(ssl); - close(client); + close(fd); return -1; } - if(SSL_get_verify_result(ssl) != X509_V_OK) { + if (SSL_get_verify_result(ssl) != X509_V_OK) { + std::cerr << "Certificate verification failed\n"; SSL_free(ssl); - close(client); + close(fd); return -1; } - client_ssl = ssl; - sendMsg(password,0); - std::cout << "Request sent\n"; - return client; -} -//NMTS (Novus Message Transfer System) -bool sendMsg(std::string msg, int id) { - SSL* ssl; - { - std::lock_guard lock(clients_mutex); - ssl = (clients.count(id)) ? clients[id] : client_ssl; - } - int msglength = msg.size(); - uint32_t msglengthC = htonl(msglength); - int bytesL = msglength; - int bytesS = 0; - int result=0; - - int headerL = sizeof(msglengthC); - int headerS = 0; - while(headerL > 0){ - result = SSL_write(ssl, (char*)&msglengthC + headerS, headerL); - if(result <= 0){ - perror("send failed"); - return false; - } - headerS += result; - headerL -= result; - } - while (bytesL > 0) { - result = SSL_write(ssl, msg.c_str() + bytesS, bytesL); - if (result <= 0) { - perror("send failed"); - return false; - } - bytesS += result; - bytesL -= result; - } - return true; -} -std::string recvMsg(int id) { - SSL* ssl; + + s_client_ssl = ssl; + // FIX: password is now sent under the client send mutex. { - std::lock_guard lock(clients_mutex); - ssl = (clients.count(id)) ? clients[id] : client_ssl; - } - uint32_t msgL_htonl; - int result=0; - - int headerL = sizeof(msgL_htonl); - int headerR = 0; - while(headerL > 0){ - result = SSL_read(ssl, (char*)&msgL_htonl + headerR, headerL); - if(result <= 0){ - perror("recv failed"); - return "EXITED(C-178)"; - } - headerR += result; - headerL -= result; - } - int msgL = ntohl(msgL_htonl); - if(msgL>4*1024*1024 || msgL<=0) return "EXITED(C-178)"; - int bytesR = 0; - int bytesL = msgL; - std::string msg(msgL, 0); - - while (bytesL > 0) { - result = SSL_read(ssl, msg.data() + bytesR, bytesL); - if (result <= 0) { - perror("recv failed"); - return "EXITED(C-178)"; - } - bytesR += result; - bytesL -= result; + std::lock_guard lk(s_client_conn_mutexes.send_mtx); + _sendRaw(ssl, password); } - return msg; + std::cout << "Request sent\n"; + return fd; } -auto printProgress = [](uint64_t done, uint64_t total) { - int percent = (done * 100) / total; - int filled = percent / 5; // 20 chars wide - std::cout << "\r["; - for (int i = 0; i < 20; i++) - std::cout << (i < filled ? '#' : '-'); - std::cout << "] " << percent << "% " << std::flush; -}; -//NFTP (Novus File Transfer Protocol) + +// FIX: sendFile acquires the send mutex and calls _sendRaw directly for the +// filename. The original code called sendMsg() (no lock) after acquiring +// io_lock — the lock did nothing because sendMsg ignored it entirely. bool sendFile(std::string filepath, int id) { SSL* ssl = getSSL(id); if (!ssl) return false; + auto conn = getConn(id); + + std::mutex& send_mtx = conn ? conn->send_mtx : s_client_conn_mutexes.send_mtx; + std::lock_guard lk(send_mtx); int fd = open(filepath.c_str(), O_RDONLY); - if (fd < 0) return false; + if (fd < 0) { perror("sendFile: open"); return false; } struct stat st; fstat(fd, &st); - uint64_t size = st.st_size; - uint64_t netsize = htobe64(size); - - // Header: Size - if (SSL_write(ssl, &netsize, sizeof(netsize)) <= 0) return false; + uint64_t size = (uint64_t)st.st_size; + uint64_t net_sz = htobe64(size); + + if (SSL_write(ssl, &net_sz, sizeof(net_sz)) <= 0) { close(fd); return false; } - // Header: Name std::string filename = filepath.substr(filepath.find_last_of("/\\") + 1); - if (!sendMsg(filename, id)) return false; - - char buffer[16384]; - uint64_t totalSent = 0; - while (totalSent < size) { - ssize_t bytesRead = read(fd, buffer, sizeof(buffer)); - if (bytesRead <= 0) break; - - int offset = 0; - while (offset < bytesRead) { - int sent = SSL_write(ssl, buffer + offset, bytesRead - offset); - if (sent <= 0) { - close(fd); - return false; - } - offset += sent; - totalSent += sent; + if (!_sendRaw(ssl, filename)) { close(fd); return false; } + + char buf[16384]; + uint64_t total_sent = 0; + while (total_sent < size) { + ssize_t rd = read(fd, buf, sizeof(buf)); + if (rd <= 0) break; + int off = 0; + while (off < (int)rd) { + int sent = SSL_write(ssl, buf + off, (int)rd - off); + if (sent <= 0) { close(fd); return false; } + off += sent; + total_sent += (uint64_t)sent; } - printProgress(totalSent, size); + printProgress(total_sent, size); } + std::cout << "\n"; close(fd); - return true; + return total_sent == size; } -bool recvFile(std::string folderpath, int id){ - SSL* ssl = (clients.count(id)) ? clients[id] : client_ssl; - char buffer[16384]; - uint64_t netsize; - if(SSL_read(ssl,&netsize,sizeof(netsize))<=0){ - perror("file recv failed"); + +// FIX: same as sendFile — holds recv_mtx and calls _recvRaw directly. +bool recvFile(std::string folderpath, int id) { + SSL* ssl = getSSL(id); + if (!ssl) return false; + auto conn = getConn(id); + + std::mutex& recv_mtx = conn ? conn->recv_mtx : s_client_conn_mutexes.recv_mtx; + std::lock_guard lk(recv_mtx); + + uint64_t net_sz = 0; + if (SSL_read(ssl, &net_sz, sizeof(net_sz)) <= 0) { perror("recvFile: size"); return false; } + uint64_t filesize = be64toh(net_sz); + if (filesize == 0 || filesize > 10ULL * 1024 * 1024 * 1024) return false; + + std::string filename = _recvRaw(ssl); + if (filename.empty() || + filename == "EXITED(C-178)" || + filename.size() > 255 || + filename.find('/') != std::string::npos || + filename.find('\\') != std::string::npos) return false; - } - uint64_t filesize = be64toh(netsize); - if(filesize>10ULL*1024*1024*1024 || filesize<=0) return false; - std::string filename = recvMsg(id); - if(filename.empty() || filename.size() > 255) return false; - if(filename.find('/') != std::string::npos) return false; - uint64_t bytesL = filesize; - uint64_t bytesR = 0; - int result; + std::string fullpath = folderpath + "/" + filename; int outfd = open(fullpath.c_str(), O_WRONLY | O_CREAT | O_TRUNC, 0644); - if(outfd < 0) return false; - while(bytesL>0){ - result = SSL_read(ssl, buffer, std::min(bytesL, (uint64_t)sizeof(buffer))); - if(result<=0) return false; - bytesL -= result; - bytesR += result; - write(outfd, buffer, result); - printProgress(bytesR,filesize); + if (outfd < 0) { perror("recvFile: create"); return false; } + + char buf[16384]; + uint64_t left = filesize, recvd = 0; + while (left > 0) { + int to_read = (int)std::min(left, (uint64_t)sizeof(buf)); + int r = SSL_read(ssl, buf, to_read); + if (r <= 0) { close(outfd); return false; } + if (write(outfd, buf, r) != r) { + perror("recvFile: write"); + close(outfd); + return false; + } + left -= (uint64_t)r; + recvd += (uint64_t)r; + printProgress(recvd, filesize); } + std::cout << "\n"; close(outfd); return true; } \ No newline at end of file