diff --git a/sherpa-onnx/c-api/cxx-api.cc b/sherpa-onnx/c-api/cxx-api.cc index f73433ba02..58e603e90b 100644 --- a/sherpa-onnx/c-api/cxx-api.cc +++ b/sherpa-onnx/c-api/cxx-api.cc @@ -821,4 +821,33 @@ bool FileExists(const std::string &filename) { return SherpaOnnxFileExists(filename.c_str()); } +// ============================================================ +// For Offline Punctuation +// ============================================================ +OfflinePunctuation OfflinePunctuation::Create(const OfflinePunctuationConfig &config) { + struct SherpaOnnxOfflinePunctuationConfig c; + memset(&c, 0, sizeof(c)); + c.model.ct_transformer = config.model.ct_transformer.c_str(); + c.model.num_threads = config.model.num_threads; + c.model.debug = config.model.debug; + c.model.provider = config.model.provider.c_str(); + + const SherpaOnnxOfflinePunctuation *punct = SherpaOnnxCreateOfflinePunctuation(&c); + return OfflinePunctuation(punct); +} + +OfflinePunctuation::OfflinePunctuation(const SherpaOnnxOfflinePunctuation *p) + : MoveOnly(p) {} + +void OfflinePunctuation::Destroy(const SherpaOnnxOfflinePunctuation *p) const { + SherpaOnnxDestroyOfflinePunctuation(p_); +} + +std::string OfflinePunctuation::AddPunctuation(const std::string &text) const { + const char *result = SherpaOfflinePunctuationAddPunct(p_, text.c_str()); + std::string ans(result); + SherpaOfflinePunctuationFreeText(result); + return ans; +} + } // namespace sherpa_onnx::cxx diff --git a/sherpa-onnx/c-api/cxx-api.h b/sherpa-onnx/c-api/cxx-api.h index 2fc2c89f88..c285476aa3 100644 --- a/sherpa-onnx/c-api/cxx-api.h +++ b/sherpa-onnx/c-api/cxx-api.h @@ -673,6 +673,33 @@ SHERPA_ONNX_API std::string GetGitSha1(); SHERPA_ONNX_API std::string GetGitDate(); SHERPA_ONNX_API bool FileExists(const std::string &filename); +// ============================================================================ +// Offline Punctuation +// ============================================================================ + +struct OfflinePunctuationModelConfig { + std::string ct_transformer; + int32_t num_threads = 1; + bool debug = false; + std::string provider = "cpu"; +}; + +struct OfflinePunctuationConfig { + OfflinePunctuationModelConfig model; +}; + +class SHERPA_ONNX_API OfflinePunctuation + : public MoveOnly { + public: + static OfflinePunctuation Create(const OfflinePunctuationConfig &config); + + void Destroy(const SherpaOnnxOfflinePunctuation *p) const; + + std::string AddPunctuation(const std::string &text) const; + private: + explicit OfflinePunctuation(const SherpaOnnxOfflinePunctuation *p); +}; + } // namespace sherpa_onnx::cxx #endif // SHERPA_ONNX_C_API_CXX_API_H_