Skip to content

Commit 8cd030d

Browse files
authored
Merge Pull Request #2950 from E3SM-Project/scream/mahf708/nb-pyscream
Automatically Merged using E3SM Pull Request AutoTester PR Title: move pyscream to nb PR Author: mahf708 PR LABELS: AT: AUTOMERGE, python
2 parents 9ba5edc + d8ae8f7 commit 8cd030d

File tree

10 files changed

+110
-89
lines changed

10 files changed

+110
-89
lines changed

components/eamxx/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,8 +272,8 @@ set(SCREAM_BASE_DIR ${CMAKE_CURRENT_SOURCE_DIR})
272272
set(SCREAM_SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}/src)
273273
set(SCREAM_BIN_DIR ${CMAKE_CURRENT_BINARY_DIR})
274274

275-
option (EAMXX_ENABLE_PYBIND "Whether to enable python interface to eamxx, via pybind11" OFF)
276-
if (EAMXX_ENABLE_PYBIND)
275+
option (EAMXX_ENABLE_PYSCREAM "Whether to enable python interface to eamxx" OFF)
276+
if (EAMXX_ENABLE_PYSCREAM)
277277
# Pybind11 requires shared libraries
278278
set (BUILD_SHARED_LIBS ON)
279279
endif()

components/eamxx/src/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,6 @@ if (PROJECT_NAME STREQUAL "E3SM")
88
add_subdirectory(mct_coupling)
99
endif()
1010

11-
if (EAMXX_ENABLE_PYBIND)
11+
if (EAMXX_ENABLE_PYSCREAM)
1212
add_subdirectory(python)
1313
endif()

components/eamxx/src/python/libpyscream/CMakeLists.txt

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
1-
find_package(pybind11 REQUIRED)
1+
# Detect the installed nanobind package and import it into CMake
2+
find_package(Python COMPONENTS Interpreter Development REQUIRED)
3+
execute_process(
4+
COMMAND "${PYTHON_EXECUTABLE}" -m nanobind --cmake_dir
5+
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE nanobind_ROOT)
6+
find_package(nanobind REQUIRED)
27
find_package(mpi4py REQUIRED)
38

4-
pybind11_add_module(pyscream_ext pyscream_ext.cpp)
9+
nanobind_add_module(pyscream_ext pyscream_ext.cpp)
510
target_link_libraries(pyscream_ext PUBLIC
611
mpi4py
712
scream_share

components/eamxx/src/python/libpyscream/pyatmproc.hpp

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@
1212

1313
#include <ekat/io/ekat_yaml.hpp>
1414

15-
#include <pybind11/pybind11.h>
16-
#include <pybind11/stl.h>
15+
#include <nanobind/nanobind.h>
16+
#include <nanobind/stl/list.h>
17+
#include <nanobind/stl/vector.h>
18+
19+
namespace nb = nanobind;
1720

1821
namespace scream {
1922

@@ -26,7 +29,7 @@ struct PyAtmProc {
2629

2730
std::shared_ptr<OutputManager> output_mgr;
2831

29-
PyAtmProc (const pybind11::dict& d, const std::string& name)
32+
PyAtmProc (const nb::dict& d, const std::string& name)
3033
{
3134
PyParamList params(d,name);
3235

@@ -102,7 +105,7 @@ struct PyAtmProc {
102105
ap->initialize(t0,RunType::Initial);
103106
}
104107

105-
pybind11::list read_ic (const std::string& ic_filename) {
108+
std::vector<std::string> read_ic (const std::string& ic_filename) {
106109
// Get input fields, and read them from file (if present).
107110
// If field is not in the IC, user is responsible for setting
108111
// it to an initial value
@@ -128,10 +131,10 @@ struct PyAtmProc {
128131
}
129132
scorpio::release_file(ic_filename);
130133

131-
return pybind11::cast(missing);
134+
return missing;
132135
}
133136

134-
pybind11::list list_fields(std::string ftype) {
137+
std::vector<std::string> list_fields(std::string ftype) {
135138
std::vector<std::string> fields_list;
136139
for (const auto& field_pair : fields) {
137140
const auto& field_identifier = field_pair.second.f.get_header().get_identifier();
@@ -144,18 +147,18 @@ struct PyAtmProc {
144147
fields_list.push_back(field_pair.first);
145148
}
146149
}
147-
return pybind11::cast(fields_list);
150+
return fields_list;
148151
}
149152

150-
pybind11::list list_all_fields() {
153+
std::vector<std::string> list_all_fields() {
151154
return list_fields("all");
152155
}
153156

154-
pybind11::list list_required_fields() {
157+
std::vector<std::string> list_required_fields() {
155158
return list_fields("required");
156159
}
157160

158-
pybind11::list list_computed_fields() {
161+
std::vector<std::string> list_computed_fields() {
159162
return list_fields("computed");
160163
}
161164

@@ -194,10 +197,10 @@ struct PyAtmProc {
194197
};
195198

196199
// Register type in the py module
197-
inline void pybind_pyatmproc(pybind11::module& m)
200+
inline void nb_pyatmproc(nb::module_& m)
198201
{
199-
pybind11::class_<PyAtmProc>(m,"AtmProc")
200-
.def(pybind11::init<const pybind11::dict&,const std::string&>())
202+
nb::class_<PyAtmProc>(m,"AtmProc")
203+
.def(nb::init<const nb::dict&,const std::string&>())
201204
.def("get_field",&PyAtmProc::get_field)
202205
.def("initialize",&PyAtmProc::initialize)
203206
.def("get_params",&PyAtmProc::get_params)

components/eamxx/src/python/libpyscream/pyfield.hpp

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
#include "share/field/field.hpp"
55
#include "share/field/field_utils.hpp"
66

7-
#include <pybind11/pybind11.h>
8-
#include <pybind11/numpy.h>
9-
#include <pybind11/stl.h>
7+
#include <nanobind/nanobind.h>
8+
#include <nanobind/ndarray.h>
9+
#include <nanobind/stl/list.h>
10+
11+
namespace nb = nanobind;
1012

1113
namespace scream {
1214

@@ -24,7 +26,8 @@ struct PyField {
2426
f.allocate_view();
2527
}
2628

27-
pybind11::array get () const {
29+
template <typename FRAMEWORK>
30+
nb::ndarray<FRAMEWORK> get () const {
2831
const auto& fh = f.get_header();
2932
const auto& fid = fh.get_identifier();
3033

@@ -39,10 +42,14 @@ struct PyField {
3942
// NOTE: since the field may be padded, the strides do not necessarily
4043
// match the dims. Also, the strides must be grabbed from the
4144
// actual view, since the layout doesn't know them.
42-
pybind11::array::ShapeContainer shape (fid.get_layout().dims());
45+
int shape_t = f.rank();
46+
size_t shape[shape_t] = {0};
47+
for (int i=0; i<shape_t; ++i) {
48+
shape[i] = fid.get_layout().dims()[i];
49+
}
4350
std::vector<ssize_t> strides;
4451

45-
pybind11::dtype dt;
52+
nb::dlpack::dtype dt;
4653
switch (fid.data_type()) {
4754
case DataType::IntType:
4855
dt = get_dt_and_set_strides<int>(strides);
@@ -59,8 +66,8 @@ struct PyField {
5966

6067
// NOTE: you MUST set the parent handle, or else you won't have view semantic
6168
auto data = f.get_internal_view_data_unsafe<void,Host>();
62-
auto this_obj = pybind11::cast(this);
63-
return pybind11::array(dt,shape,strides,data,pybind11::handle(this_obj));
69+
auto this_obj = nb::cast(this);
70+
return nb::ndarray<FRAMEWORK>(data, shape_t, shape, nb::handle(this_obj), strides.data(), dt);
6471
}
6572

6673
void sync_to_host () {
@@ -75,7 +82,7 @@ struct PyField {
7582
private:
7683

7784
template<typename T>
78-
pybind11::dtype get_dt_and_set_strides (std::vector<ssize_t>& strides) const
85+
nb::dlpack::dtype get_dt_and_set_strides (std::vector<ssize_t>& strides) const
7986
{
8087
strides.resize(f.rank());
8188
switch (f.rank()) {
@@ -126,15 +133,15 @@ struct PyField {
126133
" - field rnak: " + std::to_string(f.rank()) + "\n");
127134
}
128135

129-
return pybind11::dtype::of<T>();
136+
return nb::dtype<T>();
130137
}
131138
};
132139

133-
inline void pybind_pyfield (pybind11::module& m) {
140+
inline void nb_pyfield (nb::module_& m) {
134141
// Field class
135-
pybind11::class_<PyField>(m,"Field")
136-
.def(pybind11::init<>())
137-
.def("get",&PyField::get)
142+
nb::class_<PyField>(m,"Field")
143+
.def(nb::init<>())
144+
.def("get",&PyField::get<nb::numpy>)
138145
.def("sync_to_host",&PyField::sync_to_host)
139146
.def("sync_to_dev",&PyField::sync_to_dev)
140147
.def("print",&PyField::print);

components/eamxx/src/python/libpyscream/pygrid.hpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@
55

66
#include "pyscream_ext.hpp"
77

8-
#include <pybind11/pybind11.h>
8+
#include <nanobind/nanobind.h>
9+
#include <nanobind/stl/string.h>
910

1011
#include <mpi.h>
1112

13+
namespace nb = nanobind;
14+
1215
namespace scream {
1316

1417
inline void create_grids_manager (int ncols, int nlevs, const std::string& latlon_nc_file)
@@ -38,9 +41,9 @@ inline void create_grids_manager (int ncols, int nlevs)
3841
create_grids_manager(ncols,nlevs,"");
3942
}
4043

41-
inline void pybind_pygrid (pybind11::module& m) {
42-
m.def("create_grids_manager",pybind11::overload_cast<int,int>(&create_grids_manager));
43-
m.def("create_grids_manager",pybind11::overload_cast<int,int,const std::string&>(&create_grids_manager));
44+
inline void nb_pygrid (nb::module_& m) {
45+
m.def("create_grids_manager",nb::overload_cast<int,int>(&create_grids_manager));
46+
m.def("create_grids_manager",nb::overload_cast<int,int,const std::string&>(&create_grids_manager));
4447
}
4548

4649
} // namespace scream

components/eamxx/src/python/libpyscream/pyparamlist.hpp

Lines changed: 42 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@
33

44
#include <ekat/ekat_parameter_list.hpp>
55

6-
#include <pybind11/pybind11.h>
7-
#include <pybind11/numpy.h>
8-
#include <pybind11/stl.h>
6+
#include <nanobind/nanobind.h>
7+
#include <nanobind/stl/list.h>
8+
#include <nanobind/stl/string.h>
9+
#include <nanobind/stl/vector.h>
910

1011
#include <functional>
1112

13+
namespace nb = nanobind;
14+
1215
namespace scream {
1316

1417
struct PyParamList {
@@ -19,11 +22,11 @@ struct PyParamList {
1922
: pl_ref(src)
2023
{}
2124

22-
PyParamList(const pybind11::dict& d)
25+
PyParamList(const nb::dict& d)
2326
: PyParamList(d,"")
2427
{}
2528

26-
PyParamList(const pybind11::dict& d, const std::string& name)
29+
PyParamList(const nb::dict& d, const std::string& name)
2730
: pl(name)
2831
, pl_ref(pl)
2932
{
@@ -69,70 +72,70 @@ struct PyParamList {
6972

7073
private:
7174

72-
void parse_dict(const pybind11::dict& d, ekat::ParameterList& p) {
75+
void parse_dict(const nb::dict& d, ekat::ParameterList& p) {
7376
for (auto item : d) {
74-
const std::string key = pybind11::str(item.first);
75-
if (pybind11::isinstance<pybind11::str>(item.second)) {
76-
auto pystr = pybind11::str(item.second);
77-
p.set<std::string>(key,pystr.cast<std::string>());
78-
} else if (pybind11::isinstance<pybind11::bool_>(item.second)) {
79-
auto pyint = pybind11::cast<pybind11::bool_>(item.second);
80-
p.set(key,pyint.cast<bool>());
81-
} else if (pybind11::isinstance<pybind11::int_>(item.second)) {
82-
auto pyint = pybind11::cast<pybind11::int_>(item.second);
83-
p.set(key,pyint.cast<int>());
84-
} else if (pybind11::isinstance<pybind11::float_>(item.second)) {
85-
auto pydouble = pybind11::cast<pybind11::float_>(item.second);
86-
p.set(key,pydouble.cast<double>());
87-
} else if (pybind11::isinstance<pybind11::list>(item.second)) {
88-
auto pylist = pybind11::cast<pybind11::list>(item.second);
77+
auto key = nb::cast<std::string>(item.first);
78+
if (nb::isinstance<nb::str>(item.second)) {
79+
auto pystr = nb::str(item.second);
80+
p.set<std::string>(key,nb::cast<std::string>(pystr));
81+
} else if (nb::isinstance<nb::bool_>(item.second)) {
82+
auto pyint = nb::cast<nb::bool_>(item.second);
83+
p.set(key,nb::cast<bool>(pyint));
84+
} else if (nb::isinstance<nb::int_>(item.second)) {
85+
auto pyint = nb::cast<nb::int_>(item.second);
86+
p.set(key,nb::cast<int>(pyint));
87+
} else if (nb::isinstance<nb::float_>(item.second)) {
88+
auto pydouble = nb::cast<nb::float_>(item.second);
89+
p.set(key,nb::cast<double>(pydouble));
90+
} else if (nb::isinstance<nb::list>(item.second)) {
91+
auto pylist = nb::cast<nb::list>(item.second);
8992
parse_list(pylist,p,key);
90-
} else if (pybind11::isinstance<pybind11::dict>(item.second)) {
91-
auto pydict = pybind11::cast<pybind11::dict>(item.second);
93+
} else if (nb::isinstance<nb::dict>(item.second)) {
94+
auto pydict = nb::cast<nb::dict>(item.second);
9295
parse_dict(pydict,p.sublist(key));
9396
} else {
9497
EKAT_ERROR_MSG ("Unsupported/unrecognized dict entry type.\n");
9598
}
9699
}
97100
}
98101

99-
void parse_list (const pybind11::list& l, ekat::ParameterList&p, const std::string& key) {
100-
EKAT_REQUIRE_MSG (pybind11::len(l)>0,
102+
void parse_list (const nb::list& l, ekat::ParameterList&p, const std::string& key) {
103+
EKAT_REQUIRE_MSG (nb::len(l)>0,
101104
"Error! Cannot deduce type for dictionary list entry '" + key + "'\n");
102105
auto first = l[0];
103-
bool are_ints = pybind11::isinstance<pybind11::int_>(first);
104-
bool are_floats = pybind11::isinstance<pybind11::float_>(first);
105-
bool are_strings = pybind11::isinstance<pybind11::str>(first);
106+
bool are_ints = nb::isinstance<nb::int_>(first);
107+
bool are_floats = nb::isinstance<nb::float_>(first);
108+
bool are_strings = nb::isinstance<nb::str>(first);
106109
if (are_ints) {
107-
parse_list_impl<int,pybind11::int_>(l,p,key);
110+
parse_list_impl<int,nb::int_>(l,p,key);
108111
} else if (are_floats) {
109-
parse_list_impl<double,pybind11::float_>(l,p,key);
112+
parse_list_impl<double,nb::float_>(l,p,key);
110113
} else if (are_strings) {
111-
parse_list_impl<std::string,pybind11::str>(l,p,key);
114+
parse_list_impl<std::string,nb::str>(l,p,key);
112115
} else {
113116
EKAT_ERROR_MSG ("Unrecognized/unsupported list entry type.\n");
114117
}
115118
}
116119

117120
template<typename Txx, typename Tpy>
118-
void parse_list_impl(const pybind11::list& l, ekat::ParameterList& p, const std::string& key) {
121+
void parse_list_impl(const nb::list& l, ekat::ParameterList& p, const std::string& key) {
119122
std::vector<Txx> vals;
120123
for (auto item : l) {
121-
EKAT_REQUIRE_MSG (pybind11::isinstance<Tpy>(item),
124+
EKAT_REQUIRE_MSG (nb::isinstance<Tpy>(item),
122125
"Error! Inconsistent types in list entries.\n");
123-
auto item_py = pybind11::cast<Tpy>(item);
124-
vals.push_back(item_py.template cast<Txx>());
126+
auto item_py = nb::cast<Tpy>(item);
127+
vals.push_back(nb::cast<Txx>(item_py));
125128
}
126129
p.set(key,vals);
127130
}
128131
};
129132

130-
inline void pybind_pyparamlist (pybind11::module& m)
133+
inline void nb_pyparamlist (nb::module_& m)
131134
{
132135
// Param list
133-
pybind11::class_<PyParamList>(m,"ParameterList")
134-
.def(pybind11::init<const pybind11::dict&>())
135-
.def(pybind11::init<const pybind11::dict&,const std::string&>())
136+
nb::class_<PyParamList>(m,"ParameterList")
137+
.def(nb::init<const nb::dict&>())
138+
.def(nb::init<const nb::dict&,const std::string&>())
136139
.def("sublist",&PyParamList::sublist)
137140
.def("print",&PyParamList::print)
138141
.def("set",&PyParamList::set<bool>)

0 commit comments

Comments
 (0)