1+ #pragma once
2+
3+ #include " recorders.h"
4+ #include < mutex>
5+ #include < thread>
6+ #include < vector>
7+
8+ namespace celerity ::detail {
9+ // in c++23 replace this with mdspan
10+ template <typename T>
11+ struct mpi_multidim_send_wrapper {
12+ public:
13+ const T& operator [](std::pair<int , int > ij) const {
14+ assert (ij.first * m_width + ij.second < m_data.size ());
15+ return m_data[ij.first * m_width + ij.second ];
16+ }
17+
18+ T* data () { return m_data.data (); }
19+
20+ mpi_multidim_send_wrapper (size_t width, size_t height) : m_data(width * height), m_width(width){};
21+
22+ private:
23+ std::vector<T> m_data;
24+ const size_t m_width;
25+ };
26+
27+ // Probably replace this in c++20 with span
28+ template <typename T>
29+ struct window {
30+ public:
31+ window (const std::vector<T>& value) : m_value(value) {}
32+
33+ const T& operator [](size_t i) const {
34+ assert (i >= 0 && i < m_width);
35+ return m_value[m_offset + i];
36+ }
37+
38+ size_t size () {
39+ m_width = m_value.size () - m_offset;
40+ return m_width;
41+ }
42+
43+ void slide (size_t i) {
44+ assert (i == 0 || (i >= 0 && i <= m_width));
45+ m_offset += i;
46+ m_width -= i;
47+ }
48+
49+ private:
50+ const std::vector<T>& m_value;
51+ size_t m_offset = 0 ;
52+ size_t m_width = 0 ;
53+ };
54+
55+ using task_hash = size_t ;
56+ using task_hash_data = mpi_multidim_send_wrapper<task_hash>;
57+ using divergence_map = std::unordered_map<task_hash, std::vector<node_id>>;
58+
59+ class abstract_block_chain {
60+ friend struct abstract_block_chain_testspy ;
61+
62+ public:
63+ virtual void stop () { m_is_running = false ; };
64+
65+ abstract_block_chain (const abstract_block_chain&) = delete ;
66+ abstract_block_chain& operator =(const abstract_block_chain&) = delete ;
67+ abstract_block_chain& operator =(abstract_block_chain&&) = delete ;
68+
69+ abstract_block_chain (abstract_block_chain&&) = default ;
70+
71+ abstract_block_chain (size_t num_nodes, node_id local_nid, const std::vector<task_record>& task_recorder, MPI_Comm comm)
72+ : m_local_nid(local_nid), m_num_nodes(num_nodes), m_sizes(num_nodes), m_task_recorder_window(task_recorder), m_comm(comm) {}
73+
74+ virtual ~abstract_block_chain () = default ;
75+
76+ protected:
77+ void start () { m_is_running = true ; };
78+
79+ virtual void run () = 0;
80+
81+ virtual void divergence_out (const divergence_map& check_map, const int task_num) = 0;
82+
83+ void add_new_hashes ();
84+ void clear (const int min_progress);
85+ virtual void allgather_sizes ();
86+ virtual void allgather_hashes (const int max_size, task_hash_data& data);
87+ std::pair<int , int > collect_sizes ();
88+ task_hash_data collect_hashes (const int max_size);
89+ divergence_map create_check_map (const task_hash_data& task_graphs, const int task_num) const ;
90+
91+ void check_for_deadlock () const ;
92+
93+ static void print_node_divergences (const divergence_map& check_map, const int task_num);
94+
95+ static void print_task_record (const divergence_map& check_map, const task_record& task, const task_hash hash);
96+
97+ virtual void dedub_print_task_record (const divergence_map& check_map, const int task_num) const ;
98+
99+ bool check_for_divergence ();
100+
101+ protected:
102+ node_id m_local_nid;
103+ size_t m_num_nodes;
104+
105+ std::vector<task_hash> m_hashes;
106+ std::vector<int > m_sizes;
107+
108+ bool m_is_running = true ;
109+
110+ window<task_record> m_task_recorder_window;
111+
112+ std::chrono::time_point<std::chrono::steady_clock> m_last_cleared = std::chrono::steady_clock::now();
113+
114+ MPI_Comm m_comm;
115+ };
116+
117+ class single_node_test_divergence_block_chain : public abstract_block_chain {
118+ public:
119+ single_node_test_divergence_block_chain (size_t num_nodes, node_id local_nid, const std::vector<task_record>& task_recorder, MPI_Comm comm,
120+ const std::vector<std::reference_wrapper<const std::vector<task_record>>>& other_task_records)
121+ : abstract_block_chain(num_nodes, local_nid, task_recorder, comm), m_other_hashes(other_task_records.size()) {
122+ for (auto & tsk_rcd : other_task_records) {
123+ m_other_task_records.push_back (window<task_record>(tsk_rcd));
124+ }
125+ }
126+
127+ private:
128+ void run () override {}
129+
130+ void divergence_out (const divergence_map& check_map, const int task_num) override ;
131+ void allgather_sizes () override ;
132+ void allgather_hashes (const int max_size, task_hash_data& data) override ;
133+
134+ void dedub_print_task_record (const divergence_map& check_map, const int task_num) const override ;
135+
136+ std::vector<std::vector<task_hash>> m_other_hashes;
137+ std::vector<window<task_record>> m_other_task_records;
138+
139+ int m_injected_delete_size = 0 ;
140+ };
141+
142+ class distributed_test_divergence_block_chain : public abstract_block_chain {
143+ public:
144+ distributed_test_divergence_block_chain (size_t num_nodes, node_id local_nid, const std::vector<task_record>& task_record, MPI_Comm comm)
145+ : abstract_block_chain(num_nodes, local_nid, task_record, comm) {}
146+
147+ private:
148+ void run () override {}
149+
150+ void divergence_out (const divergence_map& check_map, const int task_num) override ;
151+ };
152+
153+ class divergence_block_chain : public abstract_block_chain {
154+ public:
155+ void start ();
156+ void stop () override ;
157+
158+ divergence_block_chain (size_t num_nodes, node_id local_nid, const std::vector<task_record>& task_record, MPI_Comm comm)
159+ : abstract_block_chain(num_nodes, local_nid, task_record, comm) {
160+ start ();
161+ }
162+
163+ divergence_block_chain (const divergence_block_chain&) = delete ;
164+ divergence_block_chain& operator =(const divergence_block_chain&) = delete ;
165+ divergence_block_chain& operator =(divergence_block_chain&&) = delete ;
166+
167+ divergence_block_chain (divergence_block_chain&&) = default ;
168+
169+ ~divergence_block_chain () override { divergence_block_chain::stop (); }
170+
171+ private:
172+ void run () override ;
173+
174+ void divergence_out (const divergence_map& check_map, const int task_num) override ;
175+
176+ private:
177+ std::thread m_thread;
178+ };
179+ } // namespace celerity::detail
0 commit comments