diff --git a/CMakeLists.txt b/CMakeLists.txt index 41f4783ca2..30e9ba2f7e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -10,6 +10,7 @@ project(YDB-CPP-SDK VERSION ${YDB_SDK_VERSION} LANGUAGES C CXX ASM) option(YDB_SDK_INSTALL "Install YDB C++ SDK" Off) option(YDB_SDK_TESTS "Build YDB C++ SDK tests" Off) option(YDB_SDK_EXAMPLES "Build YDB C++ SDK examples" On) +option(YDB_SDK_ODBC "Build YDB ODBC driver" On) set(YDB_SDK_GOOGLE_COMMON_PROTOS_TARGET "" CACHE STRING "Name of cmake target preparing google common proto library") option(YDB_SDK_USE_RAPID_JSON "Search for rapid json library in system" ON) @@ -61,6 +62,10 @@ add_subdirectory(util) #_ydb_sdk_validate_public_headers() +if (YDB_SDK_ODBC) + add_subdirectory(odbc) +endif() + if (YDB_SDK_EXAMPLES) add_subdirectory(examples) endif() diff --git a/cmake/common.cmake b/cmake/common.cmake index 546ce4e81f..89ebb5eaca 100644 --- a/cmake/common.cmake +++ b/cmake/common.cmake @@ -115,7 +115,7 @@ function(generate_enum_serilization Tgt Input) endfunction() function(add_global_library_for TgtName MainName) - add_library(${TgtName} STATIC ${ARGN}) + _ydb_sdk_add_library(${TgtName} STATIC ${ARGN}) if(APPLE) target_link_options(${MainName} INTERFACE "SHELL:-Wl,-force_load,$${TgtName}>") else() @@ -182,7 +182,7 @@ endfunction() function(_ydb_sdk_add_library Tgt) cmake_parse_arguments(ARG - "INTERFACE" "" "" + "INTERFACE;OBJECT;SHARED" "" "" ${ARGN} ) @@ -192,6 +192,12 @@ function(_ydb_sdk_add_library Tgt) set(libraryMode "INTERFACE") set(includeMode "INTERFACE") endif() + if (ARG_OBJECT) + set(libraryMode "OBJECT") + endif() + if (ARG_SHARED) + set(libraryMode "SHARED") + endif() add_library(${Tgt} ${libraryMode}) target_include_directories(${Tgt} ${includeMode} $ @@ -201,6 +207,7 @@ function(_ydb_sdk_add_library Tgt) target_compile_definitions(${Tgt} ${includeMode} YDB_SDK_USE_STD_STRING ) + set_property(TARGET ${Tgt} PROPERTY POSITION_INDEPENDENT_CODE ON) endfunction() function(_ydb_sdk_validate_public_headers) @@ -255,4 +262,3 @@ function(_ydb_sdk_validate_public_headers) ) target_include_directories(validate_public_interface PUBLIC ${YDB_SDK_BINARY_DIR}/__validate_headers_dir/include) endfunction() - diff --git a/cmake/external_libs.cmake b/cmake/external_libs.cmake index 22d0603e77..a252c588ae 100644 --- a/cmake/external_libs.cmake +++ b/cmake/external_libs.cmake @@ -14,6 +14,10 @@ find_package(Brotli 1.1.0 REQUIRED) find_package(jwt-cpp REQUIRED) find_package(double-conversion REQUIRED) +if (YDB_SDK_ODBC) + find_package(ODBC REQUIRED) +endif() + # RapidJSON if (YDB_SDK_USE_RAPID_JSON) find_package(RapidJSON REQUIRED) diff --git a/cmake/testing.cmake b/cmake/testing.cmake index d2a2050e23..43a0ba126c 100644 --- a/cmake/testing.cmake +++ b/cmake/testing.cmake @@ -83,3 +83,35 @@ function(add_ydb_test) vcs_info(${YDB_TEST_NAME}) endfunction() + +if (YDB_SDK_ODBC) + function(add_odbc_test) + set(opts "") + set(oneval_args NAME WORKING_DIRECTORY OUTPUT_DIRECTORY) + set(multival_args SOURCES LINK_LIBRARIES LABELS) + cmake_parse_arguments(ODBC_TEST + "${opts}" + "${oneval_args}" + "${multival_args}" + ${ARGN} + ) + + add_ydb_test(GTEST + NAME ${ODBC_TEST_NAME} + SOURCES ${ODBC_TEST_SOURCES} + LINK_LIBRARIES + ${ODBC_TEST_LINK_LIBRARIES} + ODBC::ODBC + LABELS + integration + ${ODBC_TEST_LABELS} + ) + + target_compile_definitions(${ODBC_TEST_NAME} + PRIVATE + ODBC_DRIVER_PATH="$" + ) + + add_dependencies(${ODBC_TEST_NAME} ydb-odbc) + endfunction() +endif() diff --git a/odbc/CMakeLists.txt b/odbc/CMakeLists.txt new file mode 100644 index 0000000000..f814f00313 --- /dev/null +++ b/odbc/CMakeLists.txt @@ -0,0 +1,54 @@ +add_library(ydb-odbc SHARED + src/utils/cursor.cpp + src/utils/types.cpp + src/utils/util.cpp + src/utils/convert.cpp + src/odbc_driver.cpp + src/connection.cpp + src/statement.cpp + src/environment.cpp +) + +target_include_directories(ydb-odbc + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/include + ${CMAKE_CURRENT_SOURCE_DIR}/src + ${ODBC_INCLUDE_DIRS} +) + +target_link_libraries(ydb-odbc + PRIVATE + YDB-CPP-SDK::Query + YDB-CPP-SDK::Table + YDB-CPP-SDK::Scheme + YDB-CPP-SDK::Driver + ODBC::ODBC +) + +set_target_properties(ydb-odbc PROPERTIES + POSITION_INDEPENDENT_CODE ON +) + +install(TARGETS ydb-odbc + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} +) + +install(DIRECTORY include/ + DESTINATION include/ydb-odbc +) + +add_subdirectory(examples) +add_subdirectory(tests) + +include(GNUInstallDirs) + +install(FILES + odbcinst.ini + DESTINATION ${CMAKE_INSTALL_SYSCONFDIR}/odbcinst.d + RENAME ydb-odbc.ini +) + +install(FILES + odbc.ini + DESTINATION ${CMAKE_INSTALL_SYSCONFDIR} +) diff --git a/odbc/README.md b/odbc/README.md new file mode 100644 index 0000000000..c73f9b8704 --- /dev/null +++ b/odbc/README.md @@ -0,0 +1,80 @@ +# YDB ODBC Driver + +ODBC driver for YDB. + +## Requirements + +- CMake 3.10 or higher +- C/C++ compiler with C11 and C++20 support +- YDB C++ SDK +- unixODBC (for Linux/macOS) + +## Build + +```bash +cmake -DYDB_SDK_ODBC=1 --preset release-clang +cmake --build --preset default +``` + +## Configuration + +1. Make sure the driver is registered: +```bash +odbcinst -q -d +``` + +2. Check available data sources: +```bash +odbcinst -q -s +``` + +3. Edit `/etc/odbc.ini` to configure the connection: +```ini +[YDB] +Driver=YDB +Description=YDB Database Connection +Server=your-server:port +Database=/path/to/database +``` + +## Usage + +Example of connecting via isql: +```bash +isql -v YDB +``` + +Example usage in C: +```c +SQLHENV env; +SQLHDBC dbc; +SQLHSTMT stmt; + +// Initialize environment +SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &env); +SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0); + +// Connect +SQLAllocHandle(SQL_HANDLE_DBC, env, &dbc); +SQLConnect(dbc, (SQLCHAR*)"YDB", SQL_NTS, + (SQLCHAR*)"", SQL_NTS, + (SQLCHAR*)"", SQL_NTS); + +// Execute query +SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt); +SQLExecDirect(stmt, (SQLCHAR*)"SELECT * FROM mytable", SQL_NTS); + +// Cleanup +SQLFreeHandle(SQL_HANDLE_STMT, stmt); +SQLDisconnect(dbc); +SQLFreeHandle(SQL_HANDLE_DBC, dbc); +SQLFreeHandle(SQL_HANDLE_ENV, env); +``` + +## Parameters + +Use names $p1, $p2, ... for parameter names + +## License + +Apache License 2.0 diff --git a/odbc/examples/CMakeLists.txt b/odbc/examples/CMakeLists.txt new file mode 100644 index 0000000000..88b1f27cc6 --- /dev/null +++ b/odbc/examples/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(basic) +add_subdirectory(scheme) diff --git a/odbc/examples/basic/CMakeLists.txt b/odbc/examples/basic/CMakeLists.txt new file mode 100644 index 0000000000..b99d1175f4 --- /dev/null +++ b/odbc/examples/basic/CMakeLists.txt @@ -0,0 +1,14 @@ +add_executable(odbc_basic + main.cpp +) + +target_link_libraries(odbc_basic + PRIVATE + ODBC::ODBC +) +target_compile_definitions(odbc_basic + PRIVATE + ODBC_DRIVER_PATH="$" +) + +add_dependencies(odbc_basic ydb-odbc) diff --git a/odbc/examples/basic/main.cpp b/odbc/examples/basic/main.cpp new file mode 100644 index 0000000000..8084e32f3d --- /dev/null +++ b/odbc/examples/basic/main.cpp @@ -0,0 +1,132 @@ +#include +#include + +#include + +void PrintOdbcError(SQLSMALLINT handleType, SQLHANDLE handle) { + SQLCHAR sqlState[6] = {0}; + SQLINTEGER nativeError = 0; + SQLCHAR message[256] = {0}; + SQLSMALLINT textLength = 0; + SQLGetDiagRec(handleType, handle, 1, sqlState, &nativeError, message, sizeof(message), &textLength); + std::cerr << "ODBC error: [" << sqlState << "] " << message << std::endl; +} + +int main() { + SQLHENV henv = nullptr; + SQLHDBC hdbc = nullptr; + SQLHSTMT hstmt = nullptr; + SQLRETURN ret; + + std::cout << "1. Allocating environment handle" << std::endl; + ret = SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &henv); + if (ret != SQL_SUCCESS && ret != SQL_SUCCESS_WITH_INFO) { + std::cerr << "Error allocating environment handle" << std::endl; + return 1; + } + SQLSetEnvAttr(henv, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0); + + std::cout << "2. Allocating connection handle" << std::endl; + ret = SQLAllocHandle(SQL_HANDLE_DBC, henv, &hdbc); + if (ret != SQL_SUCCESS && ret != SQL_SUCCESS_WITH_INFO) { + std::cerr << "Error allocating connection handle" << std::endl; + SQLFreeHandle(SQL_HANDLE_ENV, henv); + return 1; + } + + std::cout << "3. Building connection string" << std::endl; + std::string connStr = "Driver=" ODBC_DRIVER_PATH ";Endpoint=localhost:2136;Database=/local;"; + SQLCHAR outConnStr[1024] = {0}; + SQLSMALLINT outConnStrLen = 0; + + std::cout << "4. Connecting with SQLDriverConnect" << std::endl; + ret = SQLDriverConnect(hdbc, NULL, (SQLCHAR*)connStr.c_str(), SQL_NTS, + outConnStr, sizeof(outConnStr), &outConnStrLen, SQL_DRIVER_COMPLETE); + if (ret != SQL_SUCCESS && ret != SQL_SUCCESS_WITH_INFO) { + std::cerr << "Error connecting with SQLDriverConnect" << std::endl; + PrintOdbcError(SQL_HANDLE_DBC, hdbc); + SQLFreeHandle(SQL_HANDLE_DBC, hdbc); + SQLFreeHandle(SQL_HANDLE_ENV, henv); + return 1; + } + + std::cout << "5. Allocating statement handle" << std::endl; + ret = SQLAllocHandle(SQL_HANDLE_STMT, hdbc, &hstmt); + if (ret != SQL_SUCCESS && ret != SQL_SUCCESS_WITH_INFO) { + std::cerr << "Error allocating statement handle" << std::endl; + SQLDisconnect(hdbc); + SQLFreeHandle(SQL_HANDLE_DBC, hdbc); + SQLFreeHandle(SQL_HANDLE_ENV, henv); + return 1; + } + + std::cout << "6. Executing query" << std::endl; + SQLCHAR query[] = R"( + DECLARE $p1 AS Int64?; + SELECT id, data from test_table WHERE id == $p1; + )"; + + int64_t paramValue = 1; + SQLLEN paramInd = 0; + ret = SQLBindParameter(hstmt, 1, SQL_PARAM_INPUT, SQL_C_SBIGINT, SQL_BIGINT, 0, 0, ¶mValue, 0, ¶mInd); + if (ret != SQL_SUCCESS && ret != SQL_SUCCESS_WITH_INFO) { + std::cerr << "Error binding parameter" << std::endl; + PrintOdbcError(SQL_HANDLE_STMT, hstmt); + SQLFreeHandle(SQL_HANDLE_STMT, hstmt); + SQLDisconnect(hdbc); + SQLFreeHandle(SQL_HANDLE_DBC, hdbc); + SQLFreeHandle(SQL_HANDLE_ENV, henv); + return 1; + } + + ret = SQLExecDirect(hstmt, query, SQL_NTS); + if (ret != SQL_SUCCESS && ret != SQL_SUCCESS_WITH_INFO) { + std::cerr << "Error executing query" << std::endl; + PrintOdbcError(SQL_HANDLE_STMT, hstmt); + SQLFreeHandle(SQL_HANDLE_STMT, hstmt); + SQLDisconnect(hdbc); + SQLFreeHandle(SQL_HANDLE_DBC, hdbc); + SQLFreeHandle(SQL_HANDLE_ENV, henv); + return 1; + } + + std::cout << "7. Fetching result" << std::endl; + + SQLLEN ind = 0; + int value1 = 0; + if (SQLBindCol(hstmt, 1, SQL_C_SLONG, &value1, 0, &ind) != SQL_SUCCESS) { + std::cerr << "Error binding column 1" << std::endl; + PrintOdbcError(SQL_HANDLE_STMT, hstmt); + return 1; + } + + SQLCHAR value2[1024] = {0}; + if (SQLBindCol(hstmt, 2, SQL_C_CHAR, &value2, 1024, &ind) != SQL_SUCCESS) { + std::cerr << "Error binding column 2" << std::endl; + PrintOdbcError(SQL_HANDLE_STMT, hstmt); + return 1; + } + + while ((ret = SQLFetch(hstmt)) == SQL_SUCCESS || ret == SQL_SUCCESS_WITH_INFO) { + if (ret != SQL_SUCCESS) { + std::cerr << "Error fetching result" << std::endl; + PrintOdbcError(SQL_HANDLE_STMT, hstmt); + return 1; + } + + std::cout << "Result column 1: " << value1 << std::endl; + std::cout << "Result column 2: " << value2 << std::endl; + + std::cout << "--------------------------------" << std::endl; + } + + std::cout << "8. Cleaning up" << std::endl; + + SQLCloseCursor(hstmt); + SQLFreeHandle(SQL_HANDLE_STMT, hstmt); + SQLDisconnect(hdbc); + SQLFreeHandle(SQL_HANDLE_DBC, hdbc); + SQLFreeHandle(SQL_HANDLE_ENV, henv); + + return 0; +} diff --git a/odbc/examples/scheme/CMakeLists.txt b/odbc/examples/scheme/CMakeLists.txt new file mode 100644 index 0000000000..ffab881aed --- /dev/null +++ b/odbc/examples/scheme/CMakeLists.txt @@ -0,0 +1,14 @@ +add_executable(odbc_scheme + main.cpp +) + +target_link_libraries(odbc_scheme + PRIVATE + ODBC::ODBC +) +target_compile_definitions(odbc_scheme + PRIVATE + ODBC_DRIVER_PATH="$" +) + +add_dependencies(odbc_scheme ydb-odbc) diff --git a/odbc/examples/scheme/main.cpp b/odbc/examples/scheme/main.cpp new file mode 100644 index 0000000000..3ae2cd6fe4 --- /dev/null +++ b/odbc/examples/scheme/main.cpp @@ -0,0 +1,116 @@ +#include +#include + +#include + +void PrintOdbcError(SQLSMALLINT handleType, SQLHANDLE handle) { + SQLCHAR sqlState[6] = {0}; + SQLINTEGER nativeError = 0; + SQLCHAR message[256] = {0}; + SQLSMALLINT textLength = 0; + SQLGetDiagRec(handleType, handle, 1, sqlState, &nativeError, message, sizeof(message), &textLength); + std::cerr << "ODBC error: [" << sqlState << "] " << message << std::endl; +} + +int main() { + SQLHENV henv = nullptr; + SQLHDBC hdbc = nullptr; + SQLHSTMT hstmt = nullptr; + SQLRETURN ret; + + std::cout << "1. Allocating environment handle" << std::endl; + ret = SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &henv); + if (ret != SQL_SUCCESS && ret != SQL_SUCCESS_WITH_INFO) { + std::cerr << "Error allocating environment handle" << std::endl; + return 1; + } + SQLSetEnvAttr(henv, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0); + + std::cout << "2. Allocating connection handle" << std::endl; + ret = SQLAllocHandle(SQL_HANDLE_DBC, henv, &hdbc); + if (ret != SQL_SUCCESS && ret != SQL_SUCCESS_WITH_INFO) { + std::cerr << "Error allocating connection handle" << std::endl; + SQLFreeHandle(SQL_HANDLE_ENV, henv); + return 1; + } + + std::cout << "3. Building connection string" << std::endl; + std::string connStr = "Driver=" ODBC_DRIVER_PATH ";Endpoint=localhost:2136;Database=/local;"; + SQLCHAR outConnStr[1024] = {0}; + SQLSMALLINT outConnStrLen = 0; + + std::cout << "4. Connecting with SQLDriverConnect" << std::endl; + ret = SQLDriverConnect(hdbc, NULL, (SQLCHAR*)connStr.c_str(), SQL_NTS, + outConnStr, sizeof(outConnStr), &outConnStrLen, SQL_DRIVER_COMPLETE); + if (ret != SQL_SUCCESS && ret != SQL_SUCCESS_WITH_INFO) { + std::cerr << "Error connecting with SQLDriverConnect" << std::endl; + PrintOdbcError(SQL_HANDLE_DBC, hdbc); + SQLFreeHandle(SQL_HANDLE_DBC, hdbc); + SQLFreeHandle(SQL_HANDLE_ENV, henv); + return 1; + } + + std::cout << "5. Allocating statement handle" << std::endl; + ret = SQLAllocHandle(SQL_HANDLE_STMT, hdbc, &hstmt); + if (ret != SQL_SUCCESS && ret != SQL_SUCCESS_WITH_INFO) { + std::cerr << "Error allocating statement handle" << std::endl; + SQLDisconnect(hdbc); + SQLFreeHandle(SQL_HANDLE_DBC, hdbc); + SQLFreeHandle(SQL_HANDLE_ENV, henv); + return 1; + } + + std::cout << "6. Getting tables" << std::endl; + + SQLCHAR pattern[] = "/local"; + SQLCHAR tableType[] = "TABLE"; + + ret = SQLTables(hstmt, NULL, 0, NULL, 0, pattern, SQL_NTS, tableType, SQL_NTS); + if (ret != SQL_SUCCESS && ret != SQL_SUCCESS_WITH_INFO) { + std::cerr << "Error executing query" << std::endl; + PrintOdbcError(SQL_HANDLE_STMT, hstmt); + SQLFreeHandle(SQL_HANDLE_STMT, hstmt); + SQLDisconnect(hdbc); + SQLFreeHandle(SQL_HANDLE_DBC, hdbc); + SQLFreeHandle(SQL_HANDLE_ENV, henv); + return 1; + } + + std::cout << "7. Fetching result" << std::endl; + + SQLLEN ind = 0; + SQLCHAR value1[1024] = {0}; + if (SQLBindCol(hstmt, 3, SQL_C_CHAR, &value1, 1024, &ind) != SQL_SUCCESS) { + std::cerr << "Error binding column 1" << std::endl; + PrintOdbcError(SQL_HANDLE_STMT, hstmt); + return 1; + } + + SQLCHAR value2[1024] = {0}; + if (SQLBindCol(hstmt, 4, SQL_C_CHAR, &value2, 1024, &ind) != SQL_SUCCESS) { + std::cerr << "Error binding column 2" << std::endl; + PrintOdbcError(SQL_HANDLE_STMT, hstmt); + return 1; + } + + while ((ret = SQLFetch(hstmt)) == SQL_SUCCESS || ret == SQL_SUCCESS_WITH_INFO) { + if (ret != SQL_SUCCESS) { + std::cerr << "Error fetching result" << std::endl; + PrintOdbcError(SQL_HANDLE_STMT, hstmt); + return 1; + } + + std::cout << "Table name: " << value1 << std::endl; + std::cout << "Table type: " << value2 << std::endl; + + std::cout << "--------------------------------" << std::endl; + } + + std::cout << "8. Cleaning up" << std::endl; + SQLFreeHandle(SQL_HANDLE_STMT, hstmt); + SQLDisconnect(hdbc); + SQLFreeHandle(SQL_HANDLE_DBC, hdbc); + SQLFreeHandle(SQL_HANDLE_ENV, henv); + + return 0; +} diff --git a/odbc/odbc.ini b/odbc/odbc.ini new file mode 100644 index 0000000000..6335b3ee38 --- /dev/null +++ b/odbc/odbc.ini @@ -0,0 +1,9 @@ +[ODBC Data Sources] +YDB=YDB ODBC Driver + +[YDB] +Driver=YDB +Description=YDB Database Connection +Server=grpc://localhost:2136 +Database=local +AuthMode=none \ No newline at end of file diff --git a/odbc/odbcinst.ini b/odbc/odbcinst.ini new file mode 100644 index 0000000000..fd0b3f2765 --- /dev/null +++ b/odbc/odbcinst.ini @@ -0,0 +1,4 @@ +[YDB] +Description=YDB ODBC Driver +Driver=/home/brgayazov/ydbwork/ydb-cpp-sdk/build/odbc/libydb-odbc.so +Setup=/home/brgayazov/ydbwork/ydb-cpp-sdk/build/odbc/libydb-odbc.so \ No newline at end of file diff --git a/odbc/src/connection.cpp b/odbc/src/connection.cpp new file mode 100644 index 0000000000..eba32c74ef --- /dev/null +++ b/odbc/src/connection.cpp @@ -0,0 +1,175 @@ +#include "connection.h" +#include "statement.h" + +#include +#include +#include + +#include +#include + +#include + +namespace NYdb { +namespace NOdbc { + +SQLRETURN TConnection::DriverConnect(const std::string& connectionString) { + std::map params; + size_t pos = 0; + while (pos < connectionString.size()) { + size_t eq = connectionString.find('=', pos); + if (eq == std::string::npos) { + break; + } + + size_t sc = connectionString.find(';', eq); + std::string key = connectionString.substr(pos, eq-pos); + std::string val = connectionString.substr(eq+1, (sc == std::string::npos ? std::string::npos : sc-eq-1)); + params[key] = val; + if (sc == std::string::npos) { + break; + } + pos = sc+1; + } + Endpoint_ = params["Endpoint"]; + Database_ = params["Database"]; + + if (Endpoint_.empty() || Database_.empty()) { + AddError("08001", 0, "Missing Endpoint or Database in connection string"); + return SQL_ERROR; + } + + YdbDriver_ = std::make_unique(NYdb::TDriverConfig() + .SetEndpoint(Endpoint_) + .SetDatabase(Database_)); + + YdbClient_ = std::make_unique(*YdbDriver_); + YdbSchemeClient_ = std::make_unique(*YdbDriver_); + YdbTableClient_ = std::make_unique(*YdbDriver_); + + return SQL_SUCCESS; +} + +SQLRETURN TConnection::Connect(const std::string& serverName, + const std::string& userName, + const std::string& auth) { + + char endpoint[256] = {0}; + char database[256] = {0}; + + //SQLGetPrivateProfileString(serverName.c_str(), "Endpoint", "", endpoint, sizeof(endpoint), nullptr); + //SQLGetPrivateProfileString(serverName.c_str(), "Database", "", database, sizeof(database), nullptr); + + Endpoint_ = endpoint; + Database_ = database; + + if (Endpoint_.empty() || Database_.empty()) { + AddError("08001", 0, "Missing Endpoint or Database in DSN"); + return SQL_ERROR; + } + + YdbDriver_ = std::make_unique(NYdb::TDriverConfig() + .SetEndpoint(Endpoint_) + .SetDatabase(Database_)); + + YdbClient_ = std::make_unique(*YdbDriver_); + YdbSchemeClient_ = std::make_unique(*YdbDriver_); + YdbTableClient_ = std::make_unique(*YdbDriver_); + + return SQL_SUCCESS; +} + +SQLRETURN TConnection::Disconnect() { + YdbClient_.reset(); + YdbDriver_.reset(); + return SQL_SUCCESS; +} + +SQLRETURN TConnection::GetDiagRec(SQLSMALLINT recNumber, SQLCHAR* sqlState, SQLINTEGER* nativeError, + SQLCHAR* messageText, SQLSMALLINT bufferLength, SQLSMALLINT* textLength) { + if (recNumber < 1 || recNumber > (SQLSMALLINT)Errors_.size()) { + return SQL_NO_DATA; + } + + const auto& err = Errors_[recNumber-1]; + if (sqlState) { + strncpy((char*)sqlState, err.SqlState.c_str(), 6); + } + + if (nativeError) { + *nativeError = err.NativeError; + } + + if (messageText && bufferLength > 0) { + strncpy((char*)messageText, err.Message.c_str(), bufferLength); + if (textLength) { + *textLength = (SQLSMALLINT)std::min((int)err.Message.size(), (int)bufferLength); + } + } + return SQL_SUCCESS; +} + +std::unique_ptr TConnection::CreateStatement() { + return std::make_unique(this); +} + +void TConnection::RemoveStatement(TStatement* stmt) { + Statements_.erase(std::remove_if(Statements_.begin(), Statements_.end(), + [stmt](const std::unique_ptr& s) { return s.get() == stmt; }), Statements_.end()); +} + +void TConnection::AddError(const std::string& sqlState, SQLINTEGER nativeError, const std::string& message) { + Errors_.push_back({sqlState, nativeError, message}); +} + +void TConnection::ClearErrors() { + Errors_.clear(); +} + +SQLRETURN TConnection::SetAutocommit(bool value) { + Autocommit_ = value; + if (Autocommit_ && Tx_) { + auto status = Tx_->Commit().ExtractValueSync(); + if (!status.IsSuccess()) { + AddError("08001", 0, "Failed to commit transaction"); + return SQL_ERROR; + } + Tx_.reset(); + } + return SQL_SUCCESS; +} + +bool TConnection::GetAutocommit() const { + return Autocommit_; +} + +const std::optional& TConnection::GetTx() { + return Tx_; +} + +void TConnection::SetTx(const NQuery::TTransaction& tx) { + Tx_ = tx; +} + +SQLRETURN TConnection::CommitTx() { + auto status = Tx_->Commit().ExtractValueSync(); + if (!status.IsSuccess()) { + AddError("08001", 0, "Failed to commit transaction"); + return SQL_ERROR; + } + Tx_.reset(); + return SQL_SUCCESS; +} + +SQLRETURN TConnection::RollbackTx() { + auto status = Tx_->Rollback().ExtractValueSync(); + if (!status.IsSuccess()) { + AddError("08001", 0, "Failed to rollback transaction"); + return SQL_ERROR; + } + Tx_.reset(); + return SQL_SUCCESS; +} + +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/connection.h b/odbc/src/connection.h new file mode 100644 index 0000000000..fad8152777 --- /dev/null +++ b/odbc/src/connection.h @@ -0,0 +1,69 @@ +#pragma once + +#include "environment.h" + +#include +#include +#include +#include + +#include +#include + +#include +#include +#include + +namespace NYdb { +namespace NOdbc { + +class TStatement; + +class TConnection { +private: + std::unique_ptr YdbDriver_; + std::unique_ptr YdbClient_; + std::unique_ptr YdbTableClient_; + std::unique_ptr YdbSchemeClient_; + std::optional Tx_; + + TErrorList Errors_; + std::vector> Statements_; + std::string Endpoint_; + std::string Database_; + std::string AuthToken_; + + bool Autocommit_ = true; + +public: + SQLRETURN Connect(const std::string& serverName, + const std::string& userName, + const std::string& auth); + + SQLRETURN DriverConnect(const std::string& connectionString); + SQLRETURN Disconnect(); + SQLRETURN GetDiagRec(SQLSMALLINT recNumber, SQLCHAR* sqlState, SQLINTEGER* nativeError, + SQLCHAR* messageText, SQLSMALLINT bufferLength, SQLSMALLINT* textLength); + + std::unique_ptr CreateStatement(); + void RemoveStatement(TStatement* stmt); + + NYdb::NQuery::TQueryClient* GetClient() { return YdbClient_.get(); } + NYdb::NTable::TTableClient* GetTableClient() { return YdbTableClient_.get(); } + NScheme::TSchemeClient* GetSchemeClient() { return YdbSchemeClient_.get(); } + + void AddError(const std::string& sqlState, SQLINTEGER nativeError, const std::string& message); + void ClearErrors(); + + SQLRETURN SetAutocommit(bool value); + bool GetAutocommit() const; + + const std::optional& GetTx(); + void SetTx(const NQuery::TTransaction& tx); + + SQLRETURN CommitTx(); + SQLRETURN RollbackTx(); +}; + +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/environment.cpp b/odbc/src/environment.cpp new file mode 100644 index 0000000000..a09a634879 --- /dev/null +++ b/odbc/src/environment.cpp @@ -0,0 +1,54 @@ +#include "environment.h" +#include "connection.h" + +namespace NYdb { +namespace NOdbc { + +TEnvironment::TEnvironment() : OdbcVersion_(SQL_OV_ODBC3) {} +TEnvironment::~TEnvironment() {} + +SQLRETURN TEnvironment::SetAttribute(SQLINTEGER attribute, SQLPOINTER value, SQLINTEGER stringLength) { + // TODO: реализовать обработку атрибутов + OdbcVersion_ = attribute == SQL_ATTR_ODBC_VERSION ? reinterpret_cast(value) : 0; + return SQL_SUCCESS; +} + +SQLRETURN TEnvironment::GetDiagRec(SQLSMALLINT recNumber, + SQLCHAR* sqlState, + SQLINTEGER* nativeError, + SQLCHAR* messageText, + SQLSMALLINT bufferLength, + SQLSMALLINT* textLength) { + + if (recNumber < 1 || recNumber > (SQLSMALLINT)Errors_.size()) { + return SQL_NO_DATA; + } + + const auto& err = Errors_[recNumber-1]; + if (sqlState) { + strncpy((char*)sqlState, err.SqlState.c_str(), 6); + } + + if (nativeError) { + *nativeError = err.NativeError; + } + + if (messageText && bufferLength > 0) { + strncpy((char*)messageText, err.Message.c_str(), bufferLength); + if (textLength) { + *textLength = (SQLSMALLINT)std::min((int)err.Message.size(), (int)bufferLength); + } + } + return SQL_SUCCESS; +} + +void TEnvironment::AddError(const std::string& sqlState, SQLINTEGER nativeError, const std::string& message) { + Errors_.push_back({sqlState, nativeError, message}); +} + +void TEnvironment::ClearErrors() { + Errors_.clear(); +} + +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/environment.h b/odbc/src/environment.h new file mode 100644 index 0000000000..0190b91383 --- /dev/null +++ b/odbc/src/environment.h @@ -0,0 +1,40 @@ +#pragma once + +#include +#include + +#include +#include + +namespace NYdb { +namespace NOdbc { + +class TConnection; + +struct TErrorInfo { + std::string SqlState; + SQLINTEGER NativeError; + std::string Message; +}; + +using TErrorList = std::vector; + +class TEnvironment { +private: + SQLINTEGER OdbcVersion_; + TErrorList Errors_; + +public: + TEnvironment(); + ~TEnvironment(); + + SQLRETURN SetAttribute(SQLINTEGER attribute, SQLPOINTER value, SQLINTEGER stringLength); + SQLRETURN GetDiagRec(SQLSMALLINT recNumber, SQLCHAR* sqlState, SQLINTEGER* nativeError, + SQLCHAR* messageText, SQLSMALLINT bufferLength, SQLSMALLINT* textLength); + + void AddError(const std::string& sqlState, SQLINTEGER nativeError, const std::string& message); + void ClearErrors(); +}; + +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/odbc_driver.cpp b/odbc/src/odbc_driver.cpp new file mode 100644 index 0000000000..3dec17d6af --- /dev/null +++ b/odbc/src/odbc_driver.cpp @@ -0,0 +1,417 @@ +#include "environment.h" +#include "connection.h" +#include "statement.h" + +#include "utils/util.h" + +#include +#include + +extern "C" { + +SQLRETURN SQL_API SQLAllocHandle(SQLSMALLINT handleType, + SQLHANDLE inputHandle, + SQLHANDLE* outputHandle) { + if (!outputHandle) { + return SQL_INVALID_HANDLE; + } + + try { + switch (handleType) { + case SQL_HANDLE_ENV: { + if (inputHandle != SQL_NULL_HANDLE) { + return SQL_INVALID_HANDLE; + } + + *outputHandle = new NYdb::NOdbc::TEnvironment(); + return SQL_SUCCESS; + } + + case SQL_HANDLE_DBC: { + if (!inputHandle) { + return SQL_INVALID_HANDLE; + } + + *outputHandle = new NYdb::NOdbc::TConnection(); + return SQL_SUCCESS; + } + + case SQL_HANDLE_STMT: { + auto conn = static_cast(inputHandle); + if (!conn) { + return SQL_INVALID_HANDLE; + } + auto stmt = conn->CreateStatement(); + *outputHandle = stmt.release(); + return SQL_SUCCESS; + } + + default: + return SQL_ERROR; + } + } catch (...) { + return SQL_ERROR; + } +} + +SQLRETURN SQL_API SQLFreeHandle(SQLSMALLINT handleType, SQLHANDLE handle) { + if (!handle) { + return SQL_INVALID_HANDLE; + } + + try { + switch (handleType) { + case SQL_HANDLE_ENV: { + auto env = static_cast(handle); + delete env; + return SQL_SUCCESS; + } + + case SQL_HANDLE_DBC: { + auto conn = static_cast(handle); + delete conn; + return SQL_SUCCESS; + } + + case SQL_HANDLE_STMT: { + auto stmt = static_cast(handle); + if (stmt->GetConnection()) { + stmt->GetConnection()->RemoveStatement(stmt); + } + delete stmt; + return SQL_SUCCESS; + } + + default: + return SQL_ERROR; + } + } catch (...) { + return SQL_ERROR; + } +} + +SQLRETURN SQL_API SQLSetEnvAttr(SQLHENV environmentHandle, + SQLINTEGER attribute, + SQLPOINTER value, + SQLINTEGER stringLength) { + auto env = static_cast(environmentHandle); + if (!env) { + return SQL_INVALID_HANDLE; + } + + return env->SetAttribute(attribute, value, stringLength); +} + +SQLRETURN SQL_API SQLDriverConnect(SQLHDBC connectionHandle, + SQLHWND /*WindowHandle*/, + SQLCHAR* inConnectionString, + SQLSMALLINT stringLength1, + SQLCHAR* /*outConnectionString*/, + SQLSMALLINT /*bufferLength*/, + SQLSMALLINT* /*stringLength2Ptr*/, + SQLUSMALLINT /*driverCompletion*/) { + auto conn = static_cast(connectionHandle); + if (!conn) { + return SQL_INVALID_HANDLE; + } + + return conn->DriverConnect(NYdb::NOdbc::GetString(inConnectionString, stringLength1)); +} + +SQLRETURN SQL_API SQLConnect(SQLHDBC connectionHandle, + SQLCHAR* serverName, SQLSMALLINT nameLength1, + SQLCHAR* userName, SQLSMALLINT nameLength2, + SQLCHAR* authentication, SQLSMALLINT nameLength3) { + auto conn = static_cast(connectionHandle); + if (!conn) { + return SQL_INVALID_HANDLE; + } + + return conn->Connect(NYdb::NOdbc::GetString(serverName, nameLength1), + NYdb::NOdbc::GetString(userName, nameLength2), + NYdb::NOdbc::GetString(authentication, nameLength3)); +} + +SQLRETURN SQL_API SQLDisconnect(SQLHDBC connectionHandle) { + auto conn = static_cast(connectionHandle); + if (!conn) { + return SQL_INVALID_HANDLE; + } + + return conn->Disconnect(); +} + +SQLRETURN SQL_API SQLExecDirect(SQLHSTMT statementHandle, + SQLCHAR* statementText, + SQLINTEGER textLength) { + auto stmt = static_cast(statementHandle); + if (!stmt) { + return SQL_INVALID_HANDLE; + } + + auto ret = stmt->Prepare(NYdb::NOdbc::GetString(statementText, textLength)); + if (ret != SQL_SUCCESS) { + return ret; + } + return stmt->Execute(); +} + +SQLRETURN SQL_API SQLPrepare(SQLHSTMT statementHandle, + SQLCHAR* statementText, + SQLINTEGER textLength) { + auto stmt = static_cast(statementHandle); + if (!stmt) { + return SQL_INVALID_HANDLE; + } + return stmt->Prepare(NYdb::NOdbc::GetString(statementText, textLength)); +} + +SQLRETURN SQL_API SQLExecute(SQLHSTMT statementHandle) { + auto stmt = static_cast(statementHandle); + if (!stmt) { + return SQL_INVALID_HANDLE; + } + return stmt->Execute(); +} + +SQLRETURN SQL_API SQLFetch(SQLHSTMT statementHandle) { + auto stmt = static_cast(statementHandle); + if (!stmt) { + return SQL_INVALID_HANDLE; + } + + return stmt->Fetch(); +} + +SQLRETURN SQL_API SQLGetData(SQLHSTMT statementHandle, + SQLUSMALLINT columnNumber, + SQLSMALLINT targetType, + SQLPOINTER targetValue, + SQLLEN bufferLength, + SQLLEN* strLenOrInd) { + auto stmt = static_cast(statementHandle); + if (!stmt) { + return SQL_INVALID_HANDLE; + } + + return stmt->GetData(columnNumber, targetType, targetValue, bufferLength, strLenOrInd); +} + +SQLRETURN SQL_API SQLBindCol(SQLHSTMT statementHandle, + SQLUSMALLINT columnNumber, + SQLSMALLINT targetType, + SQLPOINTER targetValue, + SQLLEN bufferLength, + SQLLEN* strLenOrInd) { + auto stmt = static_cast(statementHandle); + if (!stmt) { + return SQL_INVALID_HANDLE; + } + return stmt->BindCol(columnNumber, targetType, targetValue, bufferLength, strLenOrInd); +} + +SQLRETURN SQL_API SQLGetDiagRec(SQLSMALLINT handleType, + SQLHANDLE handle, + SQLSMALLINT recNumber, + SQLCHAR* sqlState, + SQLINTEGER* nativeError, + SQLCHAR* messageText, + SQLSMALLINT bufferLength, + SQLSMALLINT* textLength) { + if (!handle) { + return SQL_INVALID_HANDLE; + } + + try { + switch (handleType) { + case SQL_HANDLE_ENV: { + auto env = static_cast(handle); + return env->GetDiagRec(recNumber, sqlState, nativeError, messageText, bufferLength, textLength); + } + + case SQL_HANDLE_DBC: { + auto conn = static_cast(handle); + return conn->GetDiagRec(recNumber, sqlState, nativeError, messageText, bufferLength, textLength); + } + + case SQL_HANDLE_STMT: { + auto stmt = static_cast(handle); + return stmt->GetDiagRec(recNumber, sqlState, nativeError, messageText, bufferLength, textLength); + } + + default: + return SQL_ERROR; + } + } catch (...) { + return SQL_ERROR; + } +} + +SQLRETURN SQL_API SQLBindParameter(SQLHSTMT statementHandle, + SQLUSMALLINT paramNumber, + SQLSMALLINT inputOutputType, + SQLSMALLINT valueType, + SQLSMALLINT parameterType, + SQLULEN columnSize, + SQLSMALLINT decimalDigits, + SQLPOINTER parameterValuePtr, + SQLLEN bufferLength, + SQLLEN* strLenOrIndPtr) { + auto stmt = static_cast(statementHandle); + if (!stmt) { + return SQL_INVALID_HANDLE; + } + + return stmt->BindParameter(paramNumber, inputOutputType, valueType, parameterType, columnSize, decimalDigits, parameterValuePtr, bufferLength, strLenOrIndPtr); +} + +SQLRETURN SQL_API SQLEndTran(SQLSMALLINT handleType, SQLHANDLE handle, SQLSMALLINT completionType) { + if (!handle) { + return SQL_INVALID_HANDLE; + } + try { + switch (handleType) { + case SQL_HANDLE_DBC: { + auto conn = static_cast(handle); + if (completionType == SQL_COMMIT) { + return conn->CommitTx(); + } else if (completionType == SQL_ROLLBACK) { + return conn->RollbackTx(); + } else { + return SQL_ERROR; + } + } + case SQL_HANDLE_STMT: { + auto stmt = static_cast(handle); + auto conn = stmt->GetConnection(); + if (!conn) return SQL_INVALID_HANDLE; + if (completionType == SQL_COMMIT) { + return conn->CommitTx(); + } else if (completionType == SQL_ROLLBACK) { + return conn->RollbackTx(); + } else { + return SQL_ERROR; + } + } + case SQL_HANDLE_ENV: { + // TODO: if's list of connections in ENV, go through them and commit/rollback transactions + return SQL_SUCCESS; + } + default: + return SQL_ERROR; + } + } catch (...) { + return SQL_ERROR; + } +} + +SQLRETURN SQL_API SQLSetConnectAttr(SQLHDBC connectionHandle, SQLINTEGER attribute, SQLPOINTER value, SQLINTEGER stringLength) { + auto conn = static_cast(connectionHandle); + if (!conn) { + return SQL_INVALID_HANDLE; + } + if (attribute == SQL_ATTR_AUTOCOMMIT) { + if ((intptr_t)value == SQL_AUTOCOMMIT_ON) { + return conn->SetAutocommit(true); + } else if ((intptr_t)value == SQL_AUTOCOMMIT_OFF) { + return conn->SetAutocommit(false); + } else { + return SQL_ERROR; + } + } + // TODO: other attributes + return SQL_ERROR; +} + +SQLRETURN SQL_API SQLColumns(SQLHSTMT statementHandle, + SQLCHAR* catalogName, SQLSMALLINT nameLength1, + SQLCHAR* schemaName, SQLSMALLINT nameLength2, + SQLCHAR* tableName, SQLSMALLINT nameLength3, + SQLCHAR* columnName, SQLSMALLINT nameLength4) { + auto stmt = static_cast(statementHandle); + if (!stmt) { + return SQL_INVALID_HANDLE; + } + return stmt->Columns( + NYdb::NOdbc::GetString(catalogName, nameLength1), + NYdb::NOdbc::GetString(schemaName, nameLength2), + NYdb::NOdbc::GetString(tableName, nameLength3), + NYdb::NOdbc::GetString(columnName, nameLength4)); +} + +SQLRETURN SQL_API SQLTables(SQLHSTMT statementHandle, + SQLCHAR* catalogName, SQLSMALLINT nameLength1, + SQLCHAR* schemaName, SQLSMALLINT nameLength2, + SQLCHAR* tableName, SQLSMALLINT nameLength3, + SQLCHAR* tableType, SQLSMALLINT nameLength4) { + auto stmt = static_cast(statementHandle); + if (!stmt) { + return SQL_INVALID_HANDLE; + } + return stmt->Tables( + NYdb::NOdbc::GetString(catalogName, nameLength1), + NYdb::NOdbc::GetString(schemaName, nameLength2), + NYdb::NOdbc::GetString(tableName, nameLength3), + NYdb::NOdbc::GetString(tableType, nameLength4)); +} + +SQLRETURN SQL_API SQLCloseCursor(SQLHSTMT statementHandle) { + auto stmt = static_cast(statementHandle); + if (!stmt) { + return SQL_INVALID_HANDLE; + } + + return stmt->Close(false); +} + +SQLRETURN SQL_API SQLFreeStmt(SQLHSTMT statementHandle, SQLUSMALLINT option) { + auto stmt = static_cast(statementHandle); + if (!stmt) { + return SQL_INVALID_HANDLE; + } + switch (option) { + case SQL_CLOSE: + return stmt->Close(true); + case SQL_DROP: + return SQLFreeHandle(SQL_HANDLE_STMT, statementHandle); + case SQL_UNBIND: + stmt->UnbindColumns(); + return SQL_SUCCESS; + case SQL_RESET_PARAMS: + stmt->ResetParams(); + return SQL_SUCCESS; + default: + return SQL_ERROR; + } +} + +SQLRETURN SQL_API SQLFetchScroll(SQLHSTMT statementHandle, SQLSMALLINT fetchOrientation, SQLLEN fetchOffset) { + auto stmt = static_cast(statementHandle); + if (!stmt) { + return SQL_INVALID_HANDLE; + } + if (fetchOrientation == SQL_FETCH_NEXT) { + return stmt->Fetch(); + } else { + stmt->AddError("HYC00", 0, "Only SQL_FETCH_NEXT is supported"); + return SQL_ERROR; + } +} + +SQLRETURN SQL_API SQLRowCount(SQLHSTMT statementHandle, SQLLEN* rowCount) { + auto stmt = static_cast(statementHandle); + if (!stmt) { + return SQL_INVALID_HANDLE; + } + return stmt->RowCount(rowCount); +} + +SQLRETURN SQL_API SQLNumResultCols(SQLHSTMT statementHandle, SQLSMALLINT* colCount) { + auto stmt = static_cast(statementHandle); + if (!stmt) { + return SQL_INVALID_HANDLE; + } + return stmt->NumResultCols(colCount); +} + +} diff --git a/odbc/src/statement.cpp b/odbc/src/statement.cpp new file mode 100644 index 0000000000..2bf4c78fd1 --- /dev/null +++ b/odbc/src/statement.cpp @@ -0,0 +1,413 @@ +#include "statement.h" + +#include "utils/convert.h" +#include "utils/types.h" + +#include +#include + +namespace NYdb { +namespace NOdbc { + +TStatement::TStatement(TConnection* conn) + : Conn_(conn) {} + +SQLRETURN TStatement::Prepare(const std::string& statementText) { + Cursor_.reset(); + PreparedQuery_ = statementText; + IsPrepared_ = true; + return SQL_SUCCESS; +} + +SQLRETURN TStatement::Execute() { + if (!IsPrepared_ || PreparedQuery_.empty()) { + AddError("HY007", 0, "No prepared statement"); + return SQL_ERROR; + } + Cursor_.reset(); + auto* client = Conn_->GetClient(); + if (!client) { + return SQL_ERROR; + } + NYdb::TParams params = BuildParams(); + if (!Errors_.empty()) { + return SQL_ERROR; + } + if (!Conn_->GetTx()) { + auto sessionResult = client->GetSession().ExtractValueSync(); + if (!sessionResult.IsSuccess()) { + return SQL_ERROR; + } + auto session = sessionResult.GetSession(); + auto beginTxResult = session.BeginTransaction(NQuery::TTxSettings::SerializableRW()).ExtractValueSync(); + if (!beginTxResult.IsSuccess()) { + return SQL_ERROR; + } + Conn_->SetTx(beginTxResult.GetTransaction()); + } + auto session = Conn_->GetTx()->GetSession(); + auto iterator = session.StreamExecuteQuery(PreparedQuery_, + NQuery::TTxControl::Tx(*Conn_->GetTx()).CommitTx(Conn_->GetAutocommit()), params).ExtractValueSync(); + if (!iterator.IsSuccess()) { + return SQL_ERROR; + } + Cursor_ = CreateExecCursor(this, std::move(iterator)); + IsPrepared_ = false; + PreparedQuery_.clear(); + return SQL_SUCCESS; +} + +SQLRETURN TStatement::Fetch() { + if (!Cursor_) { + Cursor_.reset(); + return SQL_NO_DATA; + } + return Cursor_->Fetch() ? SQL_SUCCESS : SQL_NO_DATA; +} + +SQLRETURN TStatement::GetData(SQLUSMALLINT columnNumber, SQLSMALLINT targetType, + SQLPOINTER targetValue, SQLLEN bufferLength, SQLLEN* strLenOrInd) { + if (!Cursor_) { + return SQL_NO_DATA; + } + return Cursor_->GetData(columnNumber, targetType, targetValue, bufferLength, strLenOrInd); +} + +void TStatement::FillBoundColumns() { + if (!Cursor_) { + return; + } + for (const auto& col : BoundColumns_) { + Cursor_->GetData(col.ColumnNumber, col.TargetType, col.TargetValue, col.BufferLength, col.StrLenOrInd); + } +} + +SQLRETURN TStatement::GetDiagRec(SQLSMALLINT recNumber, SQLCHAR* sqlState, SQLINTEGER* nativeError, + SQLCHAR* messageText, SQLSMALLINT bufferLength, SQLSMALLINT* textLength) { + + if (recNumber < 1 || recNumber > (SQLSMALLINT)Errors_.size()) { + return SQL_NO_DATA; + } + + const auto& err = Errors_[recNumber-1]; + if (sqlState) { + strncpy((char*)sqlState, err.SqlState.c_str(), 6); + } + + if (nativeError) { + *nativeError = err.NativeError; + } + + if (messageText && bufferLength > 0) { + strncpy((char*)messageText, err.Message.c_str(), bufferLength); + if (textLength) { + *textLength = (SQLSMALLINT)std::min((int)err.Message.size(), (int)bufferLength); + } + } + return SQL_SUCCESS; +} + +SQLRETURN TStatement::BindCol(SQLUSMALLINT columnNumber, SQLSMALLINT targetType, SQLPOINTER targetValue, SQLLEN bufferLength, SQLLEN* strLenOrInd) { + if (!Cursor_) { + return SQL_NO_DATA; + } + + BoundColumns_.erase(std::remove_if(BoundColumns_.begin(), BoundColumns_.end(), + [columnNumber](const TBoundColumn& col) { return col.ColumnNumber == columnNumber; }), BoundColumns_.end()); + + if (!targetValue) { + return SQL_SUCCESS; + } + BoundColumns_.push_back({columnNumber, targetType, targetValue, bufferLength, strLenOrInd}); + return SQL_SUCCESS; +} + +SQLRETURN TStatement::BindParameter(SQLUSMALLINT paramNumber, + SQLSMALLINT inputOutputType, + SQLSMALLINT valueType, + SQLSMALLINT parameterType, + SQLULEN columnSize, + SQLSMALLINT decimalDigits, + SQLPOINTER parameterValuePtr, + SQLLEN bufferLength, + SQLLEN* strLenOrIndPtr) { + + if (inputOutputType != SQL_PARAM_INPUT) { + AddError("HYC00", 0, "Only input parameters are supported"); + return SQL_ERROR; + } + + BoundParams_.erase(std::remove_if(BoundParams_.begin(), BoundParams_.end(), + [paramNumber](const TBoundParam& p) { return p.ParamNumber == paramNumber; }), BoundParams_.end()); + + if (!parameterValuePtr) { + return SQL_SUCCESS; + } + BoundParams_.push_back({paramNumber, inputOutputType, valueType, parameterType, columnSize, decimalDigits, parameterValuePtr, bufferLength, strLenOrIndPtr}); + return SQL_SUCCESS; +} + +void TStatement::AddError(const std::string& sqlState, SQLINTEGER nativeError, const std::string& message) { + Errors_.push_back({sqlState, nativeError, message}); +} + +NYdb::TParams TStatement::BuildParams() { + Errors_.clear(); + NYdb::TParamsBuilder paramsBuilder; + for (const auto& param : BoundParams_) { + std::string paramName = "$p" + std::to_string(param.ParamNumber); + ConvertParam(param, paramsBuilder.AddParam(paramName)); + } + + return paramsBuilder.Build(); +} + +SQLRETURN TStatement::Columns(const std::string& catalogName, + const std::string& schemaName, + const std::string& tableName, + const std::string& columnName) { + Errors_.clear(); + Cursor_.reset(); + + std::vector columns = { + {"TABLE_CAT", SQL_VARCHAR, 128, SQL_NULLABLE}, + {"TABLE_SCHEM", SQL_VARCHAR, 128, SQL_NULLABLE}, + {"TABLE_NAME", SQL_VARCHAR, 128, SQL_NO_NULLS}, + {"COLUMN_NAME", SQL_VARCHAR, 128, SQL_NO_NULLS}, + {"DATA_TYPE", SQL_INTEGER, 0, SQL_NO_NULLS}, + {"TYPE_NAME", SQL_VARCHAR, 128, SQL_NO_NULLS}, + {"COLUMN_SIZE", SQL_INTEGER, 0, SQL_NULLABLE}, + {"BUFFER_LENGTH", SQL_INTEGER, 0, SQL_NULLABLE}, + {"DECIMAL_DIGITS", SQL_INTEGER, 0, SQL_NULLABLE}, + {"NUM_PREC_RADIX", SQL_INTEGER, 0, SQL_NULLABLE}, + {"NULLABLE", SQL_INTEGER, 0, SQL_NO_NULLS}, + {"REMARKS", SQL_VARCHAR, 762, SQL_NULLABLE}, + {"COLUMN_DEF", SQL_VARCHAR, 254, SQL_NULLABLE}, + {"SQL_DATA_TYPE", SQL_INTEGER, 0, SQL_NO_NULLS}, + {"SQL_DATETIME_SUB", SQL_INTEGER, 0, SQL_NULLABLE}, + {"CHAR_OCTET_LENGTH", SQL_INTEGER, 0, SQL_NULLABLE}, + {"ORDINAL_POSITION", SQL_INTEGER, 0, SQL_NO_NULLS}, + {"IS_NULLABLE", SQL_VARCHAR, 254, SQL_NO_NULLS} + }; + + auto entries = GetPatternEntries(tableName); + if (entries.empty()) { + AddError("HYC00", 0, "No tables found"); + return SQL_ERROR; + } + + TTable table; + table.reserve(entries.size()); + + for (const auto& entry : entries) { + if (entry.Type != NScheme::ESchemeEntryType::Table && + entry.Type != NScheme::ESchemeEntryType::ColumnTable) { + continue; + } + + auto status = Conn_->GetTableClient()->RetryOperationSync([path = entry.Name, &table, &columnName](NTable::TSession session) -> TStatus { + auto result = session.DescribeTable(path).ExtractValueSync(); + if (!result.IsSuccess()) { + return result; + } + auto columns = result.GetTableDescription().GetTableColumns(); + + auto columnIt = std::find_if(columns.begin(), columns.end(), [&columnName](const NTable::TTableColumn& column) { + return column.Name == columnName; + }); + + if (columnIt == columns.end()) { + return TStatus(EStatus::NOT_FOUND, { NYdb::NIssue::TIssue("Column not found") }); + } + + auto column = *columnIt; + + TTypeParser typeParser(column.Type); + + table.push_back({ + TValueBuilder().OptionalUtf8(std::nullopt).Build(), + TValueBuilder().OptionalUtf8(std::nullopt).Build(), + TValueBuilder().Utf8(path).Build(), + TValueBuilder().Utf8(column.Name).Build(), + TValueBuilder().Int16(GetTypeId(column.Type)).Build(), + TValueBuilder().Utf8(column.Type.ToString()).Build(), + TValueBuilder().OptionalInt32(std::nullopt).Build(), + TValueBuilder().OptionalInt32(std::nullopt).Build(), + TValueBuilder().OptionalInt16(GetDecimalDigits(column.Type)).Build(), + TValueBuilder().OptionalInt16(GetRadix(column.Type)).Build(), + TValueBuilder().Int16(column.NotNull && *column.NotNull ? SQL_NO_NULLS : SQL_NULLABLE).Build(), + TValueBuilder().OptionalUtf8(std::nullopt).Build(), + TValueBuilder().OptionalUtf8(std::nullopt).Build(), + TValueBuilder().Int16(GetTypeId(column.Type)).Build(), + TValueBuilder().OptionalInt16(std::nullopt).Build(), + TValueBuilder().OptionalInt32(8).Build(), + TValueBuilder().OptionalInt32(columnIt - columns.begin() + 1).Build(), + TValueBuilder().Utf8(column.NotNull && *column.NotNull ? "NO" : "YES").Build(), + }); + return TStatus(EStatus::SUCCESS, {}); + }); + + if (!status.IsSuccess()) { + return SQL_ERROR; + } + } + + Cursor_ = CreateVirtualCursor(this, columns, table); + return SQL_SUCCESS; +} + +SQLRETURN TStatement::Tables(const std::string& catalogName, + const std::string& schemaName, + const std::string& tableName, + const std::string& tableType) { + Errors_.clear(); + Cursor_.reset(); + + std::vector columns = { + {"TABLE_CAT", SQL_VARCHAR, 128, SQL_NULLABLE}, + {"TABLE_SCHEM", SQL_VARCHAR, 128, SQL_NULLABLE}, + {"TABLE_NAME", SQL_VARCHAR, 128, SQL_NO_NULLS}, + {"TABLE_TYPE", SQL_VARCHAR, 128, SQL_NO_NULLS}, + {"REMARKS", SQL_VARCHAR, 254, SQL_NULLABLE} + }; + + auto entries = GetPatternEntries(tableName); + if (entries.empty()) { + AddError("HYC00", 0, "No tables found"); + return SQL_ERROR; + } + + TTable table; + table.reserve(entries.size()); + + for (const auto& entry : entries) { + auto tableType = GetTableType(entry.Type); + if (!tableType) { + continue; + } + + table.push_back({ + TValueBuilder().OptionalUtf8(std::nullopt).Build(), + TValueBuilder().OptionalUtf8(std::nullopt).Build(), + TValueBuilder().Utf8(entry.Name).Build(), + TValueBuilder().Utf8(*tableType).Build(), + TValueBuilder().OptionalUtf8(std::nullopt).Build(), + }); + } + + Cursor_ = CreateVirtualCursor(this, columns, table); + return SQL_SUCCESS; +} + +std::vector TStatement::GetPatternEntries(const std::string& pattern) { + std::vector entries; + VisitEntry("", pattern, entries); + return entries; +} + +SQLRETURN TStatement::VisitEntry(const std::string& path, const std::string& pattern, std::vector& resultEntries) { + auto schemeClient = Conn_->GetSchemeClient(); + auto listDirectoryResult = schemeClient->ListDirectory(path + "/").ExtractValueSync(); + if (!listDirectoryResult.IsSuccess()) { + return SQL_ERROR; + } + for (const auto& entry : listDirectoryResult.GetChildren()) { + std::string fullPath = path + "/" + entry.Name; + if (entry.Type == NScheme::ESchemeEntryType::Directory || + entry.Type == NScheme::ESchemeEntryType::SubDomain) { + VisitEntry(fullPath, pattern, resultEntries); + } else if (IsPatternMatch(fullPath, pattern)) { + NScheme::TSchemeEntry entryCopy = entry; + entryCopy.Name = fullPath; + resultEntries.push_back(entryCopy); + } + } + return SQL_SUCCESS; +} + +bool TStatement::IsPatternMatch(const std::string& path, const std::string& pattern) { + return path.starts_with(pattern); +} + +std::optional TStatement::GetTableType(NScheme::ESchemeEntryType type) { + switch (type) { + case NScheme::ESchemeEntryType::Table: + return "TABLE"; + case NScheme::ESchemeEntryType::View: + return "VIEW"; + case NScheme::ESchemeEntryType::ColumnStore: + return "COLUMN_STORE"; + case NScheme::ESchemeEntryType::ColumnTable: + return "COLUMN_TABLE"; + case NScheme::ESchemeEntryType::Sequence: + return "SEQUENCE"; + case NScheme::ESchemeEntryType::Replication: + return "REPLICATION"; + case NScheme::ESchemeEntryType::Topic: + return "TOPIC"; + case NScheme::ESchemeEntryType::ExternalTable: + return "EXTERNAL_TABLE"; + case NScheme::ESchemeEntryType::ExternalDataSource: + return "EXTERNAL_DATA_SOURCE"; + case NScheme::ESchemeEntryType::ResourcePool: + return "RESOURCE_POOL"; + case NScheme::ESchemeEntryType::PqGroup: + return "PQ_GROUP"; + case NScheme::ESchemeEntryType::RtmrVolume: + return "RTMR_VOLUME"; + case NScheme::ESchemeEntryType::BlockStoreVolume: + return "BLOCK_STORE_VOLUME"; + case NScheme::ESchemeEntryType::CoordinationNode: + return "COORDINATION_NODE"; + case NScheme::ESchemeEntryType::Unknown: + return "UNKNOWN"; + case NScheme::ESchemeEntryType::Directory: + case NScheme::ESchemeEntryType::SubDomain: + return std::nullopt; + } +} + +SQLRETURN TStatement::Close(bool force) { + if (!force && !Cursor_) { + AddError("24000", 0, "Invalid handle"); + return SQL_ERROR; + } + + Cursor_.reset(); + PreparedQuery_.clear(); + IsPrepared_ = false; + Errors_.clear(); + return SQL_SUCCESS; +} + +void TStatement::UnbindColumns() { + BoundColumns_.clear(); +} + +void TStatement::ResetParams() { + BoundParams_.clear(); +} + +SQLRETURN TStatement::RowCount(SQLLEN* rowCount) { + if (!rowCount) { + return SQL_ERROR; + } + + *rowCount = -1; + return SQL_SUCCESS; +} + +SQLRETURN TStatement::NumResultCols(SQLSMALLINT* colCount) { + if (!colCount) { + return SQL_ERROR; + } + if (!Cursor_) { + *colCount = 0; + return SQL_SUCCESS; + } + *colCount = static_cast(Cursor_->GetColumnMeta().size()); + return SQL_SUCCESS; +} + +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/statement.h b/odbc/src/statement.h new file mode 100644 index 0000000000..d6b47a1146 --- /dev/null +++ b/odbc/src/statement.h @@ -0,0 +1,85 @@ +#pragma once + +#include "connection.h" + +#include "utils/bindings.h" +#include "utils/cursor.h" + +#include + +#include +#include + +#include +#include +#include + + +namespace NYdb { +namespace NOdbc { + +class TStatement : public IBindingFiller { +public: + TStatement(TConnection* conn); + + SQLRETURN Prepare(const std::string& statementText); + SQLRETURN Execute(); + + SQLRETURN Fetch(); + SQLRETURN GetData(SQLUSMALLINT columnNumber, SQLSMALLINT targetType, + SQLPOINTER targetValue, SQLLEN bufferLength, SQLLEN* strLenOrInd); + + void FillBoundColumns() override; + + SQLRETURN Close(bool force = false); + void UnbindColumns(); + void ResetParams(); + + SQLRETURN GetDiagRec(SQLSMALLINT recNumber, SQLCHAR* sqlState, SQLINTEGER* nativeError, + SQLCHAR* messageText, SQLSMALLINT bufferLength, SQLSMALLINT* textLength); + + SQLRETURN BindCol(SQLUSMALLINT columnNumber, SQLSMALLINT targetType, SQLPOINTER targetValue, SQLLEN bufferLength, SQLLEN* strLenOrInd); + SQLRETURN BindParameter(SQLUSMALLINT paramNumber, SQLSMALLINT inputOutputType, SQLSMALLINT valueType, SQLSMALLINT parameterType, SQLULEN columnSize, SQLSMALLINT decimalDigits, SQLPOINTER parameterValuePtr, SQLLEN bufferLength, SQLLEN* strLenOrIndPtr); + + SQLRETURN Columns(const std::string& catalogName, + const std::string& schemaName, + const std::string& tableName, + const std::string& columnName); + + SQLRETURN Tables(const std::string& catalogName, + const std::string& schemaName, + const std::string& tableName, + const std::string& tableType); + + SQLRETURN RowCount(SQLLEN* rowCount); + SQLRETURN NumResultCols(SQLSMALLINT* colCount); + + TConnection* GetConnection() { + return Conn_; + } + + void AddError(const std::string& sqlState, SQLINTEGER nativeError, const std::string& message); + + NYdb::TParams BuildParams(); + +private: + std::vector GetPatternEntries(const std::string& pattern); + SQLRETURN VisitEntry(const std::string& path, const std::string& pattern, std::vector& resultEntries); + bool IsPatternMatch(const std::string& path, const std::string& pattern); + + std::optional GetTableType(NScheme::ESchemeEntryType type); + + TConnection* Conn_; + TErrorList Errors_; + + std::unique_ptr Cursor_; + + std::vector BoundColumns_; + std::vector BoundParams_; + + std::string PreparedQuery_; + bool IsPrepared_ = false; +}; + +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/utils/bindings.h b/odbc/src/utils/bindings.h new file mode 100644 index 0000000000..df76de4e95 --- /dev/null +++ b/odbc/src/utils/bindings.h @@ -0,0 +1,37 @@ +#pragma once + +#include +#include + +namespace NYdb { +namespace NOdbc { + +struct TBoundParam { + SQLUSMALLINT ParamNumber; + SQLSMALLINT InputOutputType; + SQLSMALLINT ValueType; + SQLSMALLINT ParameterType; + SQLULEN ColumnSize; + SQLSMALLINT DecimalDigits; + SQLPOINTER ParameterValuePtr; + SQLLEN BufferLength; + SQLLEN* StrLenOrIndPtr; +}; + +struct TBoundColumn { + SQLUSMALLINT ColumnNumber; + SQLSMALLINT TargetType; + SQLPOINTER TargetValue; + SQLLEN BufferLength; + SQLLEN* StrLenOrInd; +}; + +class IBindingFiller { +public: + virtual void FillBoundColumns() = 0; + + virtual ~IBindingFiller() = default; +}; + +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/utils/convert.cpp b/odbc/src/utils/convert.cpp new file mode 100644 index 0000000000..224f228e49 --- /dev/null +++ b/odbc/src/utils/convert.cpp @@ -0,0 +1,410 @@ +#include "convert.h" + +#include + +namespace NYdb { +namespace NOdbc { + +template +struct TSqlTypeTraits; + +template<> struct TSqlTypeTraits { using Type = std::string; }; +template<> struct TSqlTypeTraits { using Type = std::string; }; +template<> struct TSqlTypeTraits { using Type = SQLBIGINT; }; +template<> struct TSqlTypeTraits { using Type = SQLUBIGINT; }; +template<> struct TSqlTypeTraits { using Type = SQLINTEGER; }; +template<> struct TSqlTypeTraits { using Type = SQLUINTEGER; }; +template<> struct TSqlTypeTraits { using Type = SQLSMALLINT; }; +template<> struct TSqlTypeTraits { using Type = SQLSMALLINT; }; +template<> struct TSqlTypeTraits { using Type = SQLUSMALLINT; }; +template<> struct TSqlTypeTraits { using Type = SQLSCHAR; }; +template<> struct TSqlTypeTraits { using Type = SQLCHAR; }; +template<> struct TSqlTypeTraits { using Type = SQLDOUBLE; }; +template<> struct TSqlTypeTraits { using Type = SQLFLOAT; }; +template<> struct TSqlTypeTraits { using Type = SQLCHAR; }; + +template +struct TTypedValue { + using TSrcType = typename TSqlTypeTraits::Type; + + TSrcType Data; + + TTypedValue(const TBoundParam& param) { + Data = *static_cast(param.ParameterValuePtr); + } +}; + +template<> +TTypedValue::TTypedValue(const TBoundParam& param) { + Data = std::string(static_cast(param.ParameterValuePtr), param.BufferLength); +} + +template<> +TTypedValue::TTypedValue(const TBoundParam& param) { + Data = std::string(static_cast(param.ParameterValuePtr), param.BufferLength); +} + +class IConverter { +public: + virtual void AddToBuilder(const TBoundParam& param, TParamValueBuilder& builder) = 0; + + virtual ~IConverter() = default; +}; + +template +class TConverter : public IConverter { +public: + virtual void AddToBuilder(const TBoundParam& param, TParamValueBuilder& builder) override { + TTypedValue value(param); + Convert(param, std::move(value.Data), builder); + if (param.StrLenOrIndPtr && *param.StrLenOrIndPtr == SQL_NULL_DATA) { + builder.EmptyOptional(GetType()); + } + builder.Build(); + } + +private: + void Convert(const TBoundParam& param, TTypedValue::TSrcType&& data, TParamValueBuilder& builder); + TType GetType(); +}; + +class TConverterRegistry { +public: + static TConverterRegistry& GetInstance() { + static TConverterRegistry instance; + return instance; + } + + void RegisterConverter(SQLSMALLINT cType, SQLSMALLINT sqlType, std::unique_ptr converter) { + Converters_.emplace(std::make_pair(cType, sqlType), std::move(converter)); + } + + IConverter* GetConverter(SQLSMALLINT cType, SQLSMALLINT sqlType) { + auto it = Converters_.find(std::make_pair(cType, sqlType)); + if (it != Converters_.end()) { + return it->second.get(); + } + return nullptr; + } + +private: + std::map, std::unique_ptr> Converters_; +}; + +#define REGISTER_CONVERTER(CType, SqlType, YdbType) \ + struct TConverterRegistration##CType##SqlType { \ + TConverterRegistration##CType##SqlType() { \ + TConverterRegistry::GetInstance().RegisterConverter(CType, SqlType, std::make_unique>()); \ + } \ + }; \ + static const TConverterRegistration##CType##SqlType converterRegistration##CType##SqlType; \ + template<> \ + TType TConverter::GetType() { \ + return TTypeBuilder().Primitive(YdbType).Build(); \ + } \ + template<> \ + void TConverter::Convert(const TBoundParam& param, TTypedValue::TSrcType&& data, TParamValueBuilder& builder) + +// Integer types + +REGISTER_CONVERTER(SQL_C_SBIGINT, SQL_BIGINT, EPrimitiveType::Int64) { + builder.OptionalInt64(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_LONG, SQL_BIGINT, EPrimitiveType::Int64) { + builder.OptionalInt64(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_SHORT, SQL_BIGINT, EPrimitiveType::Int64) { + builder.OptionalInt64(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_TINYINT, SQL_BIGINT, EPrimitiveType::Int64) { + builder.OptionalInt64(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_UBIGINT, SQL_BIGINT, EPrimitiveType::Uint64) { + builder.OptionalUint64(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_ULONG, SQL_BIGINT, EPrimitiveType::Uint64) { + builder.OptionalUint64(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_USHORT, SQL_BIGINT, EPrimitiveType::Uint64) { + builder.OptionalUint64(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_UTINYINT, SQL_BIGINT, EPrimitiveType::Uint64) { + builder.OptionalUint64(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_SBIGINT, SQL_INTEGER, EPrimitiveType::Int32) { + builder.OptionalInt32(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_LONG, SQL_INTEGER, EPrimitiveType::Int32) { + builder.OptionalInt32(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_SHORT, SQL_INTEGER, EPrimitiveType::Int32) { + builder.OptionalInt32(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_TINYINT, SQL_INTEGER, EPrimitiveType::Int32) { + builder.OptionalInt32(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_UBIGINT, SQL_INTEGER, EPrimitiveType::Uint32) { + builder.OptionalUint32(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_ULONG, SQL_INTEGER, EPrimitiveType::Uint32) { + builder.OptionalUint32(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_USHORT, SQL_INTEGER, EPrimitiveType::Uint32) { + builder.OptionalUint32(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_UTINYINT, SQL_INTEGER, EPrimitiveType::Uint32) { + builder.OptionalUint32(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_SBIGINT, SQL_SMALLINT, EPrimitiveType::Int16) { + builder.OptionalInt16(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_LONG, SQL_SMALLINT, EPrimitiveType::Int16) { + builder.OptionalInt16(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_SHORT, SQL_SMALLINT, EPrimitiveType::Int16) { + builder.OptionalInt16(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_TINYINT, SQL_SMALLINT, EPrimitiveType::Int16) { + builder.OptionalInt16(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_UBIGINT, SQL_SMALLINT, EPrimitiveType::Uint16) { + builder.OptionalUint16(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_ULONG, SQL_SMALLINT, EPrimitiveType::Uint16) { + builder.OptionalUint16(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_USHORT, SQL_SMALLINT, EPrimitiveType::Uint16) { + builder.OptionalUint16(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_UTINYINT, SQL_SMALLINT, EPrimitiveType::Uint16) { + builder.OptionalUint16(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_SBIGINT, SQL_TINYINT, EPrimitiveType::Int8) { + builder.OptionalInt8(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_LONG, SQL_TINYINT, EPrimitiveType::Int8) { + builder.OptionalInt8(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_SHORT, SQL_TINYINT, EPrimitiveType::Int8) { + builder.OptionalInt8(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_TINYINT, SQL_TINYINT, EPrimitiveType::Int8) { + builder.OptionalInt8(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_UBIGINT, SQL_TINYINT, EPrimitiveType::Uint8) { + builder.OptionalUint8(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_ULONG, SQL_TINYINT, EPrimitiveType::Uint8) { + builder.OptionalUint8(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_USHORT, SQL_TINYINT, EPrimitiveType::Uint8) { + builder.OptionalUint8(static_cast(data)); +} + +REGISTER_CONVERTER(SQL_C_UTINYINT, SQL_TINYINT, EPrimitiveType::Uint8) { + builder.OptionalUint8(static_cast(data)); +} + +// Floating point types + +REGISTER_CONVERTER(SQL_C_FLOAT, SQL_REAL, EPrimitiveType::Float) { + builder.OptionalFloat(data); +} + +REGISTER_CONVERTER(SQL_C_DOUBLE, SQL_FLOAT, EPrimitiveType::Double) { + builder.OptionalDouble(data); +} + +REGISTER_CONVERTER(SQL_C_DOUBLE, SQL_DOUBLE, EPrimitiveType::Double) { + builder.OptionalDouble(data); +} + +// String types + +REGISTER_CONVERTER(SQL_C_CHAR, SQL_CHAR, EPrimitiveType::Utf8) { + builder.OptionalUtf8(std::move(data)); +} + +REGISTER_CONVERTER(SQL_C_CHAR, SQL_VARCHAR, EPrimitiveType::Utf8) { + builder.OptionalUtf8(std::move(data)); +} + +REGISTER_CONVERTER(SQL_C_CHAR, SQL_LONGVARCHAR, EPrimitiveType::Utf8) { + builder.OptionalUtf8(std::move(data)); +} + +// Binary types + +REGISTER_CONVERTER(SQL_C_BINARY, SQL_BINARY, EPrimitiveType::String) { + builder.OptionalString(std::move(data)); +} + +REGISTER_CONVERTER(SQL_C_BINARY, SQL_VARBINARY, EPrimitiveType::String) { + builder.OptionalString(std::move(data)); +} + +REGISTER_CONVERTER(SQL_C_BINARY, SQL_LONGVARBINARY, EPrimitiveType::String) { + builder.OptionalString(std::move(data)); +} + +#undef REGISTER_CONVERTER + +SQLRETURN ConvertParam(const TBoundParam& param, TParamValueBuilder& builder) { + auto converter = TConverterRegistry::GetInstance().GetConverter(param.ValueType, param.ParameterType); + if (!converter) { + return SQL_ERROR; + } + + converter->AddToBuilder(param, builder); + return SQL_SUCCESS; +} + +SQLRETURN ConvertColumn(TValueParser& parser, SQLSMALLINT targetType, SQLPOINTER targetValue, SQLLEN bufferLength, SQLLEN* strLenOrInd) { + if (parser.IsNull()) { + if (strLenOrInd) { + *strLenOrInd = SQL_NULL_DATA; + } + return SQL_SUCCESS; + } + + if (parser.GetKind() == TTypeParser::ETypeKind::Optional) { + parser.OpenOptional(); + SQLRETURN ret = ConvertColumn(parser, targetType, targetValue, bufferLength, strLenOrInd); + parser.CloseOptional(); + return ret; + } + + if (parser.GetKind() != TTypeParser::ETypeKind::Primitive) { + return SQL_ERROR; + } + + EPrimitiveType ydbType = parser.GetPrimitiveType(); + + switch (targetType) { + case SQL_C_SLONG: + case SQL_C_LONG: + { + int32_t v = 0; + switch (ydbType) { + case EPrimitiveType::Int16: v = static_cast(parser.GetInt16()); break; + case EPrimitiveType::Uint16: v = static_cast(parser.GetUint16()); break; + case EPrimitiveType::Int8: v = static_cast(parser.GetInt8()); break; + case EPrimitiveType::Uint8: v = static_cast(parser.GetUint8()); break; + case EPrimitiveType::Int32: v = static_cast(parser.GetInt32()); break; + case EPrimitiveType::Uint32: v = static_cast(parser.GetUint32()); break; + case EPrimitiveType::Int64: v = static_cast(parser.GetInt64()); break; + case EPrimitiveType::Uint64: v = static_cast(parser.GetUint64()); break; + case EPrimitiveType::Bool: v = parser.GetBool() ? 1 : 0; break; + default: return SQL_ERROR; + } + if (targetValue) { + *reinterpret_cast(targetValue) = v; + } + if (strLenOrInd) { + *strLenOrInd = sizeof(int32_t); + } + return SQL_SUCCESS; + } + case SQL_C_SBIGINT: + { + SQLBIGINT v = 0; + switch (ydbType) { + case EPrimitiveType::Int64: v = parser.GetInt64(); break; + case EPrimitiveType::Uint64: v = static_cast(parser.GetUint64()); break; + case EPrimitiveType::Int32: v = static_cast(parser.GetInt32()); break; + case EPrimitiveType::Uint32: v = static_cast(parser.GetUint32()); break; + default: return SQL_ERROR; + } + if (targetValue) { + *reinterpret_cast(targetValue) = v; + } + if (strLenOrInd) { + *strLenOrInd = sizeof(SQLBIGINT); + } + return SQL_SUCCESS; + } + case SQL_C_DOUBLE: + { + double v = 0.0; + switch (ydbType) { + case EPrimitiveType::Double: v = parser.GetDouble(); break; + case EPrimitiveType::Float: v = parser.GetFloat(); break; + default: return SQL_ERROR; + } + if (targetValue) { + *reinterpret_cast(targetValue) = v; + } + if (strLenOrInd) { + *strLenOrInd = sizeof(double); + } + return SQL_SUCCESS; + } + case SQL_C_CHAR: + { + std::string str; + switch (ydbType) { + case EPrimitiveType::Utf8: str = parser.GetUtf8(); break; + case EPrimitiveType::String: str = parser.GetString(); break; + case EPrimitiveType::Json: str = parser.GetJson(); break; + case EPrimitiveType::JsonDocument: str = parser.GetJsonDocument(); break; + default: return SQL_ERROR; + } + SQLLEN len = str.size(); + if (targetValue && bufferLength > 0) { + SQLLEN copyLen = std::min(len, bufferLength - 1); + memcpy(targetValue, str.data(), copyLen); + reinterpret_cast(targetValue)[copyLen] = 0; + } + if (strLenOrInd) { + *strLenOrInd = len; + } + return SQL_SUCCESS; + } + case SQL_C_BIT: + { + char v = parser.GetBool() ? 1 : 0; + if (targetValue) { + *reinterpret_cast(targetValue) = v; + } + if (strLenOrInd) { + *strLenOrInd = sizeof(char); + } + return SQL_SUCCESS; + } + default: + return SQL_ERROR; + } +} + +} // namespace NYdb +} // namespace NOdbc diff --git a/odbc/src/utils/convert.h b/odbc/src/utils/convert.h new file mode 100644 index 0000000000..9b8140665e --- /dev/null +++ b/odbc/src/utils/convert.h @@ -0,0 +1,17 @@ +#pragma once + +#include "bindings.h" + +#include + +#include +#include + +namespace NYdb { +namespace NOdbc { + +SQLRETURN ConvertParam(const TBoundParam& param, TParamValueBuilder& builder); +SQLRETURN ConvertColumn(TValueParser& parser, SQLSMALLINT targetType, SQLPOINTER targetValue, SQLLEN bufferLength, SQLLEN* strLenOrInd); + +} // namespace NYdb +} // namespace NOdbc diff --git a/odbc/src/utils/cursor.cpp b/odbc/src/utils/cursor.cpp new file mode 100644 index 0000000000..fbd10588ab --- /dev/null +++ b/odbc/src/utils/cursor.cpp @@ -0,0 +1,119 @@ +#include "cursor.h" + +#include "convert.h" +#include "types.h" + +namespace NYdb { +namespace NOdbc { + +class TExecCursor : public ICursor { +public: + TExecCursor(IBindingFiller* bindingFiller, NQuery::TExecuteQueryIterator iterator) + : BindingFiller_(bindingFiller) + , Iterator_(std::move(iterator)) + {} + + bool Fetch() override { + while (true) { + if (ResultSetParser_) { + if (ResultSetParser_->TryNextRow()) { + BindingFiller_->FillBoundColumns(); + return true; + } + ResultSetParser_.reset(); + } + auto part = Iterator_.ReadNext().ExtractValueSync(); + if (part.EOS()) { + return false; + } + if (!part.IsSuccess()) { + return false; + } + if (part.HasResultSet()) { + ResultSetParser_ = std::make_unique(part.ExtractResultSet()); + } + } + return false; + } + + SQLRETURN GetData(SQLUSMALLINT columnNumber, SQLSMALLINT targetType, + SQLPOINTER targetValue, SQLLEN bufferLength, SQLLEN* strLenOrInd) override { + if (!ResultSetParser_) { + return SQL_NO_DATA; + } + if (columnNumber < 1 || columnNumber > ResultSetParser_->ColumnsCount()) { + return SQL_ERROR; + } + return ConvertColumn(ResultSetParser_->ColumnParser(columnNumber - 1), targetType, targetValue, bufferLength, strLenOrInd); + } + + const std::vector& GetColumnMeta() const override { + return Columns_; + } + +private: + // void GetNextPart() { + // auto part = Iterator_.ReadNext().ExtractValueSync(); + // while (!part.EOS() && part.IsSuccess() && !part.HasResultSet()) { + // part = Iterator_.ReadNext().ExtractValueSync(); + // } + // Part_ = std::move(part); + // } + + IBindingFiller* BindingFiller_; + NQuery::TExecuteQueryIterator Iterator_; + // std::optional Part_; + std::unique_ptr ResultSetParser_; + std::vector Columns_; +}; + +class TVirtualCursor : public ICursor { +public: + TVirtualCursor(IBindingFiller* bindingFiller, const std::vector& columns, const TTable& table) + : BindingFiller_(bindingFiller) + , Columns_(columns) + , Table_(table) + {} + + bool Fetch() override { + Cursor_++; + if (Cursor_ >= static_cast(Table_.size())) { + return false; + } + BindingFiller_->FillBoundColumns(); + return true; + } + + SQLRETURN GetData(SQLUSMALLINT columnNumber, SQLSMALLINT targetType, + SQLPOINTER targetValue, SQLLEN bufferLength, SQLLEN* strLenOrInd) override { + if (Cursor_ >= static_cast(Table_.size())) { + return SQL_NO_DATA; + } + if (Cursor_ < 0 || columnNumber < 1 || columnNumber > Columns_.size()) { + return SQL_ERROR; + } + TValueParser parser{Table_[Cursor_][columnNumber - 1]}; + return ConvertColumn(parser, targetType, targetValue, bufferLength, strLenOrInd); + } + + const std::vector& GetColumnMeta() const override { + return Columns_; + } + +private: + IBindingFiller* BindingFiller_; + std::vector Columns_; + TTable Table_; + int64_t Cursor_ = -1; +}; + +std::unique_ptr CreateExecCursor(IBindingFiller* bindingFiller, NQuery::TExecuteQueryIterator iterator) { + return std::make_unique(bindingFiller, std::move(iterator)); +} + +std::unique_ptr CreateVirtualCursor(IBindingFiller* bindingFiller, const std::vector& columns, const TTable& table) { + return std::make_unique(bindingFiller, columns, table); +} + +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/utils/cursor.h b/odbc/src/utils/cursor.h new file mode 100644 index 0000000000..e4b13ed521 --- /dev/null +++ b/odbc/src/utils/cursor.h @@ -0,0 +1,37 @@ +#pragma once + +#include "bindings.h" + +#include + +#include + +#include +#include + +namespace NYdb { +namespace NOdbc { + +struct TColumnMeta { + std::string Name; + SQLSMALLINT SqlType; + SQLULEN Size; + SQLSMALLINT Nullable; +}; + +using TTable = std::vector>; + +class ICursor { +public: + virtual ~ICursor() = default; + virtual bool Fetch() = 0; + virtual SQLRETURN GetData(SQLUSMALLINT columnNumber, SQLSMALLINT targetType, + SQLPOINTER targetValue, SQLLEN bufferLength, SQLLEN* strLenOrInd) = 0; + virtual const std::vector& GetColumnMeta() const = 0; +}; + +std::unique_ptr CreateExecCursor(IBindingFiller* bindingFiller, NYdb::NQuery::TExecuteQueryIterator iterator); +std::unique_ptr CreateVirtualCursor(IBindingFiller* bindingFiller, const std::vector& columns, const TTable& table); + +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/utils/types.cpp b/odbc/src/utils/types.cpp new file mode 100644 index 0000000000..ce5ead462c --- /dev/null +++ b/odbc/src/utils/types.cpp @@ -0,0 +1,70 @@ +#include "types.h" + +namespace NYdb { +namespace NOdbc { + +SQLSMALLINT GetTypeId(const TType& type) { + // TODO: implement + return 0; +} + +SQLSMALLINT IsNullable(const TType& type) { + TTypeParser typeParser(type); + if (typeParser.GetKind() == TTypeParser::ETypeKind::Optional || typeParser.GetKind() == TTypeParser::ETypeKind::Null) { + return SQL_NULLABLE; + } + + return SQL_NO_NULLS; +} + +std::optional GetDecimalDigits(const TType& type) { + TTypeParser typeParser(type); + if (typeParser.GetKind() != TTypeParser::ETypeKind::Primitive) { + return std::nullopt; + } + + switch (typeParser.GetPrimitive()) { + case EPrimitiveType::Int64: + return 64; + case EPrimitiveType::Uint64: + return 64; + case EPrimitiveType::Int32: + return 32; + case EPrimitiveType::Uint32: + return 32; + case EPrimitiveType::Int16: + return 16; + case EPrimitiveType::Uint16: + return 16; + case EPrimitiveType::Int8: + return 8; + case EPrimitiveType::Uint8: + return 8; + default: + return std::nullopt; + } +} + +std::optional GetRadix(const TType& type) { + TTypeParser typeParser(type); + if (typeParser.GetKind() != TTypeParser::ETypeKind::Primitive) { + return std::nullopt; + } + + switch (typeParser.GetPrimitive()) { + case EPrimitiveType::Int64: + case EPrimitiveType::Uint64: + case EPrimitiveType::Int32: + case EPrimitiveType::Uint32: + case EPrimitiveType::Int16: + case EPrimitiveType::Uint16: + case EPrimitiveType::Int8: + case EPrimitiveType::Uint8: + return 10; + default: + return std::nullopt; + } +} + +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/utils/types.h b/odbc/src/utils/types.h new file mode 100644 index 0000000000..3f48170290 --- /dev/null +++ b/odbc/src/utils/types.h @@ -0,0 +1,17 @@ +#pragma once + +#include + +#include + +namespace NYdb { +namespace NOdbc { + +SQLSMALLINT GetTypeId(const TType& type); +SQLSMALLINT IsNullable(const TType& type); + +std::optional GetDecimalDigits(const TType& type); +std::optional GetRadix(const TType& type); + +} // namespace NOdbc +} // namespace NYdb diff --git a/odbc/src/utils/util.cpp b/odbc/src/utils/util.cpp new file mode 100644 index 0000000000..9097ce80db --- /dev/null +++ b/odbc/src/utils/util.cpp @@ -0,0 +1,12 @@ +#include "util.h" + +namespace NYdb::NOdbc { + +std::string GetString(SQLCHAR* str, SQLSMALLINT length) { + if (length == SQL_NTS) { + return std::string(reinterpret_cast(str)); + } + return std::string(reinterpret_cast(str), length); +} + +} // namespace NYdb::NOdbc diff --git a/odbc/src/utils/util.h b/odbc/src/utils/util.h new file mode 100644 index 0000000000..b17fe2c235 --- /dev/null +++ b/odbc/src/utils/util.h @@ -0,0 +1,12 @@ +#pragma once + +#include +#include + +#include + +namespace NYdb::NOdbc { + +std::string GetString(SQLCHAR* str, SQLSMALLINT length); + +} // namespace NYdb::NOdbc diff --git a/odbc/tests/CMakeLists.txt b/odbc/tests/CMakeLists.txt new file mode 100644 index 0000000000..729c6ee077 --- /dev/null +++ b/odbc/tests/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(integration) +add_subdirectory(unit) diff --git a/odbc/tests/integration/CMakeLists.txt b/odbc/tests/integration/CMakeLists.txt new file mode 100644 index 0000000000..e1aad9d391 --- /dev/null +++ b/odbc/tests/integration/CMakeLists.txt @@ -0,0 +1,4 @@ +add_odbc_test(NAME odbc-basic_it + SOURCES + basic_it.cpp +) diff --git a/odbc/tests/integration/basic_it.cpp b/odbc/tests/integration/basic_it.cpp new file mode 100644 index 0000000000..b4c7078ac4 --- /dev/null +++ b/odbc/tests/integration/basic_it.cpp @@ -0,0 +1,123 @@ +#include + +#include +#include + +#include + + +#define CHECK_ODBC_OK(rc, handle, type) \ + ASSERT_TRUE((rc) == SQL_SUCCESS || (rc) == SQL_SUCCESS_WITH_INFO) << GetOdbcError(handle, type) + +std::string GetOdbcError(SQLHANDLE handle, SQLSMALLINT type) { + SQLCHAR sqlState[6], message[256]; + SQLINTEGER nativeError; + SQLSMALLINT textLength; + SQLRETURN rc = SQLGetDiagRec(type, handle, 1, sqlState, &nativeError, message, sizeof(message), &textLength); + if (rc == SQL_SUCCESS || rc == SQL_SUCCESS_WITH_INFO) { + return std::string((char*)sqlState) + ": " + (char*)message; + } + return "Unknown ODBC error"; +} + +const char* kConnStr = "Driver=" ODBC_DRIVER_PATH ";Endpoint=localhost:2136;Database=/local;"; + +TEST(OdbcBasic, SimpleQuery) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &env), SQL_SUCCESS); + ASSERT_EQ(SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0), SQL_SUCCESS); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_DBC, env, &dbc), SQL_SUCCESS); + CHECK_ODBC_OK(SQLDriverConnect(dbc, nullptr, (SQLCHAR*)kConnStr, SQL_NTS, nullptr, 0, nullptr, SQL_DRIVER_COMPLETE), dbc, SQL_HANDLE_DBC); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + + // Simple query + CHECK_ODBC_OK(SQLExecDirect(stmt, (SQLCHAR*)"SELECT 1 AS one, 'abc' AS str", SQL_NTS), stmt, SQL_HANDLE_STMT); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + + SQLINTEGER ival = 0; + char sval[16] = {0}; + SQLLEN ival_ind = 0, sval_ind = 0; + ASSERT_EQ(SQLGetData(stmt, 1, SQL_C_LONG, &ival, 0, &ival_ind), SQL_SUCCESS); + ASSERT_EQ(SQLGetData(stmt, 2, SQL_C_CHAR, sval, sizeof(sval), &sval_ind), SQL_SUCCESS); + ASSERT_EQ(ival, 1); + ASSERT_STREQ(sval, "abc"); + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(OdbcBasic, ParameterizedQuery) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &env), SQL_SUCCESS); + ASSERT_EQ(SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0), SQL_SUCCESS); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_DBC, env, &dbc), SQL_SUCCESS); + CHECK_ODBC_OK(SQLDriverConnect(dbc, nullptr, (SQLCHAR*)kConnStr, SQL_NTS, nullptr, 0, nullptr, SQL_DRIVER_COMPLETE), dbc, SQL_HANDLE_DBC); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + + SQLCHAR query[] = R"( + DECLARE $p1 AS Int32?; + SELECT $p1 + 10 AS res; + )"; + + // Parameterized query + CHECK_ODBC_OK(SQLPrepare(stmt, query, SQL_NTS), stmt, SQL_HANDLE_STMT); + SQLINTEGER param = 5; + CHECK_ODBC_OK(SQLBindParameter(stmt, 1, SQL_PARAM_INPUT, SQL_C_LONG, SQL_INTEGER, 0, 0, ¶m, 0, nullptr), stmt, SQL_HANDLE_STMT); + CHECK_ODBC_OK(SQLExecute(stmt), stmt, SQL_HANDLE_STMT); + + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + + SQLINTEGER res = 0; + SQLLEN res_ind = 0; + ASSERT_EQ(SQLGetData(stmt, 1, SQL_C_LONG, &res, 0, &res_ind), SQL_SUCCESS); + ASSERT_EQ(res, 15); + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} + +TEST(OdbcBasic, ColumnBinding) { + SQLHENV env; + SQLHDBC dbc; + SQLHSTMT stmt; + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_ENV, SQL_NULL_HANDLE, &env), SQL_SUCCESS); + ASSERT_EQ(SQLSetEnvAttr(env, SQL_ATTR_ODBC_VERSION, (void*)SQL_OV_ODBC3, 0), SQL_SUCCESS); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_DBC, env, &dbc), SQL_SUCCESS); + CHECK_ODBC_OK(SQLDriverConnect(dbc, nullptr, (SQLCHAR*)kConnStr, SQL_NTS, nullptr, 0, nullptr, SQL_DRIVER_COMPLETE), dbc, SQL_HANDLE_DBC); + ASSERT_EQ(SQLAllocHandle(SQL_HANDLE_STMT, dbc, &stmt), SQL_SUCCESS); + + SQLCHAR query_ddl[] = R"( + DROP TABLE IF EXISTS test_bind; + CREATE TABLE test_bind (id Int32, name Text, PRIMARY KEY (id)); + )"; + + SQLCHAR query[] = R"( + UPSERT INTO test_bind (id, name) VALUES (1, 'foo'), (2, 'bar'); + SELECT id, name FROM test_bind ORDER BY id; + )"; + + CHECK_ODBC_OK(SQLExecDirect(stmt, query_ddl, SQL_NTS), stmt, SQL_HANDLE_STMT); + CHECK_ODBC_OK(SQLExecDirect(stmt, query, SQL_NTS), stmt, SQL_HANDLE_STMT); + + SQLINTEGER id = 0; + char name[16] = {0}; + SQLLEN id_ind = 0, name_ind = 0; + ASSERT_EQ(SQLBindCol(stmt, 1, SQL_C_LONG, &id, 0, &id_ind), SQL_SUCCESS); + ASSERT_EQ(SQLBindCol(stmt, 2, SQL_C_CHAR, name, sizeof(name), &name_ind), SQL_SUCCESS); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + ASSERT_EQ(id, 1); + ASSERT_STREQ(name, "foo"); + ASSERT_EQ(SQLFetch(stmt), SQL_SUCCESS); + ASSERT_EQ(id, 2); + ASSERT_STREQ(name, "bar"); + SQLFreeHandle(SQL_HANDLE_STMT, stmt); + SQLDisconnect(dbc); + SQLFreeHandle(SQL_HANDLE_DBC, dbc); + SQLFreeHandle(SQL_HANDLE_ENV, env); +} diff --git a/odbc/tests/unit/CMakeLists.txt b/odbc/tests/unit/CMakeLists.txt new file mode 100644 index 0000000000..d1eac19961 --- /dev/null +++ b/odbc/tests/unit/CMakeLists.txt @@ -0,0 +1,10 @@ +add_ydb_test(NAME odbc-convert_ut GTEST + SOURCES + convert_ut.cpp + LINK_LIBRARIES + yutil + api-protos + ydb-odbc + LABELS + unit +) diff --git a/odbc/tests/unit/convert_ut.cpp b/odbc/tests/unit/convert_ut.cpp new file mode 100644 index 0000000000..f4bad34a36 --- /dev/null +++ b/odbc/tests/unit/convert_ut.cpp @@ -0,0 +1,130 @@ +#include +#undef BOOL + +#include + +#include + +#include + +#include + +using namespace NYdb::NOdbc; +using namespace NYdb; + +template +void CheckProto(const T& value, const std::string& expected) { + std::string protoStr; + google::protobuf::TextFormat::PrintToString(value, &protoStr); + ASSERT_EQ(protoStr, expected); +} + +TEST(OdbcConvert, Int64ToYdb) { + SQLBIGINT v = 42; + TBoundParam param{ + 1, // ParamNumber + SQL_PARAM_INPUT, // InputOutputType + SQL_C_SBIGINT, // ValueType + SQL_BIGINT, // ParameterType + 0, 0, // ColumnSize, DecimalDigits + &v, // ParameterValuePtr + sizeof(v), // BufferLength + nullptr // StrLenOrIndPtr + }; + + TParamsBuilder paramsBuilder; + ConvertParam(param, paramsBuilder.AddParam("$p1")); + auto params = paramsBuilder.Build(); + auto value = params.GetValue("$p1"); + ASSERT_TRUE(value); + CheckProto(value->GetType().GetProto(), "optional_type {\n item {\n type_id: INT64\n }\n}\n"); + CheckProto(value->GetProto(), "int64_value: 42\n"); +} + +TEST(OdbcConvert, Uint64ToYdb) { + SQLUBIGINT v = 123; + TBoundParam param{ + 1, SQL_PARAM_INPUT, SQL_C_UBIGINT, SQL_BIGINT, 0, 0, &v, sizeof(v), nullptr + }; + TParamsBuilder paramsBuilder; + ConvertParam(param, paramsBuilder.AddParam("$p1")); + auto params = paramsBuilder.Build(); + auto value = params.GetValue("$p1"); + ASSERT_TRUE(value); + CheckProto(value->GetType().GetProto(), "optional_type {\n item {\n type_id: UINT64\n }\n}\n"); + CheckProto(value->GetProto(), "uint64_value: 123\n"); +} + +TEST(OdbcConvert, DoubleToYdb) { + SQLDOUBLE v = 3.14; + TBoundParam param{ + 1, SQL_PARAM_INPUT, SQL_C_DOUBLE, SQL_DOUBLE, 0, 0, &v, sizeof(v), nullptr + }; + TParamsBuilder paramsBuilder; + ConvertParam(param, paramsBuilder.AddParam("$p1")); + auto params = paramsBuilder.Build(); + auto value = params.GetValue("$p1"); + ASSERT_TRUE(value); + CheckProto(value->GetType().GetProto(), "optional_type {\n item {\n type_id: DOUBLE\n }\n}\n"); + CheckProto(value->GetProto(), "double_value: 3.14\n"); +} + +TEST(OdbcConvert, StringToYdbUtf8) { + const char* str = "hello"; + SQLLEN len = 5; + TBoundParam param{ + 1, SQL_PARAM_INPUT, SQL_C_CHAR, SQL_VARCHAR, 0, 0, (SQLPOINTER)str, len, nullptr + }; + TParamsBuilder paramsBuilder; + ConvertParam(param, paramsBuilder.AddParam("$p1")); + auto params = paramsBuilder.Build(); + auto value = params.GetValue("$p1"); + ASSERT_TRUE(value); + CheckProto(value->GetType().GetProto(), "optional_type {\n item {\n type_id: UTF8\n }\n}\n"); + CheckProto(value->GetProto(), "text_value: \"hello\"\n"); +} + +TEST(OdbcConvert, StringToYdbBinary) { + const char* str = "bin\x01\x02"; + SQLLEN len = 5; + TBoundParam param{ + 1, SQL_PARAM_INPUT, SQL_C_BINARY, SQL_BINARY, 0, 0, (SQLPOINTER)str, len, nullptr + }; + TParamsBuilder paramsBuilder; + ConvertParam(param, paramsBuilder.AddParam("$p1")); + auto params = paramsBuilder.Build(); + auto value = params.GetValue("$p1"); + ASSERT_TRUE(value); + CheckProto(value->GetType().GetProto(), "optional_type {\n item {\n type_id: STRING\n }\n}\n"); + CheckProto(value->GetProto(), "bytes_value: \"bin\\001\\002\"\n"); +} + +TEST(OdbcConvert, Int64NullToYdb) { + SQLBIGINT v = 42; + SQLLEN nullInd = SQL_NULL_DATA; + TBoundParam param{ + 1, SQL_PARAM_INPUT, SQL_C_SBIGINT, SQL_BIGINT, 0, 0, &v, sizeof(v), &nullInd + }; + TParamsBuilder paramsBuilder; + ConvertParam(param, paramsBuilder.AddParam("$p1")); + auto params = paramsBuilder.Build(); + auto value = params.GetValue("$p1"); + ASSERT_TRUE(value); + CheckProto(value->GetType().GetProto(), "optional_type {\n item {\n type_id: INT64\n }\n}\n"); + CheckProto(value->GetProto(), "null_flag_value: NULL_VALUE\n"); +} + +TEST(OdbcConvert, StringNullToYdb) { + const char* str = "test"; + SQLLEN nullInd = SQL_NULL_DATA; + TBoundParam param{ + 1, SQL_PARAM_INPUT, SQL_C_CHAR, SQL_VARCHAR, 0, 0, (SQLPOINTER)str, 4, &nullInd + }; + TParamsBuilder paramsBuilder; + ConvertParam(param, paramsBuilder.AddParam("$p1")); + auto params = paramsBuilder.Build(); + auto value = params.GetValue("$p1"); + ASSERT_TRUE(value); + CheckProto(value->GetType().GetProto(), "optional_type {\n item {\n type_id: UTF8\n }\n}\n"); + CheckProto(value->GetProto(), "null_flag_value: NULL_VALUE\n"); +}