Skip to content

Commit 1504fe4

Browse files
authored
Merge branch 'bartgol/eamxx/pysession-header-only' (PR #7704)
Avoids recompilation of several copies of the class [BFB]
2 parents 520eb03 + 7b4acf0 commit 1504fe4

File tree

3 files changed

+82
-88
lines changed

3 files changed

+82
-88
lines changed

components/eamxx/src/share/core/CMakeLists.txt

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,11 @@ if (EAMXX_ENABLE_PYTHON)
7373

7474
find_package(pybind11 REQUIRED HINTS ${pybind11_ROOT})
7575

76-
# Add py sources and link pybind11
77-
target_sources (eamxx_core PUBLIC
78-
eamxx_pysession.cpp)
79-
8076
target_link_libraries(eamxx_core PUBLIC pybind11::embed)
8177
target_compile_definitions (eamxx_core PUBLIC EAMXX_HAS_PYTHON)
78+
79+
# Used in eamxx_pysession.hpp to get current path
80+
target_link_libraries(eamxx_core PUBLIC stdc++fs)
8281
endif()
8382

8483
if (NOT SCREAM_LIB_ONLY)

components/eamxx/src/share/core/eamxx_pysession.cpp

Lines changed: 0 additions & 84 deletions
This file was deleted.

components/eamxx/src/share/core/eamxx_pysession.hpp

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@
22
#define EAMXX_PY_SESSION_HPP
33

44
#include <pybind11/pybind11.h>
5+
#include <pybind11/embed.h>
56

7+
#include <ekat_assert.hpp>
8+
#include <ekat_fpe.hpp>
9+
10+
#include <filesystem>
611
#include <string>
712
#include <any>
813

@@ -67,6 +72,80 @@ class PySession {
6772
std::any py_guard;
6873
};
6974

75+
// ==================== IMPLEMENTATION ===================== //
76+
77+
inline void PySession::initialize () {
78+
if (num_customers==0) {
79+
// Note: if Py interpreter is already inited, we ASSUME someone else
80+
// is handling the interpreter initialization/finalization
81+
if (not Py_IsInitialized()) {
82+
py_guard = std::make_shared<pybind11::scoped_interpreter>();
83+
}
84+
}
85+
++num_customers;
86+
}
87+
88+
inline void PySession::finalize () {
89+
EKAT_REQUIRE_MSG (num_customers>0,
90+
"Error! Invalid number of customers.\n"
91+
" Did you call PySession::finalize() without calling PySession::initialize()?\n");
92+
93+
--num_customers;
94+
if (num_customers==0) {
95+
py_guard.reset();
96+
}
97+
}
98+
99+
inline void PySession::add_path (const std::string& path)
100+
{
101+
EKAT_REQUIRE_MSG (is_initialized(),
102+
"Error! Cannot modify python's sys.path, since PySession was not initialized yet.\n");
103+
104+
try {
105+
// Import the sys module
106+
pybind11::module sysModule = pybind11::module::import("sys");
107+
108+
// Get the sys.path list
109+
pybind11::list sysPath = sysModule.attr("path");
110+
111+
// Append the new path to sys.path
112+
sysPath.append(path);
113+
} catch (const pybind11::error_already_set& e) {
114+
std::cerr << "Error: " << e.what() << std::endl;
115+
throw std::runtime_error("Could not modify sys.path. Aborting.");
116+
}
117+
}
118+
119+
inline void PySession::add_curr_path ()
120+
{
121+
auto curr_path = std::filesystem::current_path();
122+
add_path(curr_path.string());
123+
}
124+
125+
inline pybind11::module PySession::safe_import (const std::string& module_name) const
126+
{
127+
pybind11::module m;
128+
129+
// Disable FPEs while loading the module, then immediately re-enable them
130+
auto fpes = ekat::get_enabled_fpes();
131+
ekat::disable_all_fpes();
132+
try {
133+
m = pybind11::module::import(module_name.c_str());
134+
} catch (const pybind11::error_already_set& e) {
135+
std::cout << "[PySession::safe_import] Error! Python module import failed.\n"
136+
" - module name: " + module_name + "\n"
137+
" - pybind11 error: " + std::string(e.what()) + "\n"
138+
"Did you forget to call PySession::add_path to add the module location to sys.path?\n";
139+
throw e;
140+
}
141+
ekat::enable_fpes(fpes);
142+
143+
EKAT_REQUIRE_MSG (not m.is_none(),
144+
"Error! Could not import module '" + module_name + "'.\n");
145+
146+
return m;
147+
}
148+
70149
} // namespace scream
71150

72151
#endif // EAMXX_PY_SESSION_HPP

0 commit comments

Comments
 (0)