|
2 | 2 | #define EAMXX_PY_SESSION_HPP
|
3 | 3 |
|
4 | 4 | #include <pybind11/pybind11.h>
|
| 5 | +#include <pybind11/embed.h> |
5 | 6 |
|
| 7 | +#include <ekat_assert.hpp> |
| 8 | +#include <ekat_fpe.hpp> |
| 9 | + |
| 10 | +#include <filesystem> |
6 | 11 | #include <string>
|
7 | 12 | #include <any>
|
8 | 13 |
|
@@ -67,6 +72,80 @@ class PySession {
|
67 | 72 | std::any py_guard;
|
68 | 73 | };
|
69 | 74 |
|
| 75 | +// ==================== IMPLEMENTATION ===================== // |
| 76 | + |
| 77 | +inline void PySession::initialize () { |
| 78 | + if (num_customers==0) { |
| 79 | + // Note: if Py interpreter is already inited, we ASSUME someone else |
| 80 | + // is handling the interpreter initialization/finalization |
| 81 | + if (not Py_IsInitialized()) { |
| 82 | + py_guard = std::make_shared<pybind11::scoped_interpreter>(); |
| 83 | + } |
| 84 | + } |
| 85 | + ++num_customers; |
| 86 | +} |
| 87 | + |
| 88 | +inline void PySession::finalize () { |
| 89 | + EKAT_REQUIRE_MSG (num_customers>0, |
| 90 | + "Error! Invalid number of customers.\n" |
| 91 | + " Did you call PySession::finalize() without calling PySession::initialize()?\n"); |
| 92 | + |
| 93 | + --num_customers; |
| 94 | + if (num_customers==0) { |
| 95 | + py_guard.reset(); |
| 96 | + } |
| 97 | +} |
| 98 | + |
| 99 | +inline void PySession::add_path (const std::string& path) |
| 100 | +{ |
| 101 | + EKAT_REQUIRE_MSG (is_initialized(), |
| 102 | + "Error! Cannot modify python's sys.path, since PySession was not initialized yet.\n"); |
| 103 | + |
| 104 | + try { |
| 105 | + // Import the sys module |
| 106 | + pybind11::module sysModule = pybind11::module::import("sys"); |
| 107 | + |
| 108 | + // Get the sys.path list |
| 109 | + pybind11::list sysPath = sysModule.attr("path"); |
| 110 | + |
| 111 | + // Append the new path to sys.path |
| 112 | + sysPath.append(path); |
| 113 | + } catch (const pybind11::error_already_set& e) { |
| 114 | + std::cerr << "Error: " << e.what() << std::endl; |
| 115 | + throw std::runtime_error("Could not modify sys.path. Aborting."); |
| 116 | + } |
| 117 | +} |
| 118 | + |
| 119 | +inline void PySession::add_curr_path () |
| 120 | +{ |
| 121 | + auto curr_path = std::filesystem::current_path(); |
| 122 | + add_path(curr_path.string()); |
| 123 | +} |
| 124 | + |
| 125 | +inline pybind11::module PySession::safe_import (const std::string& module_name) const |
| 126 | +{ |
| 127 | + pybind11::module m; |
| 128 | + |
| 129 | + // Disable FPEs while loading the module, then immediately re-enable them |
| 130 | + auto fpes = ekat::get_enabled_fpes(); |
| 131 | + ekat::disable_all_fpes(); |
| 132 | + try { |
| 133 | + m = pybind11::module::import(module_name.c_str()); |
| 134 | + } catch (const pybind11::error_already_set& e) { |
| 135 | + std::cout << "[PySession::safe_import] Error! Python module import failed.\n" |
| 136 | + " - module name: " + module_name + "\n" |
| 137 | + " - pybind11 error: " + std::string(e.what()) + "\n" |
| 138 | + "Did you forget to call PySession::add_path to add the module location to sys.path?\n"; |
| 139 | + throw e; |
| 140 | + } |
| 141 | + ekat::enable_fpes(fpes); |
| 142 | + |
| 143 | + EKAT_REQUIRE_MSG (not m.is_none(), |
| 144 | + "Error! Could not import module '" + module_name + "'.\n"); |
| 145 | + |
| 146 | + return m; |
| 147 | +} |
| 148 | + |
70 | 149 | } // namespace scream
|
71 | 150 |
|
72 | 151 | #endif // EAMXX_PY_SESSION_HPP
|
0 commit comments