diff --git a/src/litgen/internal/adapt_function_params/apply_all_adapters.py b/src/litgen/internal/adapt_function_params/apply_all_adapters.py index d5466062..08932883 100644 --- a/src/litgen/internal/adapt_function_params/apply_all_adapters.py +++ b/src/litgen/internal/adapt_function_params/apply_all_adapters.py @@ -161,7 +161,12 @@ def _make_adapted_lambda_code_end(adapted_function: AdaptedFunction, lambda_adap # Fill auto_r_equal_or_void _fn_return_type = adapted_function.cpp_adapted_function.str_full_return_type() - auto_r_equal_or_void = "auto lambda_result = " if _fn_return_type != "void" else "" + _return_referenced = False + + if hasattr(adapted_function.cpp_element(), "return_type"): + _return_referenced = '&' in adapted_function.cpp_element().return_type.modifiers + + auto_r_equal_or_void = ("auto" + ("&" if _return_referenced else "") + " lambda_result = ") if _fn_return_type != "void" else "" # Fill function_or_lambda_to_call if adapted_function.lambda_to_call is not None: diff --git a/src/litgen/tests/internal/adapt_function_params/adapt_function_test.py b/src/litgen/tests/internal/adapt_function_params/adapt_function_test.py index fdfbedb0..be9669ba 100644 --- a/src/litgen/tests/internal/adapt_function_params/adapt_function_test.py +++ b/src/litgen/tests/internal/adapt_function_params/adapt_function_test.py @@ -2,10 +2,14 @@ from dataclasses import dataclass from typing import Optional +from codemanip import code_utils + import srcmlcpp from srcmlcpp import srcmlcpp_main from srcmlcpp.cpp_types import CppFunctionDecl +import litgen +from litgen import litgen_generator @dataclass class AdaptedFunction2(CppFunctionDecl): @@ -31,3 +35,44 @@ def test_inherit(): cpp_function = srcmlcpp_main.code_first_child_of_type(options, CppFunctionDecl, code) assert isinstance(cpp_function, CppFunctionDecl) _ = AdaptedFunction2(cpp_function, "Foo") + + +def test_lambda_correctly_returns_reference(): + """ + Test for check the "auto& lambda_result" has a reference + """ + options = litgen.LitgenOptions() + code = """ + class MyClass { + public: + MyClass& setArray(const uint8_t arr[20]) { + memcpy(_signature, arr, sizeof(arr)); + return *this; + } + private: + uint8_t _arr[20]; + }; + """ + + generated_code = litgen.generate_code(options, code) + code_utils.assert_are_codes_equal( + generated_code.pydef_code, + """ + auto pyClassMyClass = + py::class_ + (m, "MyClass", "") + .def(py::init<>()) // implicit default constructor + .def("set_array", + [](MyClass & self, const std::array& arr) -> MyClass & + { + auto setArray_adapt_fixed_size_c_arrays = [&self](const std::array& arr) -> MyClass & + { + auto& lambda_result = self.setArray(arr.data()); + return lambda_result; + }; + + return setArray_adapt_fixed_size_c_arrays(arr); + }, py::arg("arr")) + ; + """, + ) \ No newline at end of file