|  | 
|  | 1 | +#pragma once | 
|  | 2 | + | 
|  | 3 | +#include <mutex> | 
|  | 4 | +#include <thread> | 
|  | 5 | +#include <vector> | 
|  | 6 | + | 
|  | 7 | +#include "communicator.h" | 
|  | 8 | +#include "recorders.h" | 
|  | 9 | + | 
|  | 10 | +namespace celerity::detail { | 
|  | 11 | +struct runtime_testspy; | 
|  | 12 | +} | 
|  | 13 | + | 
|  | 14 | +namespace celerity::detail::divergence_checker_detail { | 
|  | 15 | +using task_hash = size_t; | 
|  | 16 | +using divergence_map = std::unordered_map<task_hash, std::vector<node_id>>; | 
|  | 17 | + | 
|  | 18 | +/** | 
|  | 19 | + * @brief Stores the hashes of tasks for each node. | 
|  | 20 | + * | 
|  | 21 | + * The data is stored densely so it can easily be exchanged through MPI collective operations. | 
|  | 22 | + */ | 
|  | 23 | +struct per_node_task_hashes { | 
|  | 24 | +  public: | 
|  | 25 | +	per_node_task_hashes(const size_t max_hash_count, const size_t num_nodes) : m_data(max_hash_count * num_nodes), m_max_hash_count(max_hash_count){}; | 
|  | 26 | +	const task_hash& operator()(const node_id nid, const size_t i) const { return m_data.at(nid * m_max_hash_count + i); } | 
|  | 27 | +	task_hash* data() { return m_data.data(); } | 
|  | 28 | + | 
|  | 29 | +  private: | 
|  | 30 | +	std::vector<task_hash> m_data; | 
|  | 31 | +	size_t m_max_hash_count; | 
|  | 32 | +}; | 
|  | 33 | + | 
|  | 34 | +/** | 
|  | 35 | + *  @brief This class checks for divergences of tasks between nodes. | 
|  | 36 | + * | 
|  | 37 | + *  It is responsible for collecting the task hashes from all nodes and checking for differences -> divergence. | 
|  | 38 | + *  When a divergence is found, the task record for the diverging task is printed and the program is terminated. | 
|  | 39 | + *  Additionally it will also print a warning when a deadlock is suspected. | 
|  | 40 | + */ | 
|  | 41 | + | 
|  | 42 | +class divergence_block_chain { | 
|  | 43 | +	friend struct divergence_block_chain_testspy; | 
|  | 44 | + | 
|  | 45 | +  public: | 
|  | 46 | +	divergence_block_chain(task_recorder& task_recorder, std::unique_ptr<communicator> comm) | 
|  | 47 | +	    : m_local_nid(comm->get_local_nid()), m_num_nodes(comm->get_num_nodes()), m_per_node_hash_counts(comm->get_num_nodes()), | 
|  | 48 | +	      m_communicator(std::move(comm)) { | 
|  | 49 | +		task_recorder.add_callback([this](const task_record& task) { add_new_task(task); }); | 
|  | 50 | +	} | 
|  | 51 | + | 
|  | 52 | +	divergence_block_chain(const divergence_block_chain&) = delete; | 
|  | 53 | +	divergence_block_chain(divergence_block_chain&&) = delete; | 
|  | 54 | + | 
|  | 55 | +	~divergence_block_chain() = default; | 
|  | 56 | + | 
|  | 57 | +	divergence_block_chain& operator=(const divergence_block_chain&) = delete; | 
|  | 58 | +	divergence_block_chain& operator=(divergence_block_chain&&) = delete; | 
|  | 59 | + | 
|  | 60 | +	bool check_for_divergence(); | 
|  | 61 | + | 
|  | 62 | +  private: | 
|  | 63 | +	node_id m_local_nid; | 
|  | 64 | +	size_t m_num_nodes; | 
|  | 65 | + | 
|  | 66 | +	std::vector<task_hash> m_local_hashes; | 
|  | 67 | +	std::vector<task_record> m_task_records; | 
|  | 68 | +	size_t m_tasks_checked = 0; | 
|  | 69 | +	size_t m_hashes_added = 0; | 
|  | 70 | + | 
|  | 71 | +	std::vector<int> m_per_node_hash_counts; | 
|  | 72 | +	std::mutex m_task_records_mutex; | 
|  | 73 | + | 
|  | 74 | +	std::chrono::time_point<std::chrono::steady_clock> m_last_cleared = std::chrono::steady_clock::now(); | 
|  | 75 | + | 
|  | 76 | +	std::unique_ptr<communicator> m_communicator; | 
|  | 77 | + | 
|  | 78 | +	void divergence_out(const divergence_map& check_map, const int task_num); | 
|  | 79 | + | 
|  | 80 | +	void add_new_hashes(); | 
|  | 81 | +	void clear(const int min_progress); | 
|  | 82 | +	std::pair<int, int> collect_hash_counts(); | 
|  | 83 | +	per_node_task_hashes collect_hashes(const int min_hash_count) const; | 
|  | 84 | +	divergence_map create_check_map(const per_node_task_hashes& task_hashes, const int task_num) const; | 
|  | 85 | + | 
|  | 86 | +	void check_for_deadlock() const; | 
|  | 87 | + | 
|  | 88 | +	static void log_node_divergences(const divergence_map& check_map, const int task_num); | 
|  | 89 | +	static void log_task_record(const divergence_map& check_map, const task_record& task, const task_hash hash); | 
|  | 90 | +	void log_task_record_once(const divergence_map& check_map, const int task_num); | 
|  | 91 | + | 
|  | 92 | +	void add_new_task(const task_record& task); | 
|  | 93 | +	task_record thread_save_get_task_record(const size_t task_num); | 
|  | 94 | +}; | 
|  | 95 | + | 
|  | 96 | +class divergence_checker { | 
|  | 97 | +	friend struct ::celerity::detail::runtime_testspy; | 
|  | 98 | + | 
|  | 99 | +  public: | 
|  | 100 | +	divergence_checker(task_recorder& task_recorder, std::unique_ptr<communicator> comm, bool test_mode = false) | 
|  | 101 | +	    : m_block_chain(task_recorder, std::move(comm)) { | 
|  | 102 | +		if(!test_mode) { start(); } | 
|  | 103 | +	} | 
|  | 104 | + | 
|  | 105 | +	divergence_checker(const divergence_checker&) = delete; | 
|  | 106 | +	divergence_checker(const divergence_checker&&) = delete; | 
|  | 107 | + | 
|  | 108 | +	divergence_checker& operator=(const divergence_checker&) = delete; | 
|  | 109 | +	divergence_checker& operator=(divergence_checker&&) = delete; | 
|  | 110 | + | 
|  | 111 | +	~divergence_checker() { stop(); } | 
|  | 112 | + | 
|  | 113 | +  private: | 
|  | 114 | +	void start() { | 
|  | 115 | +		m_thread = std::thread(&divergence_checker::run, this); | 
|  | 116 | +		m_is_running = true; | 
|  | 117 | +	} | 
|  | 118 | + | 
|  | 119 | +	void stop() { | 
|  | 120 | +		m_is_running = false; | 
|  | 121 | +		if(m_thread.joinable()) { m_thread.join(); } | 
|  | 122 | +	} | 
|  | 123 | + | 
|  | 124 | +	void run() { | 
|  | 125 | +		bool is_finished = false; | 
|  | 126 | +		while(!is_finished || m_is_running) { | 
|  | 127 | +			is_finished = m_block_chain.check_for_divergence(); | 
|  | 128 | + | 
|  | 129 | +			std::this_thread::sleep_for(std::chrono::milliseconds(100)); | 
|  | 130 | +		} | 
|  | 131 | +	} | 
|  | 132 | + | 
|  | 133 | +	std::thread m_thread; | 
|  | 134 | +	bool m_is_running = false; | 
|  | 135 | +	divergence_block_chain m_block_chain; | 
|  | 136 | +}; | 
|  | 137 | +}; // namespace celerity::detail::divergence_checker_detail | 
0 commit comments