|
| 1 | +//***************************************************************************** |
| 2 | +// Copyright 2025 Intel Corporation |
| 3 | +// |
| 4 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +// you may not use this file except in compliance with the License. |
| 6 | +// You may obtain a copy of the License at |
| 7 | +// |
| 8 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +// |
| 10 | +// Unless required by applicable law or agreed to in writing, software |
| 11 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +// See the License for the specific language governing permissions and |
| 14 | +// limitations under the License. |
| 15 | +//***************************************************************************** |
| 16 | +#include <fstream> |
| 17 | + |
| 18 | +#pragma warning(push) |
| 19 | +#pragma warning(disable : 4005 4309 6001 6385 6386 6326 6011 6246 4456 6246) |
| 20 | +#pragma GCC diagnostic push |
| 21 | +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" |
| 22 | +#include "mediapipe/framework/calculator_framework.h" |
| 23 | +#include "mediapipe/framework/port/canonical_errors.h" |
| 24 | +#pragma GCC diagnostic pop |
| 25 | +#pragma warning(pop) |
| 26 | + |
| 27 | +#include "../http_payload.hpp" |
| 28 | +#include "../logging.hpp" |
| 29 | + |
| 30 | +#pragma warning(push) |
| 31 | +#pragma warning(disable : 6001 4324 6385 6386) |
| 32 | +#include "absl/strings/escaping.h" |
| 33 | +#include "absl/strings/str_cat.h" |
| 34 | +#pragma warning(pop) |
| 35 | + |
| 36 | +#define DR_WAV_IMPLEMENTATION |
| 37 | +#include "dr_wav.h" |
| 38 | +#include "openvino/genai/whisper_pipeline.hpp" |
| 39 | +#include "openvino/genai/speech_generation/text2speech_pipeline.hpp" |
| 40 | + |
| 41 | +#ifdef _WIN32 |
| 42 | +# include <fcntl.h> |
| 43 | +# include <io.h> |
| 44 | +#endif |
| 45 | + |
| 46 | +using namespace ovms; |
| 47 | + |
| 48 | +namespace mediapipe { |
| 49 | + |
| 50 | +// using SpeechPipelinesMap = std::unordered_map<std::string, std::shared_ptr<SpeechPipelines>>; |
| 51 | + |
| 52 | + |
| 53 | +const std::string SPEECH_SESSION_SIDE_PACKET_TAG = "SPEECH_NODE_RESOURCES"; |
| 54 | + |
| 55 | +#define COMMON_SAMPLE_RATE 16000 |
| 56 | + |
| 57 | +bool is_wav_buffer(const std::string buf) { |
| 58 | + // RIFF ref: https://en.wikipedia.org/wiki/Resource_Interchange_File_Format |
| 59 | + // WAV ref: https://www.mmsp.ece.mcgill.ca/Documents/AudioFormats/WAVE/WAVE.html |
| 60 | + if (buf.size() < 12 || buf.substr(0, 4) != "RIFF" || buf.substr(8, 4) != "WAVE") { |
| 61 | + return false; |
| 62 | + } |
| 63 | + |
| 64 | + uint32_t chunk_size = *reinterpret_cast<const uint32_t*>(buf.data() + 4); |
| 65 | + if (chunk_size + 8 != buf.size()) { |
| 66 | + return false; |
| 67 | + } |
| 68 | + |
| 69 | + return true; |
| 70 | +} |
| 71 | + |
| 72 | +ov::genai::RawSpeechInput read_wav(const std::string_view& wav_data) { |
| 73 | + drwav wav; |
| 74 | + |
| 75 | +// if (filename == "-") { |
| 76 | +// { |
| 77 | +// #ifdef _WIN32 |
| 78 | +// _setmode(_fileno(stdin), _O_BINARY); |
| 79 | +// #endif |
| 80 | + |
| 81 | +// uint8_t buf[1024]; |
| 82 | +// while (true) { |
| 83 | +// const size_t n = fread(buf, 1, sizeof(buf), stdin); |
| 84 | +// if (n == 0) { |
| 85 | +// break; |
| 86 | +// } |
| 87 | +// wav_data.insert(wav_data.end(), buf, buf + n); |
| 88 | +// } |
| 89 | +// } |
| 90 | + |
| 91 | +// OPENVINO_ASSERT(drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr), |
| 92 | +// "Failed to open WAV file from stdin"); |
| 93 | + |
| 94 | +// fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size()); |
| 95 | +// } else if (is_wav_buffer(filename)) { |
| 96 | +// OPENVINO_ASSERT(drwav_init_memory(&wav, filename.c_str(), filename.size(), nullptr), |
| 97 | +// "Failed to open WAV file from fname buffer"); |
| 98 | +// } else if (!drwav_init_file(&wav, filename.c_str(), nullptr)) { |
| 99 | +// #if defined(WHISPER_FFMPEG) |
| 100 | +// OPENVINO_ASSERT(ffmpeg_decode_audio(fname, wav_data) == 0, "Failed to ffmpeg decode") |
| 101 | + |
| 102 | +// OPENVINO_ASSERT(drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr), |
| 103 | +// "Failed to read wav data as wav") |
| 104 | +// #else |
| 105 | +// throw std::runtime_error("failed to open as WAV file"); |
| 106 | +// #endif |
| 107 | +// } |
| 108 | + OPENVINO_ASSERT(drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr), "Failed to open WAV file from stdin"); |
| 109 | + if (wav.channels != 1 && wav.channels != 2) { |
| 110 | + drwav_uninit(&wav); |
| 111 | + throw std::runtime_error("WAV file must be mono or stereo"); |
| 112 | + } |
| 113 | + |
| 114 | + if (wav.sampleRate != COMMON_SAMPLE_RATE) { |
| 115 | + drwav_uninit(&wav); |
| 116 | + throw std::runtime_error("WAV file must be " + std::string{COMMON_SAMPLE_RATE / 1000} + " kHz"); |
| 117 | + } |
| 118 | + |
| 119 | + const uint64_t n = |
| 120 | + wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size() / (wav.channels * wav.bitsPerSample / 8ul); |
| 121 | + |
| 122 | + std::vector<int16_t> pcm16; |
| 123 | + pcm16.resize(n * wav.channels); |
| 124 | + drwav_read_pcm_frames_s16(&wav, n, pcm16.data()); |
| 125 | + drwav_uninit(&wav); |
| 126 | + |
| 127 | + // convert to mono, float |
| 128 | + std::vector<float> pcmf32; |
| 129 | + pcmf32.resize(n); |
| 130 | + if (wav.channels == 1) { |
| 131 | + for (uint64_t i = 0; i < n; i++) { |
| 132 | + pcmf32[i] = float(pcm16[i]) / 32768.0f; |
| 133 | + } |
| 134 | + } else { |
| 135 | + for (uint64_t i = 0; i < n; i++) { |
| 136 | + pcmf32[i] = float(pcm16[2 * i] + pcm16[2 * i + 1]) / 65536.0f; |
| 137 | + } |
| 138 | + } |
| 139 | + |
| 140 | + return pcmf32; |
| 141 | +} |
| 142 | + |
| 143 | +std::variant<absl::Status, std::optional<std::string_view>> getFileFromPayload(const ovms::MultiPartParser& parser, const std::string& keyName) { |
| 144 | + std::string_view value = parser.getFileContentByFieldName(keyName); |
| 145 | + if (value.empty()) { |
| 146 | + return std::nullopt; |
| 147 | + } |
| 148 | + return value; |
| 149 | +} |
| 150 | + |
| 151 | +#define SET_OR_RETURN(TYPE, NAME, RHS) \ |
| 152 | + auto NAME##_OPT = RHS; \ |
| 153 | + RETURN_IF_HOLDS_STATUS(NAME##_OPT) \ |
| 154 | + auto NAME = std::get<TYPE>(NAME##_OPT); |
| 155 | + |
| 156 | +#define RETURN_IF_HOLDS_STATUS(NAME) \ |
| 157 | + if (std::holds_alternative<absl::Status>(NAME)) { \ |
| 158 | + return std::get<absl::Status>(NAME); \ |
| 159 | + } |
| 160 | + |
| 161 | +class SpeechCalculator : public CalculatorBase { |
| 162 | + static const std::string INPUT_TAG_NAME; |
| 163 | + static const std::string OUTPUT_TAG_NAME; |
| 164 | + |
| 165 | +public: |
| 166 | + static absl::Status GetContract(CalculatorContract* cc) { |
| 167 | + RET_CHECK(!cc->Inputs().GetTags().empty()); |
| 168 | + RET_CHECK(!cc->Outputs().GetTags().empty()); |
| 169 | + cc->Inputs().Tag(INPUT_TAG_NAME).Set<ovms::HttpPayload>(); |
| 170 | + // cc->InputSidePackets().Tag(IMAGE_GEN_SESSION_SIDE_PACKET_TAG).Set<SpeechPipelinesMap>(); // TODO: template? |
| 171 | + cc->Outputs().Tag(OUTPUT_TAG_NAME).Set<std::string>(); |
| 172 | + return absl::OkStatus(); |
| 173 | + } |
| 174 | + |
| 175 | + absl::Status Close(CalculatorContext* cc) final { |
| 176 | + SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "SpeechCalculator [Node: {} ] Close", cc->NodeName()); |
| 177 | + return absl::OkStatus(); |
| 178 | + } |
| 179 | + |
| 180 | + absl::Status Open(CalculatorContext* cc) final { |
| 181 | + SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "SpeechCalculator [Node: {}] Open start", cc->NodeName()); |
| 182 | + return absl::OkStatus(); |
| 183 | + } |
| 184 | + |
| 185 | + absl::Status Process(CalculatorContext* cc) final { |
| 186 | + SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "SpeechCalculator [Node: {}] Process start", cc->NodeName()); |
| 187 | + |
| 188 | + // ImageGenerationPipelinesMap pipelinesMap = cc->InputSidePackets().Tag(SPEECH_SESSION_SIDE_PACKET_TAG).Get<SpeechPipelinesMap>(); |
| 189 | + // auto it = pipelinesMap.find(cc->NodeName()); |
| 190 | + // RET_CHECK(it != pipelinesMap.end()) << "Could not find initialized Speech node named: " << cc->NodeName(); |
| 191 | + // auto pipe = it->second; |
| 192 | + |
| 193 | + auto payload = cc->Inputs().Tag(INPUT_TAG_NAME).Get<ovms::HttpPayload>(); |
| 194 | + |
| 195 | + std::unique_ptr<ov::Tensor> images; // output |
| 196 | + std::unique_ptr<std::string> output; |
| 197 | + if (absl::StartsWith(payload.uri, "/v3/audio/transcriptions")) { |
| 198 | + if (payload.multipartParser->hasParseError()) |
| 199 | + return absl::InvalidArgumentError("Failed to parse multipart data"); |
| 200 | + |
| 201 | + SET_OR_RETURN(std::optional<std::string_view>, file, getFileFromPayload(*payload.multipartParser, "file")); |
| 202 | + if(!file.has_value()){ |
| 203 | + return absl::InvalidArgumentError(absl::StrCat("File parsing fails")); |
| 204 | + } |
| 205 | + ov::genai::WhisperPipeline pipeline("/models/audio/transcriptions", "CPU"); |
| 206 | + ov::genai::WhisperGenerationConfig config = pipeline.get_generation_config(); |
| 207 | + // 'task' and 'language' parameters are supported for multilingual models only |
| 208 | + config.language = "<|en|>"; // can switch to <|zh|> for Chinese language |
| 209 | + config.task = "transcribe"; |
| 210 | + config.return_timestamps = true; |
| 211 | + ov::genai::RawSpeechInput raw_speech = read_wav(file.value()); |
| 212 | + output = std::make_unique<std::string>(pipeline.generate(raw_speech)); |
| 213 | + } else if(absl::StartsWith(payload.uri, "/v3/audio/speech")){ |
| 214 | + if (payload.parsedJson->HasParseError()) |
| 215 | + return absl::InvalidArgumentError("Failed to parse JSON"); |
| 216 | + |
| 217 | + if (!payload.parsedJson->IsObject()) { |
| 218 | + return absl::InvalidArgumentError("JSON body must be an object"); |
| 219 | + } |
| 220 | + auto inputIt = payload.parsedJson->FindMember("input"); |
| 221 | + if (inputIt == payload.parsedJson->MemberEnd()) { |
| 222 | + return absl::InvalidArgumentError("input field is missing in JSON body"); |
| 223 | + } |
| 224 | + if (!inputIt->value.IsString()) { |
| 225 | + return absl::InvalidArgumentError("input field is not a string"); |
| 226 | + } |
| 227 | + ov::genai::Text2SpeechPipeline pipeline("/models/audio/speech", "CPU"); |
| 228 | + SPDLOG_ERROR("1"); |
| 229 | + auto gen_speech = pipeline.generate(inputIt->value.GetString()); |
| 230 | + drwav_data_format format; |
| 231 | + format.container = drwav_container_riff; |
| 232 | + format.format = DR_WAVE_FORMAT_IEEE_FLOAT; |
| 233 | + format.channels = 1; |
| 234 | + format.sampleRate = 16000; // assume it is always 16 KHz |
| 235 | + format.bitsPerSample = gen_speech.speeches[0].get_element_type().bitwidth(); |
| 236 | + |
| 237 | + drwav wav; |
| 238 | + void* ppData; |
| 239 | + size_t pDataSize; |
| 240 | + OPENVINO_ASSERT(drwav_init_memory_write(&wav, &ppData, &pDataSize, &format, nullptr), |
| 241 | + "Failed to initialize WAV writer"); |
| 242 | + auto waveform_size = gen_speech.speeches[0].get_size(); |
| 243 | + size_t total_samples = waveform_size * format.channels; |
| 244 | + auto waveform_ptr = gen_speech.speeches[0].data<const float>(); |
| 245 | + |
| 246 | + drwav_uint64 frames_written = drwav_write_pcm_frames(&wav, total_samples, waveform_ptr); |
| 247 | + OPENVINO_ASSERT(frames_written == total_samples, "Failed to write not all frames"); |
| 248 | + |
| 249 | + SPDLOG_ERROR("SIZE {}", gen_speech.speeches[0].get_size()); |
| 250 | + output = std::make_unique<std::string>(reinterpret_cast<char*>(ppData), pDataSize); |
| 251 | + //drwav_free(&wav) TODO: ?? |
| 252 | + drwav_uninit(&wav); |
| 253 | + SPDLOG_ERROR("3"); |
| 254 | + }else { |
| 255 | + return absl::InvalidArgumentError(absl::StrCat("Unsupported URI: ", payload.uri)); |
| 256 | + } |
| 257 | + |
| 258 | + // auto outputOrStatus = generateJSONResponseFromOvTensor(*images); |
| 259 | + // RETURN_IF_HOLDS_STATUS(outputOrStatus); |
| 260 | + // auto output = std::move(std::get<std::unique_ptr<std::string>>(outputOrStatus)); |
| 261 | + cc->Outputs().Tag(OUTPUT_TAG_NAME).Add(output.release(), cc->InputTimestamp()); |
| 262 | + SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "SpeechCalculator [Node: {}] Process end", cc->NodeName()); |
| 263 | + |
| 264 | + return absl::OkStatus(); |
| 265 | + } |
| 266 | +}; |
| 267 | + |
| 268 | +const std::string SpeechCalculator::INPUT_TAG_NAME{"HTTP_REQUEST_PAYLOAD"}; |
| 269 | +const std::string SpeechCalculator::OUTPUT_TAG_NAME{"HTTP_RESPONSE_PAYLOAD"}; |
| 270 | + |
| 271 | +REGISTER_CALCULATOR(SpeechCalculator); |
| 272 | + |
| 273 | +} // namespace mediapipe |
0 commit comments