unilink  0.4.3
A simple C++ library for unified async communication
tcp_server.cc
Go to the documentation of this file.
1 /*
2  * Copyright 2025 Jinwoo Sung
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
18 
19 #include <atomic>
20 #include <boost/asio.hpp>
21 #include <future>
22 #include <iostream>
23 #include <mutex>
24 #include <thread>
25 #include <unordered_map>
26 
33 
34 namespace unilink {
35 namespace transport {
36 
37 namespace net = boost::asio;
38 using tcp = net::ip::tcp;
39 
41  std::atomic<bool> stopping_{false};
42  std::atomic<size_t> next_client_id_{0};
43 
44  std::unique_ptr<net::io_context> owned_ioc_;
45  bool owns_ioc_;
47  net::io_context& ioc_;
48  std::thread ioc_thread_;
49 
50  std::unique_ptr<interface::TcpAcceptorInterface> acceptor_;
52 
60 
61  mutable std::mutex sessions_mutex_;
62  std::unordered_map<size_t, std::shared_ptr<TcpServerSession>> sessions_;
63 
64  size_t max_clients_;
66  bool paused_accept_ = false;
67 
68  std::shared_ptr<TcpServerSession> current_session_;
69 
70  explicit Impl(const config::TcpServerConfig& cfg)
71  : owns_ioc_(false),
72  uses_global_ioc_(true),
73  ioc_(concurrency::IoContextManager::instance().get_context()),
74  cfg_(cfg),
75  max_clients_(cfg.max_connections > 0 ? static_cast<size_t>(cfg.max_connections) : 0),
76  client_limit_enabled_(cfg.max_connections > 0) {
77  try {
78  acceptor_ = std::make_unique<BoostTcpAcceptor>(ioc_);
79  } catch (const std::exception& e) {
80  throw std::runtime_error("Failed to create TCP acceptor: " + std::string(e.what()));
81  }
82  }
83 
84  Impl(const config::TcpServerConfig& cfg, std::unique_ptr<interface::TcpAcceptorInterface> acceptor,
85  net::io_context& ioc)
86  : owns_ioc_(false),
87  uses_global_ioc_(false),
88  ioc_(ioc),
89  acceptor_(std::move(acceptor)),
90  cfg_(cfg),
91  max_clients_(cfg.max_connections > 0 ? static_cast<size_t>(cfg.max_connections) : 0),
92  client_limit_enabled_(cfg.max_connections > 0) {
93  if (!acceptor_) {
94  throw std::runtime_error("Failed to create TCP acceptor");
95  }
96  }
97 
98  ~Impl() {
99  try {
100  stopping_.store(true);
101  if (owns_ioc_) {
102  ioc_.stop();
103  }
104  if (ioc_thread_.joinable()) {
105  if (std::this_thread::get_id() != ioc_thread_.get_id()) {
106  ioc_thread_.join();
107  } else {
108  ioc_thread_.detach();
109  }
110  }
111  perform_cleanup();
112  } catch (...) {
113  }
114  }
115 
116  void notify_state() {
117  if (stopping_.load()) return;
118  OnState cb;
119  try {
120  {
121  std::lock_guard<std::mutex> lock(sessions_mutex_);
122  cb = on_state_;
123  }
124  if (cb) {
125  cb(state_.get_state());
126  }
127  } catch (...) {
128  }
129  }
130 
131  void attempt_port_binding(std::shared_ptr<TcpServer> self, int retry_count) {
132  if (stopping_.load()) return;
133  boost::system::error_code ec;
134 
135  auto address = net::ip::make_address(cfg_.bind_address, ec);
136  if (ec) {
137  UNILINK_LOG_ERROR("tcp_server", "bind", "Invalid bind address: " + cfg_.bind_address + ", " + ec.message());
139  notify_state();
140  return;
141  }
142 
143  if (!acceptor_->is_open()) {
144  acceptor_->open(address.is_v6() ? tcp::v6() : tcp::v4(), ec);
145  if (ec) {
146  UNILINK_LOG_ERROR("tcp_server", "open", "Failed to open acceptor: " + ec.message());
148  notify_state();
149  return;
150  }
151  }
152 
153  acceptor_->bind(tcp::endpoint(address, cfg_.port), ec);
154  if (ec) {
155  if (cfg_.enable_port_retry && retry_count < cfg_.max_port_retries) {
156  auto timer = std::make_shared<net::steady_timer>(ioc_);
157  timer->expires_after(std::chrono::milliseconds(cfg_.port_retry_interval_ms));
158  timer->async_wait([self, retry_count, timer](const boost::system::error_code& timer_ec) {
159  if (!timer_ec) {
160  auto impl = self->get_impl();
161  if (!impl->stopping_.load()) {
162  impl->attempt_port_binding(self, retry_count + 1);
163  }
164  }
165  });
166  return;
167  } else {
168  UNILINK_LOG_ERROR("tcp_server", "bind",
169  "Failed to bind to port " + std::to_string(cfg_.port) + ": " + ec.message());
171  notify_state();
172  return;
173  }
174  }
175 
176  acceptor_->listen(boost::asio::socket_base::max_listen_connections, ec);
177  if (ec) {
178  UNILINK_LOG_ERROR("tcp_server", "listen",
179  "Failed to listen on port " + std::to_string(cfg_.port) + ": " + ec.message());
181  notify_state();
182  return;
183  }
184 
186  notify_state();
187  do_accept(self);
188  }
189 
190  void do_accept(std::shared_ptr<TcpServer> self) {
191  if (stopping_.load() || !acceptor_ || !acceptor_->is_open()) return;
192 
193  acceptor_->async_accept([self](auto ec, tcp::socket sock) {
194  auto impl = self->get_impl();
195  if (impl->stopping_.load()) {
196  return;
197  }
198  if (ec) {
199  if (ec != boost::asio::error::operation_aborted) {
200  impl->state_.set_state(base::LinkState::Error);
201  impl->notify_state();
202  }
203  if (!impl->state_.is_state(base::LinkState::Closed) && !impl->stopping_.load()) {
204  auto timer = std::make_shared<net::steady_timer>(impl->ioc_);
205  timer->expires_after(std::chrono::milliseconds(100));
206  timer->async_wait([self, timer](const boost::system::error_code&) {
207  auto impl = self->get_impl();
208  if (!impl->stopping_.load()) {
209  impl->do_accept(self);
210  }
211  });
212  }
213  return;
214  }
215 
216  boost::system::error_code ep_ec;
217  auto rep = sock.remote_endpoint(ep_ec);
218  std::string client_info = "unknown";
219  if (!ep_ec) {
220  client_info = rep.address().to_string() + ":" + std::to_string(rep.port());
221  }
222 
223  if (impl->client_limit_enabled_) {
224  std::lock_guard<std::mutex> lock(impl->sessions_mutex_);
225  if (impl->sessions_.size() >= impl->max_clients_) {
226  boost::system::error_code close_ec;
227  sock.close(close_ec);
228  impl->paused_accept_ = true;
229  return;
230  }
231  }
232 
233  auto new_session = std::make_shared<TcpServerSession>(
234  impl->ioc_, std::move(sock), impl->cfg_.backpressure_threshold, impl->cfg_.idle_timeout_ms);
235 
236  size_t client_id;
237  {
238  std::lock_guard<std::mutex> lock(impl->sessions_mutex_);
239  client_id = impl->next_client_id_.fetch_add(1);
240  impl->sessions_.emplace(client_id, new_session);
241  impl->current_session_ = new_session;
242  }
243 
244  std::weak_ptr<TcpServer> weak_self = self;
245 
246  new_session->on_bytes([weak_self, client_id](memory::ConstByteSpan data) {
247  auto self = weak_self.lock();
248  if (!self) return;
249  auto impl = self->get_impl();
250 
251  OnBytes cb;
252  MultiClientDataHandler multi_cb;
253  {
254  std::lock_guard<std::mutex> lock(impl->sessions_mutex_);
255  cb = impl->on_bytes_;
256  multi_cb = impl->on_multi_data_;
257  }
258  if (cb) cb(data);
259  if (multi_cb) {
260  std::string str_data = common::safe_convert::uint8_to_string(data.data(), data.size());
261  multi_cb(client_id, str_data);
262  }
263  });
264 
265  if (impl->on_bp_) new_session->on_backpressure(impl->on_bp_);
266 
267  new_session->on_close([weak_self, client_id, new_session] {
268  auto self = weak_self.lock();
269  if (!self) return;
270  auto impl = self->get_impl();
271  if (impl->stopping_.load()) return;
272 
273  MultiClientDisconnectHandler disconnect_cb;
274  {
275  std::lock_guard<std::mutex> lock(impl->sessions_mutex_);
276  disconnect_cb = impl->on_multi_disconnect_;
277  }
278  if (disconnect_cb) disconnect_cb(client_id);
279 
280  bool was_current = false;
281  {
282  std::lock_guard<std::mutex> lock(impl->sessions_mutex_);
283  impl->sessions_.erase(client_id);
284  if (impl->paused_accept_ && (!impl->client_limit_enabled_ || impl->sessions_.size() < impl->max_clients_)) {
285  impl->paused_accept_ = false;
286  net::post(impl->ioc_, [self] { self->get_impl()->do_accept(self); });
287  }
288  was_current = (impl->current_session_ == new_session);
289  if (was_current) {
290  if (!impl->sessions_.empty())
291  impl->current_session_ = impl->sessions_.begin()->second;
292  else
293  impl->current_session_.reset();
294  }
295  }
296  if (was_current) {
297  impl->state_.set_state(base::LinkState::Listening);
298  impl->notify_state();
299  }
300  });
301 
302  MultiClientConnectHandler connect_cb;
303  {
304  std::lock_guard<std::mutex> lock(impl->sessions_mutex_);
305  connect_cb = impl->on_multi_connect_;
306  }
307  if (connect_cb) connect_cb(client_id, client_info);
308 
309  impl->state_.set_state(base::LinkState::Connected);
310  impl->notify_state();
311  new_session->start();
312  impl->do_accept(self);
313  });
314  }
315 
317  try {
318  boost::system::error_code ec;
319  if (acceptor_ && acceptor_->is_open()) {
320  acceptor_->close(ec);
321  }
322 
323  std::vector<std::shared_ptr<TcpServerSession>> sessions_copy;
324  {
325  std::lock_guard<std::mutex> lock(sessions_mutex_);
326  sessions_copy.reserve(sessions_.size());
327  for (auto& kv : sessions_) {
328  sessions_copy.push_back(kv.second);
329  }
330  sessions_.clear();
331  current_session_.reset();
332  }
333 
334  for (auto& session : sessions_copy) {
335  if (session) {
336  session->stop();
337  }
338  }
339 
340  state_.set_state(base::LinkState::Closed);
341  notify_state();
342  } catch (...) {
343  }
344  }
345 
346  void stop(std::shared_ptr<TcpServer> self) {
347  if (stopping_.exchange(true)) {
348  return;
349  }
350 
351  {
352  std::lock_guard<std::mutex> lock(sessions_mutex_);
353  on_bytes_ = nullptr;
354  on_state_ = nullptr;
355  on_bp_ = nullptr;
356  on_multi_connect_ = nullptr;
357  on_multi_data_ = nullptr;
358  on_multi_disconnect_ = nullptr;
359  }
360 
361  if (ioc_.get_executor().running_in_this_thread()) {
362  perform_cleanup();
363  if (owns_ioc_) ioc_.stop();
364  return;
365  }
366 
367  bool has_active_ioc = owns_ioc_ || (uses_global_ioc_ && concurrency::IoContextManager::instance().is_running());
368 
369  if (has_active_ioc && self) {
370  auto cleanup_promise = std::make_shared<std::promise<void>>();
371  auto cleanup_future = cleanup_promise->get_future();
372 
373  std::weak_ptr<TcpServer> weak_self = self;
374  net::dispatch(ioc_, [weak_self, cleanup_promise]() {
375  if (auto shared_self = weak_self.lock()) {
376  shared_self->get_impl()->perform_cleanup();
377  }
378  cleanup_promise->set_value();
379  });
380 
381  if (cleanup_future.wait_for(std::chrono::seconds(2)) == std::future_status::timeout) {
382  perform_cleanup();
383  }
384  } else {
385  perform_cleanup();
386  }
387 
388  if (owns_ioc_ && ioc_thread_.joinable()) {
389  ioc_thread_.join();
390  ioc_.restart();
391  }
392  }
393 };
394 
395 std::shared_ptr<TcpServer> TcpServer::create(const config::TcpServerConfig& cfg) {
396  return std::shared_ptr<TcpServer>(new TcpServer(cfg));
397 }
398 
399 std::shared_ptr<TcpServer> TcpServer::create(const config::TcpServerConfig& cfg,
400  std::unique_ptr<interface::TcpAcceptorInterface> acceptor,
401  net::io_context& ioc) {
402  return std::shared_ptr<TcpServer>(new TcpServer(cfg, std::move(acceptor), ioc));
403 }
404 
405 TcpServer::TcpServer(const config::TcpServerConfig& cfg) : impl_(std::make_unique<Impl>(cfg)) {}
406 
407 TcpServer::TcpServer(const config::TcpServerConfig& cfg, std::unique_ptr<interface::TcpAcceptorInterface> acceptor,
408  net::io_context& ioc)
409  : impl_(std::make_unique<Impl>(cfg, std::move(acceptor), ioc)) {}
410 
412  if (impl_ && !impl_->state_.is_state(base::LinkState::Closed)) {
413  // Pass nullptr to stop() to indicate we are in destructor and cannot use shared_from_this
414  impl_->stop(nullptr);
415  }
416 }
417 
418 TcpServer::TcpServer(TcpServer&&) noexcept = default;
419 TcpServer& TcpServer::operator=(TcpServer&&) noexcept = default;
420 
421 void TcpServer::start() {
422  auto impl = get_impl();
423  auto current = impl->state_.get_state();
424  if (current == base::LinkState::Listening || current == base::LinkState::Connected ||
425  current == base::LinkState::Connecting) {
426  return;
427  }
428  impl->stopping_.store(false);
429 
430  if (impl->uses_global_ioc_) {
431  auto& manager = concurrency::IoContextManager::instance();
432  if (!manager.is_running()) {
433  manager.start();
434  }
435  }
436 
437  if (!impl->acceptor_) {
438  impl->state_.set_state(base::LinkState::Error);
439  impl->notify_state();
440  return;
441  }
442 
443  if (impl->owns_ioc_) {
444  impl->ioc_thread_ = std::thread([impl] { impl->ioc_.run(); });
445  }
446  auto self = shared_from_this();
447  if (impl->ioc_.get_executor().running_in_this_thread()) {
448  if (!impl->stopping_.load()) {
449  impl->attempt_port_binding(self, 0);
450  }
451  } else {
452  net::dispatch(impl->ioc_, [self] {
453  auto impl = self->get_impl();
454  if (impl->stopping_.load()) return;
455  impl->attempt_port_binding(self, 0);
456  });
457  }
458 }
459 
460 void TcpServer::stop() { impl_->stop(shared_from_this()); }
461 
463  auto impl = get_impl();
464  if (impl->stopping_.load()) return;
465  auto self = shared_from_this();
466  net::post(impl->ioc_, [self] { self->stop(); });
467 }
468 
470  auto impl = get_impl();
471  std::lock_guard<std::mutex> lock(impl->sessions_mutex_);
472  return impl->current_session_ && impl->current_session_->alive();
473 }
474 
476  auto impl = get_impl();
477  if (impl->stopping_.load()) return;
478  std::shared_ptr<TcpServerSession> session;
479  {
480  std::lock_guard<std::mutex> lock(impl->sessions_mutex_);
481  session = impl->current_session_;
482  }
483 
484  if (session && session->alive()) {
485  session->async_write_copy(data);
486  }
487 }
488 
489 void TcpServer::async_write_move(std::vector<uint8_t>&& data) {
490  auto impl = get_impl();
491  if (impl->stopping_.load()) return;
492  std::shared_ptr<TcpServerSession> session;
493  {
494  std::lock_guard<std::mutex> lock(impl->sessions_mutex_);
495  session = impl->current_session_;
496  }
497 
498  if (session && session->alive()) {
499  session->async_write_move(std::move(data));
500  }
501 }
502 
503 void TcpServer::async_write_shared(std::shared_ptr<const std::vector<uint8_t>> data) {
504  auto impl = get_impl();
505  if (impl->stopping_.load() || !data) return;
506  std::shared_ptr<TcpServerSession> session;
507  {
508  std::lock_guard<std::mutex> lock(impl->sessions_mutex_);
509  session = impl->current_session_;
510  }
511 
512  if (session && session->alive()) {
513  session->async_write_shared(std::move(data));
514  }
515 }
516 
518  std::lock_guard<std::mutex> lock(impl_->sessions_mutex_);
519  impl_->on_bytes_ = std::move(cb);
520 }
522  std::lock_guard<std::mutex> lock(impl_->sessions_mutex_);
523  impl_->on_state_ = std::move(cb);
524 }
526  auto impl = get_impl();
527  {
528  std::lock_guard<std::mutex> lock(impl->sessions_mutex_);
529  impl->on_bp_ = std::move(cb);
530  }
531  std::shared_ptr<TcpServerSession> session;
532  {
533  std::lock_guard<std::mutex> lock(impl->sessions_mutex_);
534  session = impl->current_session_;
535  }
536 
537  if (session) session->on_backpressure(impl->on_bp_);
538 }
539 
540 bool TcpServer::broadcast(const std::string& message) {
541  auto impl = get_impl();
542  auto shared_data = std::make_shared<const std::vector<uint8_t>>(message.begin(), message.end());
543  std::lock_guard<std::mutex> lock(impl->sessions_mutex_);
544  bool sent = false;
545  for (auto& entry : impl->sessions_) {
546  auto& session = entry.second;
547  if (session && session->alive()) {
548  session->async_write_shared(shared_data);
549  sent = true;
550  }
551  }
552  return sent;
553 }
554 
555 bool TcpServer::send_to_client(size_t client_id, const std::string& message) {
556  auto impl = get_impl();
557  std::lock_guard<std::mutex> lock(impl->sessions_mutex_);
558  auto it = impl->sessions_.find(client_id);
559  if (it != impl->sessions_.end() && it->second && it->second->alive()) {
560  auto binary_view = common::safe_convert::string_to_bytes(message);
561  it->second->async_write_copy(memory::ConstByteSpan(binary_view.first, binary_view.second));
562  return true;
563  }
564  return false;
565 }
566 
568  auto impl = get_impl();
569  std::lock_guard<std::mutex> lock(impl->sessions_mutex_);
570  size_t alive = 0;
571  for (const auto& entry : impl->sessions_)
572  if (entry.second && entry.second->alive()) ++alive;
573  return alive;
574 }
575 
576 std::vector<size_t> TcpServer::get_connected_clients() const {
577  auto impl = get_impl();
578  std::lock_guard<std::mutex> lock(impl->sessions_mutex_);
579  std::vector<size_t> connected_clients;
580  connected_clients.reserve(impl->sessions_.size());
581  for (const auto& entry : impl->sessions_)
582  if (entry.second && entry.second->alive()) connected_clients.push_back(entry.first);
583  return connected_clients;
584 }
585 
587  std::lock_guard<std::mutex> l(impl_->sessions_mutex_);
588  impl_->on_multi_connect_ = std::move(h);
589 }
591  std::lock_guard<std::mutex> l(impl_->sessions_mutex_);
592  impl_->on_multi_data_ = std::move(h);
593 }
595  std::lock_guard<std::mutex> l(impl_->sessions_mutex_);
596  impl_->on_multi_disconnect_ = std::move(h);
597 }
598 
599 void TcpServer::set_client_limit(size_t max) {
600  auto impl = get_impl();
601  std::lock_guard<std::mutex> l(impl->sessions_mutex_);
602  impl->max_clients_ = max;
603  impl->client_limit_enabled_ = true;
604  if (impl->paused_accept_ && impl->sessions_.size() < impl->max_clients_) {
605  impl->paused_accept_ = false;
606  net::post(impl->ioc_, [self = shared_from_this()] { self->get_impl()->do_accept(self); });
607  }
608 }
609 
611  auto impl = get_impl();
612  std::lock_guard<std::mutex> l(impl->sessions_mutex_);
613  impl->client_limit_enabled_ = false;
614  impl->max_clients_ = 0;
615  if (impl->paused_accept_) {
616  impl->paused_accept_ = false;
617  net::post(impl->ioc_, [self = shared_from_this()] { self->get_impl()->do_accept(self); });
618  }
619 }
620 
621 base::LinkState TcpServer::get_state() const { return get_impl()->state_.get_state(); }
622 
623 } // namespace transport
624 } // namespace unilink
#define UNILINK_LOG_ERROR(component, operation, message)
Definition: logger.hpp:279