1+ #pragma once
2+
3+ #include " recorders.h"
4+ #include < mutex>
5+ #include < thread>
6+ #include < vector>
7+
8+ namespace celerity ::detail {
9+ /* * @brief This class is a wrapper around a 1D vector that allows us to access it as a 2D array.
10+ *
11+ * It is used to send the task hashes to other nodes using MPI while keeping the code simple and readable.
12+ */
13+ template <typename T>
14+ struct mpi_2d_send_wrapper {
15+ public:
16+ const T& operator [](std::pair<int , int > ij) const {
17+ assert (ij.first * m_width + ij.second < m_data.size ());
18+ return m_data[ij.first * m_width + ij.second ];
19+ }
20+
21+ T* data () { return m_data.data (); }
22+
23+ mpi_2d_send_wrapper (size_t width, size_t height) : m_data(width * height), m_width(width){};
24+
25+ private:
26+ std::vector<T> m_data;
27+ const size_t m_width;
28+ };
29+
30+ /* * @brief This class gives a view into a const vector.
31+ *
32+ * It is used to give us the currently unhashed task records while keeping track of the offset and width.
33+ */
34+ template <typename T>
35+ struct window {
36+ public:
37+ window (const std::vector<T>& value) : m_value(value) {}
38+
39+ const T& operator [](size_t i) const {
40+ assert (i >= 0 && i < m_width);
41+ return m_value[m_offset + i];
42+ }
43+
44+ size_t size () {
45+ m_width = m_value.size () - m_offset;
46+ return m_width;
47+ }
48+
49+ void slide (size_t i) {
50+ assert (i == 0 || (i >= 0 && i <= m_width));
51+ m_offset += i;
52+ m_width -= i;
53+ }
54+
55+ private:
56+ const std::vector<T>& m_value;
57+ size_t m_offset = 0 ;
58+ size_t m_width = 0 ;
59+ };
60+
61+ using task_hash = size_t ;
62+ using task_hash_data = mpi_2d_send_wrapper<task_hash>;
63+ using divergence_map = std::unordered_map<task_hash, std::vector<node_id>>;
64+
65+ /* * @brief This class is the base class for the divergence check.
66+ *
67+ * It is responsible for collecting the task hashes from all nodes and checking for differences -> divergence.
68+ * When a divergence is found, the task record for the diverging task is printed and the program is terminated.
69+ * Additionally it also checks for deadlocks and prints a warning if one is detected.
70+ */
71+ class abstract_block_chain {
72+ friend struct abstract_block_chain_testspy ;
73+
74+ public:
75+ virtual void stop () { m_is_running = false ; };
76+
77+ abstract_block_chain (const abstract_block_chain&) = delete ;
78+ abstract_block_chain& operator =(const abstract_block_chain&) = delete ;
79+ abstract_block_chain& operator =(abstract_block_chain&&) = delete ;
80+
81+ abstract_block_chain (abstract_block_chain&&) = default ;
82+
83+ abstract_block_chain (size_t num_nodes, node_id local_nid, const std::vector<task_record>& task_recorder, MPI_Comm comm)
84+ : m_local_nid(local_nid), m_num_nodes(num_nodes), m_sizes(num_nodes), m_task_recorder_window(task_recorder), m_comm(comm) {}
85+
86+ virtual ~abstract_block_chain () = default ;
87+
88+ protected:
89+ void start () { m_is_running = true ; };
90+
91+ virtual void run () = 0;
92+
93+ virtual void divergence_out (const divergence_map& check_map, const int task_num) = 0;
94+
95+ void add_new_hashes ();
96+ void clear (const int min_progress);
97+ virtual void allgather_sizes ();
98+ virtual void allgather_hashes (const int max_size, task_hash_data& data);
99+ std::pair<int , int > collect_sizes ();
100+ task_hash_data collect_hashes (const int max_size);
101+ divergence_map create_check_map (const task_hash_data& task_graphs, const int task_num) const ;
102+
103+ void check_for_deadlock () const ;
104+
105+ static void print_node_divergences (const divergence_map& check_map, const int task_num);
106+
107+ static void print_task_record (const divergence_map& check_map, const task_record& task, const task_hash hash);
108+
109+ virtual void dedub_print_task_record (const divergence_map& check_map, const int task_num) const ;
110+
111+ bool check_for_divergence ();
112+
113+ protected:
114+ node_id m_local_nid;
115+ size_t m_num_nodes;
116+
117+ std::vector<task_hash> m_hashes;
118+ std::vector<int > m_sizes;
119+
120+ bool m_is_running = true ;
121+
122+ window<task_record> m_task_recorder_window;
123+
124+ std::chrono::time_point<std::chrono::steady_clock> m_last_cleared = std::chrono::steady_clock::now();
125+
126+ MPI_Comm m_comm;
127+ };
128+
129+ class divergence_block_chain : public abstract_block_chain {
130+ public:
131+ void start ();
132+ void stop () override ;
133+
134+ divergence_block_chain (size_t num_nodes, node_id local_nid, const std::vector<task_record>& task_record, MPI_Comm comm, bool test_mode = false )
135+ : abstract_block_chain(num_nodes, local_nid, task_record, comm), m_test_mode(test_mode) {
136+ divergence_block_chain::start ();
137+ }
138+
139+ divergence_block_chain (const divergence_block_chain&) = delete ;
140+ divergence_block_chain& operator =(const divergence_block_chain&) = delete ;
141+ divergence_block_chain& operator =(divergence_block_chain&&) = delete ;
142+
143+ divergence_block_chain (divergence_block_chain&&) = default ;
144+
145+ ~divergence_block_chain () override { divergence_block_chain::stop (); }
146+
147+ private:
148+ void run () override ;
149+
150+ void divergence_out (const divergence_map& check_map, const int task_num) override ;
151+
152+ private:
153+ std::thread m_thread;
154+ bool m_test_mode = false ;
155+ };
156+ }; // namespace celerity::detail
0 commit comments