Skip to content

Commit 0f95e9d

Browse files
committed
customized function for llm invoking
1 parent c575b53 commit 0f95e9d

File tree

6 files changed

+22
-80
lines changed

6 files changed

+22
-80
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,15 +70,15 @@ You can first setup a [demo website](./demo/) and then use the web UI to try Ref
7070
## 🚀 Quick Start
7171

7272
### Setup Environment
73-
First create a python environment using conda or virtualenv. Clone this repo and change path into the root directory. Then install:
73+
First create a python environment using conda or virtualenv. Then install:
7474
```bash
75-
pip install -e .
75+
pip install refchecker
7676
python -m spacy download en_core_web_sm
7777
```
7878

7979
Install optional dependencies to use open source extractors (Mistral, Mixtral) or enable acceleration for RepCChecker.
8080
```bash
81-
pip install -e .[open-extractor,repcex]
81+
pip install refchecker[open-extractor,repcex]
8282
```
8383

8484
### Code Examples

refchecker/checker/checker_base.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,7 @@ def check(
4949
merge_psg: bool = True,
5050
is_joint: bool = False,
5151
joint_check_num: int = 5,
52-
sagemaker_client=None,
53-
sagemaker_params=None,
54-
sagemaker_get_response_func=None,
52+
custom_llm_api_func=None,
5553
**kwargs
5654
):
5755
"""
@@ -97,9 +95,7 @@ def check(
9795
questions=batch_questions,
9896
is_joint=True,
9997
joint_check_num=joint_check_num,
100-
sagemaker_client=sagemaker_client,
101-
sagemaker_params=sagemaker_params,
102-
sagemaker_get_response_func=sagemaker_get_response_func,
98+
custom_llm_api_func=custom_llm_api_func,
10399
**kwargs
104100
)
105101
if merge_psg:
@@ -139,9 +135,7 @@ def check(
139135
responses=[inp[2] for inp in input_flattened],
140136
questions=[inp[3] for inp in input_flattened],
141137
is_joint=False,
142-
sagemaker_client=sagemaker_client,
143-
sagemaker_params=sagemaker_params,
144-
sagemaker_get_response_func=sagemaker_get_response_func
138+
custom_llm_api_func=custom_llm_api_func,
145139
)
146140

147141
ret = [[x] + y for x, y in zip(ret, input_ids)]

refchecker/checker/llm_checker.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,7 @@ def _check(
4646
questions: List[str] = None,
4747
is_joint: bool = False,
4848
joint_check_num: int = 5,
49-
sagemaker_client=None,
50-
sagemaker_params=None,
51-
sagemaker_get_response_func=None,
49+
custom_llm_api_func=None,
5250
**kwargs
5351
):
5452
"""
@@ -127,9 +125,7 @@ def _check(
127125
model=self.model,
128126
max_new_tokens=joint_check_num * 10 + 100,
129127
api_base=self.api_base,
130-
sagemaker_client=sagemaker_client,
131-
sagemaker_params=sagemaker_params,
132-
sagemaker_get_response_func=sagemaker_get_response_func,
128+
custom_llm_api_func=custom_llm_api_func,
133129
**kwargs
134130
)
135131

@@ -208,9 +204,7 @@ def _check(
208204
model=self.model,
209205
max_new_tokens=10,
210206
api_base=self.api_base,
211-
sagemaker_client=sagemaker_client,
212-
sagemaker_params=sagemaker_params,
213-
sagemaker_get_response_func=sagemaker_get_response_func,
207+
custom_llm_api_func=custom_llm_api_func,
214208
**kwargs
215209
)
216210

refchecker/extractor/extractor_base.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,29 +16,23 @@ def extract(
1616
batch_responses,
1717
batch_questions=None,
1818
max_new_tokens=500,
19-
sagemaker_client=None,
20-
sagemaker_params=None,
21-
sagemaker_get_response_func=None,
19+
custom_llm_api_func=None,
2220
**kwargs
2321
):
2422
if self.claim_format == 'triplet':
2523
result = self.extract_claim_triplets(
2624
batch_responses=batch_responses,
2725
batch_questions=batch_questions,
2826
max_new_tokens=max_new_tokens,
29-
sagemaker_client=sagemaker_client,
30-
sagemaker_params=sagemaker_params,
31-
sagemaker_get_response_func=sagemaker_get_response_func,
27+
custom_llm_api_func=custom_llm_api_func,
3228
**kwargs
3329
)
3430
elif self.claim_format == 'subsentence':
3531
result = self.extract_subsentence_claims(
3632
batch_responses=batch_responses,
3733
batch_questions=batch_questions,
3834
max_new_tokens=max_new_tokens,
39-
sagemaker_client=sagemaker_client,
40-
sagemaker_params=sagemaker_params,
41-
sagemaker_get_response_func=sagemaker_get_response_func,
35+
custom_llm_api_func=custom_llm_api_func,
4236
**kwargs
4337
)
4438
return result
@@ -48,9 +42,7 @@ def extract_claim_triplets(
4842
batch_responses,
4943
batch_questions=None,
5044
max_new_tokens=500,
51-
sagemaker_client=None,
52-
sagemaker_params=None,
53-
sagemaker_get_response_func=None,
45+
custom_llm_api_func=None,
5446
**kwargs
5547
):
5648
raise NotImplementedError
@@ -60,9 +52,7 @@ def extract_subsentence_claims(
6052
batch_responses,
6153
batch_questions=None,
6254
max_new_tokens=500,
63-
sagemaker_client=None,
64-
sagemaker_params=None,
65-
sagemaker_get_response_func=None,
55+
custom_llm_api_func=None,
6656
**kwargs
6757
):
6858
raise NotImplementedError

refchecker/extractor/llm_extractor.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,7 @@ def extract_subsentence_claims(
3030
batch_responses,
3131
batch_questions=None,
3232
max_new_tokens=500,
33-
sagemaker_client=None,
34-
sagemaker_params=None,
35-
sagemaker_get_response_func=None,
33+
custom_llm_api_func=None,
3634
**kwargs
3735
):
3836
"""Extract subsentence claims from the response text.
@@ -75,9 +73,7 @@ def extract_subsentence_claims(
7573
n_choices=1,
7674
max_new_tokens=max_new_tokens,
7775
api_base=self.api_base,
78-
sagemaker_client=sagemaker_client,
79-
sagemaker_params=sagemaker_params,
80-
sagemaker_get_response_func=sagemaker_get_response_func,
76+
custom_llm_api_func=custom_llm_api_func,
8177
**kwargs
8278
)
8379

@@ -103,9 +99,7 @@ def extract_claim_triplets(
10399
batch_responses,
104100
batch_questions=None,
105101
max_new_tokens=500,
106-
sagemaker_client=None,
107-
sagemaker_params=None,
108-
sagemaker_get_response_func=None,
102+
custom_llm_api_func=None,
109103
**kwargs
110104
):
111105
"""Extract KG triplets from the response text.
@@ -151,9 +145,7 @@ def extract_claim_triplets(
151145
n_choices=1,
152146
max_new_tokens=max_new_tokens,
153147
api_base=self.api_base,
154-
sagemaker_client=sagemaker_client,
155-
sagemaker_params=sagemaker_params,
156-
sagemaker_get_response_func=sagemaker_get_response_func,
148+
custom_llm_api_func=custom_llm_api_func,
157149
**kwargs
158150
)
159151

refchecker/utils.py

Lines changed: 4 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,7 @@ def get_model_batch_response(
7070
n_choices=1,
7171
max_new_tokens=500,
7272
api_base=None,
73-
sagemaker_client=None,
74-
sagemaker_params=None,
75-
sagemaker_get_response_func=None,
73+
custom_llm_api_func=None,
7674
**kwargs
7775
):
7876
"""
@@ -99,35 +97,9 @@ def get_model_batch_response(
9997
"""
10098
if not prompts or len(prompts) == 0:
10199
raise ValueError("Invalid input.")
102-
103-
if sagemaker_client is not None:
104-
parameters = {
105-
"max_new_tokens": max_new_tokens,
106-
"temperature": temperature
107-
}
108-
if sagemaker_params is not None:
109-
for k, v in sagemaker_params.items():
110-
if k in parameters:
111-
parameters[k] = v
112-
response_list = []
113-
for prompt in prompts:
114-
r = sagemaker_client.invoke_endpoint(
115-
EndpointName=model,
116-
Body=json.dumps(
117-
{
118-
"inputs": prompt,
119-
"parameters": parameters,
120-
}
121-
),
122-
ContentType="application/json",
123-
)
124-
if sagemaker_get_response_func is not None:
125-
response = sagemaker_get_response_func(r)
126-
else:
127-
r = json.loads(r['Body'].read().decode('utf8'))
128-
response = r['outputs'][0]
129-
response_list.append(response)
130-
return response_list
100+
101+
if custom_func is not None:
102+
return custom_func(prompts)
131103
else:
132104
message_list = []
133105
for prompt in prompts:

0 commit comments

Comments
 (0)