unilink  0.4.3
A simple C++ library for unified async communication
thread_safe_state.hpp
Go to the documentation of this file.
1 /*
2  * Copyright 2025 Jinwoo Sung
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <algorithm>
20 #include <atomic>
21 #include <chrono>
22 #include <condition_variable>
23 #include <functional>
24 #include <mutex>
25 #include <shared_mutex>
26 #include <vector>
27 
28 #include "unilink/base/common.hpp"
29 
30 namespace unilink {
31 namespace concurrency {
32 
40 template <typename StateType>
42  public:
43  using State = StateType;
44  using StateCallback = std::function<void(const State&)>;
45  using StateCallbackHandle = size_t;
46 
47  // Constructors
48  explicit ThreadSafeState(const State& initial_state = State{});
49  ThreadSafeState(const ThreadSafeState&) = delete;
53 
54  // State access methods
55  State get_state() const;
56  void set_state(const State& new_state);
57  void set_state(State&& new_state);
58 
59  // Atomic state operations
60  bool compare_and_set(const State& expected, const State& desired);
61  State exchange(const State& new_state);
62 
63  // State change notifications
67 
68  // Wait for state change
69  void wait_for_state(const State& expected_state, std::chrono::milliseconds timeout = std::chrono::milliseconds(1000));
70  void wait_for_state_change(std::chrono::milliseconds timeout = std::chrono::milliseconds(1000));
71 
72  // Utility methods
73  bool is_state(const State& expected_state) const;
75 
76  private:
77  mutable std::shared_mutex state_mutex_;
78  State state_;
79  std::atomic<bool> state_changed_{false};
80 
81  struct CallbackInfo {
82  StateCallbackHandle handle;
83  StateCallback callback;
84  };
85  std::vector<CallbackInfo> callbacks_;
86  StateCallbackHandle next_handle_{1};
87  mutable std::mutex callbacks_mutex_;
88 
89  std::condition_variable_any state_cv_;
90 
91  void notify_callbacks(const State& new_state);
92 };
93 
100 template <typename StateType>
101 class AtomicState {
102  public:
103  using State = StateType;
104 
105  explicit AtomicState(const State& initial_state = State{});
106 
107  State get() const noexcept;
108  void set(const State& new_state) noexcept;
109  void set(State&& new_state) noexcept;
110 
111  bool compare_and_set(const State& expected, const State& desired) noexcept;
112  State exchange(const State& new_state) noexcept;
113 
114  bool is_state(const State& expected_state) const noexcept;
115 
116  private:
117  std::atomic<State> state_;
118 };
119 
124  public:
125  explicit ThreadSafeCounter(int64_t initial_value = 0);
126 
127  int64_t get() const noexcept;
128  int64_t increment() noexcept;
129  int64_t decrement() noexcept;
130  int64_t add(int64_t value) noexcept;
131  int64_t subtract(int64_t value) noexcept;
132 
133  bool compare_and_set(int64_t expected, int64_t desired) noexcept;
134  int64_t exchange(int64_t new_value) noexcept;
135 
136  void reset() noexcept;
137 
138  private:
139  std::atomic<int64_t> value_;
140 };
141 
146  public:
147  explicit ThreadSafeFlag(bool initial_value = false);
148 
149  bool get() const noexcept;
150  void set(bool value = true) noexcept;
151  void clear() noexcept;
152 
153  bool test_and_set() noexcept;
154  bool compare_and_set(bool expected, bool desired) noexcept;
155 
156  void wait_for_true(std::chrono::milliseconds timeout = std::chrono::milliseconds(1000)) const;
157  void wait_for_false(std::chrono::milliseconds timeout = std::chrono::milliseconds(1000)) const;
158 
159  private:
160  std::atomic<bool> flag_;
161  mutable std::condition_variable cv_;
162  mutable std::mutex cv_mutex_;
163 };
164 
165 // Specialization for LinkState
168 
169 // Template implementations (must be in header for template instantiation)
170 template <typename StateType>
171 ThreadSafeState<StateType>::ThreadSafeState(const State& initial_state) : state_(initial_state) {}
172 
173 template <typename StateType>
175  std::shared_lock<std::shared_mutex> lock(state_mutex_);
176  return state_;
177 }
178 
179 template <typename StateType>
181  {
182  std::unique_lock<std::shared_mutex> lock(state_mutex_);
183  state_ = new_state;
184  state_changed_.store(true);
185  }
186  notify_callbacks(new_state);
187  state_cv_.notify_all();
188 }
189 
190 template <typename StateType>
192  {
193  std::unique_lock<std::shared_mutex> lock(state_mutex_);
194  state_ = std::move(new_state);
195  state_changed_.store(true);
196  }
197  notify_callbacks(state_);
198  state_cv_.notify_all();
199 }
200 
201 template <typename StateType>
202 bool ThreadSafeState<StateType>::compare_and_set(const State& expected, const State& desired) {
203  std::unique_lock<std::shared_mutex> lock(state_mutex_);
204  if (state_ == expected) {
205  state_ = desired;
206  state_changed_.store(true);
207  lock.unlock();
208  notify_callbacks(desired);
209  state_cv_.notify_all();
210  return true;
211  }
212  return false;
213 }
214 
215 template <typename StateType>
216 StateType ThreadSafeState<StateType>::exchange(const State& new_state) {
217  State old_state;
218  {
219  std::unique_lock<std::shared_mutex> lock(state_mutex_);
220  old_state = state_;
221  state_ = new_state;
222  state_changed_.store(true);
223  }
224  notify_callbacks(new_state);
225  state_cv_.notify_all();
226  return old_state;
227 }
228 
229 template <typename StateType>
231  StateCallback callback) {
232  std::lock_guard<std::mutex> lock(callbacks_mutex_);
233  StateCallbackHandle handle = next_handle_++;
234  callbacks_.push_back({handle, std::move(callback)});
235  return handle;
236 }
237 
238 template <typename StateType>
240  std::lock_guard<std::mutex> lock(callbacks_mutex_);
241  callbacks_.erase(std::remove_if(callbacks_.begin(), callbacks_.end(),
242  [handle](const CallbackInfo& info) { return info.handle == handle; }),
243  callbacks_.end());
244 }
245 
246 template <typename StateType>
248  std::lock_guard<std::mutex> lock(callbacks_mutex_);
249  callbacks_.clear();
250 }
251 
252 template <typename StateType>
253 void ThreadSafeState<StateType>::wait_for_state(const State& expected_state, std::chrono::milliseconds timeout) {
254  std::unique_lock<std::shared_mutex> lock(state_mutex_);
255  state_cv_.wait_for(lock, timeout, [this, &expected_state] { return state_ == expected_state; });
256 }
257 
258 template <typename StateType>
259 void ThreadSafeState<StateType>::wait_for_state_change(std::chrono::milliseconds timeout) {
260  std::unique_lock<std::shared_mutex> lock(state_mutex_);
261  state_cv_.wait_for(lock, timeout, [this] { return state_changed_.load(); });
262  state_changed_.store(false);
263 }
264 
265 template <typename StateType>
266 bool ThreadSafeState<StateType>::is_state(const State& expected_state) const {
267  std::shared_lock<std::shared_mutex> lock(state_mutex_);
268  return state_ == expected_state;
269 }
270 
271 template <typename StateType>
273  state_cv_.notify_all();
274 }
275 
276 template <typename StateType>
277 void ThreadSafeState<StateType>::notify_callbacks(const State& new_state) {
278  std::lock_guard<std::mutex> lock(callbacks_mutex_);
279  for (const auto& info : callbacks_) {
280  try {
281  info.callback(new_state);
282  } catch (...) {
283  // Ignore callback exceptions to prevent state corruption
284  }
285  }
286 }
287 
288 // AtomicState template implementations
289 template <typename StateType>
290 AtomicState<StateType>::AtomicState(const State& initial_state) : state_(initial_state) {}
291 
292 template <typename StateType>
293 StateType AtomicState<StateType>::get() const noexcept {
294  return state_.load();
295 }
296 
297 template <typename StateType>
298 void AtomicState<StateType>::set(const State& new_state) noexcept {
299  state_.store(new_state);
300 }
301 
302 template <typename StateType>
303 void AtomicState<StateType>::set(State&& new_state) noexcept {
304  state_.store(new_state);
305 }
306 
307 template <typename StateType>
308 bool AtomicState<StateType>::compare_and_set(const State& expected, const State& desired) noexcept {
309  State expected_copy = expected;
310  return state_.compare_exchange_strong(expected_copy, desired);
311 }
312 
313 template <typename StateType>
314 StateType AtomicState<StateType>::exchange(const State& new_state) noexcept {
315  return state_.exchange(new_state);
316 }
317 
318 template <typename StateType>
319 bool AtomicState<StateType>::is_state(const State& expected_state) const noexcept {
320  return state_.load() == expected_state;
321 }
322 
323 // ThreadSafeCounter implementations
324 inline ThreadSafeCounter::ThreadSafeCounter(int64_t initial_value) : value_(initial_value) {}
325 
326 inline int64_t ThreadSafeCounter::get() const noexcept { return value_.load(); }
327 
328 inline int64_t ThreadSafeCounter::increment() noexcept { return value_.fetch_add(1) + 1; }
329 
330 inline int64_t ThreadSafeCounter::decrement() noexcept { return value_.fetch_sub(1) - 1; }
331 
332 inline int64_t ThreadSafeCounter::add(int64_t value) noexcept { return value_.fetch_add(value) + value; }
333 
334 inline int64_t ThreadSafeCounter::subtract(int64_t value) noexcept { return value_.fetch_sub(value) - value; }
335 
336 inline bool ThreadSafeCounter::compare_and_set(int64_t expected, int64_t desired) noexcept {
337  return value_.compare_exchange_strong(expected, desired);
338 }
339 
340 inline int64_t ThreadSafeCounter::exchange(int64_t new_value) noexcept { return value_.exchange(new_value); }
341 
342 inline void ThreadSafeCounter::reset() noexcept { value_.store(0); }
343 
344 // ThreadSafeFlag implementations
345 inline ThreadSafeFlag::ThreadSafeFlag(bool initial_value) : flag_(initial_value) {}
346 
347 inline bool ThreadSafeFlag::get() const noexcept { return flag_.load(); }
348 
349 inline void ThreadSafeFlag::set(bool value) noexcept {
350  flag_.store(value);
351  if (value) {
352  cv_.notify_all();
353  }
354 }
355 
356 inline void ThreadSafeFlag::clear() noexcept { flag_.store(false); }
357 
358 inline bool ThreadSafeFlag::test_and_set() noexcept { return flag_.exchange(true); }
359 
360 inline bool ThreadSafeFlag::compare_and_set(bool expected, bool desired) noexcept {
361  return flag_.compare_exchange_strong(expected, desired);
362 }
363 
364 inline void ThreadSafeFlag::wait_for_true(std::chrono::milliseconds timeout) const {
365  std::unique_lock<std::mutex> lock(cv_mutex_);
366  cv_.wait_for(lock, timeout, [this] { return flag_.load(); });
367 }
368 
369 inline void ThreadSafeFlag::wait_for_false(std::chrono::milliseconds timeout) const {
370  std::unique_lock<std::mutex> lock(cv_mutex_);
371  cv_.wait_for(lock, timeout, [this] { return !flag_.load(); });
372 }
373 
374 } // namespace concurrency
375 } // namespace unilink