Skip to content

Commit d3095cf

Browse files
authored
chore(cc): merge get backend codes (#4355)
Fix #4308. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced a new function to dynamically determine the backend framework based on the model file type. - **Improvements** - Enhanced backend detection logic in multiple classes, allowing for more flexible model initialization. - Simplified control flow in the initialization methods of various components. - **Bug Fixes** - Improved error handling for unsupported backends and model formats during initialization processes. <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
1 parent 6d9d8bb commit d3095cf

File tree

6 files changed

+23
-23
lines changed

6 files changed

+23
-23
lines changed

source/api_cc/include/common.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,13 @@ namespace deepmd {
1515
typedef double ENERGYTYPE;
1616
enum DPBackend { TensorFlow, PyTorch, Paddle, JAX, Unknown };
1717

18+
/**
19+
* @brief Get the backend of the model.
20+
* @param[in] model The model name.
21+
* @return The backend of the model.
22+
**/
23+
DPBackend get_backend(const std::string& model);
24+
1825
struct NeighborListData {
1926
/// Array stores the core region atom's index
2027
std::vector<int> ilist;

source/api_cc/src/DataModifier.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@ void DipoleChargeModifier::init(const std::string& model,
2828
<< std::endl;
2929
return;
3030
}
31-
// TODO: To implement detect_backend
32-
DPBackend backend = deepmd::DPBackend::TensorFlow;
31+
const DPBackend backend = get_backend(model);
3332
if (deepmd::DPBackend::TensorFlow == backend) {
3433
#ifdef BUILD_TENSORFLOW
3534
dcm = std::make_shared<deepmd::DipoleChargeModifierTF>(model, gpu_rank,

source/api_cc/src/DeepPot.cc

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,17 +39,7 @@ void DeepPot::init(const std::string& model,
3939
<< std::endl;
4040
return;
4141
}
42-
DPBackend backend;
43-
if (model.length() >= 4 && model.substr(model.length() - 4) == ".pth") {
44-
backend = deepmd::DPBackend::PyTorch;
45-
} else if (model.length() >= 3 && model.substr(model.length() - 3) == ".pb") {
46-
backend = deepmd::DPBackend::TensorFlow;
47-
} else if (model.length() >= 11 &&
48-
model.substr(model.length() - 11) == ".savedmodel") {
49-
backend = deepmd::DPBackend::JAX;
50-
} else {
51-
throw deepmd::deepmd_exception("Unsupported model file format");
52-
}
42+
const DPBackend backend = get_backend(model);
5343
if (deepmd::DPBackend::TensorFlow == backend) {
5444
#ifdef BUILD_TENSORFLOW
5545
dp = std::make_shared<deepmd::DeepPotTF>(model, gpu_rank, file_content);

source/api_cc/src/DeepSpin.cc

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,7 @@ void DeepSpin::init(const std::string& model,
3636
<< std::endl;
3737
return;
3838
}
39-
DPBackend backend;
40-
if (model.length() >= 4 && model.substr(model.length() - 4) == ".pth") {
41-
backend = deepmd::DPBackend::PyTorch;
42-
} else if (model.length() >= 3 && model.substr(model.length() - 3) == ".pb") {
43-
backend = deepmd::DPBackend::TensorFlow;
44-
} else {
45-
throw deepmd::deepmd_exception("Unsupported model file format");
46-
}
39+
const DPBackend backend = get_backend(model);
4740
if (deepmd::DPBackend::TensorFlow == backend) {
4841
#ifdef BUILD_TENSORFLOW
4942
dp = std::make_shared<deepmd::DeepSpinTF>(model, gpu_rank, file_content);

source/api_cc/src/DeepTensor.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@ void DeepTensor::init(const std::string &model,
3030
<< std::endl;
3131
return;
3232
}
33-
// TODO: To implement detect_backend
34-
DPBackend backend = deepmd::DPBackend::TensorFlow;
33+
const DPBackend backend = get_backend(model);
3534
if (deepmd::DPBackend::TensorFlow == backend) {
3635
#ifdef BUILD_TENSORFLOW
3736
dt = std::make_shared<deepmd::DeepTensorTF>(model, gpu_rank, name_scope_);

source/api_cc/src/common.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1399,3 +1399,15 @@ void deepmd::print_summary(const std::string& pre) {
13991399
<< "set tf inter_op_parallelism_threads: " << num_inter_nthreads
14001400
<< std::endl;
14011401
}
1402+
1403+
deepmd::DPBackend deepmd::get_backend(const std::string& model) {
1404+
if (model.length() >= 4 && model.substr(model.length() - 4) == ".pth") {
1405+
return deepmd::DPBackend::PyTorch;
1406+
} else if (model.length() >= 3 && model.substr(model.length() - 3) == ".pb") {
1407+
return deepmd::DPBackend::TensorFlow;
1408+
} else if (model.length() >= 11 &&
1409+
model.substr(model.length() - 11) == ".savedmodel") {
1410+
return deepmd::DPBackend::JAX;
1411+
}
1412+
throw deepmd::deepmd_exception("Unsupported model file format");
1413+
}

0 commit comments

Comments
 (0)