Skip to content

Commit c393392

Browse files
authored
Merge pull request #1864 from amcadmus/master
Merge devel into master
2 parents 5a32c49 + ee3b01c commit c393392

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+513
-2829
lines changed

.github/workflows/test_python.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,5 +68,6 @@ jobs:
6868
CC: gcc-${{ matrix.gcc }}
6969
CXX: g++-${{ matrix.gcc }}
7070
TENSORFLOW_VERSION: ${{ matrix.tf }}
71+
SETUPTOOLS_ENABLE_FEATURES: "legacy-editable"
7172
- run: dp --version
7273
- run: pytest --cov=deepmd source/tests && codecov

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ _build
2828
_templates
2929
API_CC
3030
doc/api_py/
31+
doc/api_core/
3132
dp/
3233
dp_test/
3334
dp_test_cc/

deepmd/env.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -224,13 +224,15 @@ def get_module(module_name: str) -> "ModuleType":
224224
"""
225225
if platform.system() == "Windows":
226226
ext = ".dll"
227+
prefix = ""
227228
#elif platform.system() == "Darwin":
228229
# ext = ".dylib"
229230
else:
230231
ext = ".so"
232+
prefix = "lib"
231233

232234
module_file = (
233-
(Path(__file__).parent / SHARED_LIB_MODULE / module_name)
235+
(Path(__file__).parent / SHARED_LIB_MODULE / (prefix + module_name))
234236
.with_suffix(ext)
235237
.resolve()
236238
)
@@ -324,8 +326,8 @@ def _get_package_constants(
324326
TF_VERSION = GLOBAL_CONFIG["tf_version"]
325327
TF_CXX11_ABI_FLAG = int(GLOBAL_CONFIG["tf_cxx11_abi_flag"])
326328

327-
op_module = get_module("libop_abi")
328-
op_grads_module = get_module("libop_grads")
329+
op_module = get_module("op_abi")
330+
op_grads_module = get_module("op_grads")
329331

330332
# FLOAT_PREC
331333
dp_float_prec = os.environ.get("DP_INTERFACE_PREC", "high").lower()

deepmd/infer/deep_eval.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,7 @@ def model_type(self) -> str:
7070
:type:str
7171
"""
7272
t_mt = self._get_tensor("model_attr/model_type:0")
73-
sess = tf.Session(graph=self.graph, config=default_tf_session_config)
74-
[mt] = run_sess(sess, [t_mt], feed_dict={})
73+
[mt] = run_sess(self.sess, [t_mt], feed_dict={})
7574
return mt.decode("utf-8")
7675

7776
@property
@@ -90,10 +89,16 @@ def model_version(self) -> str:
9089
# For deepmd-kit version 0.x - 1.x, set model version to 0.0
9190
return "0.0"
9291
else:
93-
sess = tf.Session(graph=self.graph, config=default_tf_session_config)
94-
[mt] = run_sess(sess, [t_mt], feed_dict={})
92+
[mt] = run_sess(self.sess, [t_mt], feed_dict={})
9593
return mt.decode("utf-8")
9694

95+
@property
96+
@lru_cache(maxsize=None)
97+
def sess(self) -> tf.Session:
98+
"""Get TF session."""
99+
# start a tf session associated to the graph
100+
return tf.Session(graph=self.graph, config=default_tf_session_config)
101+
97102
def _graph_compatable(
98103
self
99104
) -> bool :

deepmd/infer/deep_pot.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,6 @@ def __init__(
125125
for attr_name, tensor_name in self.tensors.items():
126126
self._get_tensor(tensor_name, attr_name)
127127

128-
# start a tf session associated to the graph
129-
self.sess = tf.Session(graph=self.graph, config=default_tf_session_config)
130128
self._run_default_sess()
131129
self.tmap = self.tmap.decode('UTF-8').split()
132130

deepmd/infer/deep_tensor.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,6 @@ def __init__(
8181
self.tensors.update(optional_tensors)
8282
self._support_gfv = True
8383

84-
# start a tf session associated to the graph
85-
self.sess = tf.Session(graph=self.graph, config=default_tf_session_config)
8684
self._run_default_sess()
8785
self.tmap = self.tmap.decode('UTF-8').split()
8886

deepmd/train/run_options.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
f"source commit: {GLOBAL_CONFIG['git_hash']}",
4848
f"source commit at: {GLOBAL_CONFIG['git_date']}",
4949
f"build float prec: {global_float_prec}",
50+
f"build variant: {GLOBAL_CONFIG['dp_variant']}",
5051
f"build with tf inc: {GLOBAL_CONFIG['tf_include_dir']}",
5152
f"build with tf lib: {GLOBAL_CONFIG['tf_libs'].replace(';', _sep)}" # noqa
5253
)

deepmd/train/trainer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import logging
44
import os
55
import glob
6+
import platform
67
import time
78
import shutil
89
import google.protobuf.message
@@ -574,7 +575,11 @@ def save_checkpoint(self, cur_batch: int):
574575
os.remove(new_ff)
575576
except OSError:
576577
pass
577-
os.symlink(ori_ff, new_ff)
578+
if platform.system() != 'Windows':
579+
# by default one does not have access to create symlink on Windows
580+
os.symlink(ori_ff, new_ff)
581+
else:
582+
shutil.copyfile(ori_ff, new_ff)
578583
log.info("saved checkpoint %s" % self.save_ckpt)
579584

580585
def get_feed_dict(self, batch, is_training):

deepmd/utils/tabulate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def __init__(self,
7777
self.activation_fn = activation_fn
7878

7979
self.graph, self.graph_def = load_graph_def(self.model_file)
80-
self.sess = tf.Session(graph = self.graph)
80+
#self.sess = tf.Session(graph = self.graph)
8181

8282
self.sub_graph, self.sub_graph_def = self._load_sub_graph()
8383
self.sub_sess = tf.Session(graph = self.sub_graph)

0 commit comments

Comments
 (0)