Skip to content

Commit 29ddead

Browse files
authored
combine llm and support multi inputs and warp support batch_flags (#511)
1 parent 19293d1 commit 29ddead

File tree

3 files changed

+69
-9
lines changed

3 files changed

+69
-9
lines changed

lazyllm/engine/engine.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .node_meta_hook import NodeMetaHook
88
import inspect
99
import functools
10+
from itertools import repeat
1011
import copy
1112
from abc import ABC, abstractclassmethod
1213
from enum import Enum
@@ -346,8 +347,16 @@ def make_diverter(nodes: List[dict]):
346347

347348

348349
@NodeConstructor.register('Warp', subitems=['nodes', 'resources'])
349-
def make_warp(nodes: List[dict], edges: List[dict] = [], resources: List[dict] = []):
350-
return lazyllm.warp(make_graph(nodes, edges, resources, enable_server=False))
350+
def make_warp(nodes: List[dict], edges: List[dict] = [], resources: List[dict] = [],
351+
batch_flags: Optional[List[int]] = None):
352+
wp = lazyllm.warp(make_graph(nodes, edges, resources, enable_server=False))
353+
if batch_flags and len(batch_flags) > 1:
354+
def transform(*args):
355+
args = [a if b else repeat(a) for a, b in zip(args, batch_flags)]
356+
args = [lazyllm.package(a) for a in zip(*args)]
357+
return args
358+
wp = lazyllm.pipeline(transform, wp)
359+
return wp
351360

352361

353362
@NodeConstructor.register('Loop', subitems=['nodes', 'resources'])
@@ -723,6 +732,30 @@ def make_online_llm(source: str, base_model: Optional[str] = None, prompt: Optio
723732
api_key=api_key, secret_key=secret_key).prompt(prompt, history=history)
724733

725734

735+
class LLM(lazyllm.ModuleBase):
736+
def __init__(self, m: lazyllm.ModuleBase, keys: Optional[List[str]] = None):
737+
super().__init__()
738+
self._m = m
739+
self._keys = keys
740+
741+
def forward(self, *args, **kw):
742+
if self._keys and len(self._keys) > 1:
743+
assert len(args) == len(self._keys)
744+
args = ({k: a for k, a in zip(self._keys, args)},)
745+
else:
746+
assert len(args) == 1
747+
return self._m(*args, **kw)
748+
749+
750+
@NodeConstructor.register('LLM')
751+
def make_llm(kw: dict):
752+
type: str = kw.pop('type')
753+
keys: Optional[List[str]] = kw.pop('keys', None)
754+
assert type in ('local', 'online'), f'Invalid type {type} given'
755+
if type == 'local': return LLM(make_local_llm(**kw), keys)
756+
elif type == 'online': return LLM(make_online_llm(**kw), keys)
757+
758+
726759
class STT(lazyllm.Module):
727760
def __init__(self, base_model: Union[str, lazyllm.TrainableModule]):
728761
super().__init__()

tests/basic_tests/test_engine.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,32 @@ def test_engine_warp(self):
289289
assert engine.run(gid, 2, 3, 4, 5) == (4, 9, 16, 25)
290290
assert "prompt_tokens" in self.get_last_report()
291291

292+
def test_engine_warp_transform(self):
293+
nodes = [dict(id='1', kind='Code', name='code', args=dict(
294+
code='def sum(x: int, y: int, z: int): return x + y + z'))]
295+
edges = [dict(iid='__start__', oid='1'), dict(iid='1', oid='__end__')]
296+
297+
nodes = [
298+
dict(
299+
id="2",
300+
kind="Warp",
301+
name="warp",
302+
args=dict(
303+
nodes=nodes,
304+
edges=edges,
305+
batch_flags=[True, False, True],
306+
_lazyllm_enable_report=True,
307+
),
308+
)
309+
]
310+
edges = [dict(iid='__start__', oid='2'), dict(iid='2', oid='__end__')]
311+
312+
engine = LightEngine()
313+
engine.set_report_url(self.report_url)
314+
gid = engine.start(nodes, edges)
315+
assert engine.run(gid, [2, 3, 4, 5], 1, [1, 2, 3, 1]) == (4, 6, 8, 7)
316+
assert "prompt_tokens" in self.get_last_report()
317+
292318
def test_engine_formatter(self):
293319
nodes = [dict(id='1', kind='Formatter', name='f1', args=dict(ftype='python', rule='[:]'))]
294320
edges = [dict(iid='__start__', oid='1'), dict(iid='1', oid='__end__')]

tests/charge_tests/test_engine.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -232,20 +232,21 @@ def test_stream_and_hostory(self):
232232
builtin_history = [['水的沸点是多少?', '您好,我的答案是:水的沸点在标准大气压下是100摄氏度。'],
233233
['世界上最大的动物是什么?', '您好,我的答案是:蓝鲸是世界上最大的动物。'],
234234
['人一天需要喝多少水?', '您好,我的答案是:一般建议每天喝8杯水,大约2升。']]
235-
nodes = [dict(id='1', kind='OnlineLLM', name='m1', args=dict(source='glm', stream=True, prompt=dict(
235+
nodes = [dict(id='1', kind='LLM', name='m1', args=dict(source='glm', stream=True, type='online', keys=['query'],
236+
prompt=dict(
236237
system='请将我的问题翻译成中文。请注意,请直接输出翻译后的问题,不要反问和发挥',
237238
user='问题: {query} \n, 翻译:'))),
238-
dict(id='2', kind='OnlineLLM', name='m2',
239-
args=dict(source='glm', stream=True,
239+
dict(id='2', kind='LLM', name='m2',
240+
args=dict(source='glm', stream=True, type='online',
240241
prompt=dict(system='请参考历史对话,回答问题,并保持格式不变。', user='{query}'))),
241-
dict(id='3', kind='JoinFormatter', name='join', args=dict(type='to_dict', names=['query', 'answer'])),
242-
dict(id='4', kind='OnlineLLM', stream=False, name='m3',
243-
args=dict(source='glm', history=builtin_history, prompt=dict(
242+
dict(id='3', kind='LLM', stream=False, name='m3',
243+
args=dict(source='glm', type='online', keys=['query', 'answer'], history=builtin_history,
244+
prompt=dict(
244245
system='你是一个问答机器人,会根据用户的问题作出回答。',
245246
user='请结合历史对话和本轮的问题,总结我们的全部对话。本轮情况如下:\n {query}, 回答: {answer}')))]
246247
engine = LightEngine()
247248
gid = engine.start(nodes, edges=[['__start__', '1'], ['1', '2'], ['1', '3'], ['2', '3'],
248-
['3', '4'], ['4', '__end__']], _history_ids=['2', '4'])
249+
['3', '__end__']], _history_ids=['2', '3'])
249250
history = [['雨后为什么会有彩虹?', '您好,我的答案是:雨后阳光通过水滴发生折射和反射形成了彩虹。'],
250251
['月亮会发光吗?', '您好,我的答案是:月亮本身不会发光,它反射太阳光。'],
251252
['一年有多少天', '您好,我的答案是:一年有365天,闰年有366天。']]

0 commit comments

Comments
 (0)