Skip to content

Commit c575b53

Browse files
committed
customized function to get response for sagemaker
1 parent c7ceebc commit c575b53

File tree

6 files changed

+23
-5
lines changed

6 files changed

+23
-5
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "refchecker"
3-
version = "0.2.8"
3+
version = "0.2.9"
44
description = "RefChecker provides automatic checking pipeline for detecting fine-grained hallucinations generated by Large Language Models."
55
authors = [
66
"Xiangkun Hu <xiangkhu@amazon.com>",

refchecker/checker/checker_base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def check(
5151
joint_check_num: int = 5,
5252
sagemaker_client=None,
5353
sagemaker_params=None,
54+
sagemaker_get_response_func=None,
5455
**kwargs
5556
):
5657
"""
@@ -98,6 +99,7 @@ def check(
9899
joint_check_num=joint_check_num,
99100
sagemaker_client=sagemaker_client,
100101
sagemaker_params=sagemaker_params,
102+
sagemaker_get_response_func=sagemaker_get_response_func,
101103
**kwargs
102104
)
103105
if merge_psg:
@@ -138,7 +140,8 @@ def check(
138140
questions=[inp[3] for inp in input_flattened],
139141
is_joint=False,
140142
sagemaker_client=sagemaker_client,
141-
sagemaker_params=sagemaker_params
143+
sagemaker_params=sagemaker_params,
144+
sagemaker_get_response_func=sagemaker_get_response_func
142145
)
143146

144147
ret = [[x] + y for x, y in zip(ret, input_ids)]

refchecker/checker/llm_checker.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def _check(
4848
joint_check_num: int = 5,
4949
sagemaker_client=None,
5050
sagemaker_params=None,
51+
sagemaker_get_response_func=None,
5152
**kwargs
5253
):
5354
"""
@@ -128,6 +129,7 @@ def _check(
128129
api_base=self.api_base,
129130
sagemaker_client=sagemaker_client,
130131
sagemaker_params=sagemaker_params,
132+
sagemaker_get_response_func=sagemaker_get_response_func,
131133
**kwargs
132134
)
133135

@@ -208,6 +210,7 @@ def _check(
208210
api_base=self.api_base,
209211
sagemaker_client=sagemaker_client,
210212
sagemaker_params=sagemaker_params,
213+
sagemaker_get_response_func=sagemaker_get_response_func,
211214
**kwargs
212215
)
213216

refchecker/extractor/extractor_base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def extract(
1818
max_new_tokens=500,
1919
sagemaker_client=None,
2020
sagemaker_params=None,
21+
sagemaker_get_response_func=None,
2122
**kwargs
2223
):
2324
if self.claim_format == 'triplet':
@@ -27,6 +28,7 @@ def extract(
2728
max_new_tokens=max_new_tokens,
2829
sagemaker_client=sagemaker_client,
2930
sagemaker_params=sagemaker_params,
31+
sagemaker_get_response_func=sagemaker_get_response_func,
3032
**kwargs
3133
)
3234
elif self.claim_format == 'subsentence':
@@ -36,6 +38,7 @@ def extract(
3638
max_new_tokens=max_new_tokens,
3739
sagemaker_client=sagemaker_client,
3840
sagemaker_params=sagemaker_params,
41+
sagemaker_get_response_func=sagemaker_get_response_func,
3942
**kwargs
4043
)
4144
return result
@@ -47,6 +50,7 @@ def extract_claim_triplets(
4750
max_new_tokens=500,
4851
sagemaker_client=None,
4952
sagemaker_params=None,
53+
sagemaker_get_response_func=None,
5054
**kwargs
5155
):
5256
raise NotImplementedError
@@ -58,6 +62,7 @@ def extract_subsentence_claims(
5862
max_new_tokens=500,
5963
sagemaker_client=None,
6064
sagemaker_params=None,
65+
sagemaker_get_response_func=None,
6166
**kwargs
6267
):
6368
raise NotImplementedError

refchecker/extractor/llm_extractor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def extract_subsentence_claims(
3232
max_new_tokens=500,
3333
sagemaker_client=None,
3434
sagemaker_params=None,
35+
sagemaker_get_response_func=None,
3536
**kwargs
3637
):
3738
"""Extract subsentence claims from the response text.
@@ -76,6 +77,7 @@ def extract_subsentence_claims(
7677
api_base=self.api_base,
7778
sagemaker_client=sagemaker_client,
7879
sagemaker_params=sagemaker_params,
80+
sagemaker_get_response_func=sagemaker_get_response_func,
7981
**kwargs
8082
)
8183

@@ -103,6 +105,7 @@ def extract_claim_triplets(
103105
max_new_tokens=500,
104106
sagemaker_client=None,
105107
sagemaker_params=None,
108+
sagemaker_get_response_func=None,
106109
**kwargs
107110
):
108111
"""Extract KG triplets from the response text.
@@ -150,6 +153,7 @@ def extract_claim_triplets(
150153
api_base=self.api_base,
151154
sagemaker_client=sagemaker_client,
152155
sagemaker_params=sagemaker_params,
156+
sagemaker_get_response_func=sagemaker_get_response_func,
153157
**kwargs
154158
)
155159

refchecker/utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def get_model_batch_response(
7272
api_base=None,
7373
sagemaker_client=None,
7474
sagemaker_params=None,
75+
sagemaker_get_response_func=None,
7576
**kwargs
7677
):
7778
"""
@@ -120,9 +121,11 @@ def get_model_batch_response(
120121
),
121122
ContentType="application/json",
122123
)
123-
124-
r = json.loads(r['Body'].read().decode('utf8'))
125-
response = r['outputs'][0]
124+
if sagemaker_get_response_func is not None:
125+
response = sagemaker_get_response_func(r)
126+
else:
127+
r = json.loads(r['Body'].read().decode('utf8'))
128+
response = r['outputs'][0]
126129
response_list.append(response)
127130
return response_list
128131
else:

0 commit comments

Comments
 (0)