Skip to content

Commit 025dfde

Browse files
qiuyizcopybara-github
authored andcommitted
Add string parameter to metadata conversion support.
PiperOrigin-RevId: 573243721
1 parent 6edfdd0 commit 025dfde

File tree

2 files changed

+141
-0
lines changed

2 files changed

+141
-0
lines changed
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright 2023 Google LLC.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
"""Converter utils for parameters for free-form strings."""
18+
19+
from typing import Sequence
20+
import copy
21+
import json
22+
import attrs
23+
from vizier import pyvizier as vz
24+
25+
_METADATA_VERSION = '0.0.1a'
26+
PROMPT_TUNING_NS = 'prompt_tuning'
27+
28+
29+
@attrs.define
30+
class PromptTuningConfig:
31+
"""Variables and utils for configuring prompt tuning."""
32+
33+
default_prompts: dict[str, str] = attrs.field(factory=dict)
34+
35+
def augment_problem(
36+
self, problem: vz.ProblemStatement
37+
) -> vz.ProblemStatement:
38+
"""Augments problem statement to enable for prompt tuning."""
39+
for k, v in self.default_prompts.items():
40+
problem.search_space.root.add_categorical_param(k, [v], default_value=v)
41+
problem.metadata.ns(PROMPT_TUNING_NS)['version'] = _METADATA_VERSION
42+
return problem
43+
44+
def to_prompt_trials(self, trials: Sequence[vz.Trial]) -> Sequence[vz.Trial]:
45+
"""Converts from metadata to string valued parameters."""
46+
prompt_trials = copy.deepcopy(trials)
47+
for trial in prompt_trials:
48+
prompt_values = json.loads(trial.metadata.ns(PROMPT_TUNING_NS)['values'])
49+
for k in self.default_prompts.keys():
50+
if k in prompt_values:
51+
trial.parameters[k] = prompt_values[k]
52+
return prompt_trials
53+
54+
def to_valid_suggestions(
55+
self, suggestions: Sequence[vz.TrialSuggestion]
56+
) -> Sequence[vz.TrialSuggestion]:
57+
"""Returns the features array with dimension: (n_trials, n_features)."""
58+
valid_suggestions = copy.deepcopy(suggestions)
59+
for suggestion in valid_suggestions:
60+
prompt_values = {}
61+
for k, default_value in self.default_prompts.items():
62+
prompt_values[k] = suggestion.parameters[k].value
63+
suggestion.parameters[k] = default_value
64+
suggestion.metadata.ns(PROMPT_TUNING_NS)['values'] = json.dumps(
65+
prompt_values
66+
)
67+
suggestion.metadata.ns(PROMPT_TUNING_NS)['version'] = _METADATA_VERSION
68+
return valid_suggestions
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright 2023 Google LLC.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import json
18+
19+
from vizier import pyvizier as vz
20+
from vizier.pyvizier.converters import string_converters
21+
22+
from absl.testing import absltest
23+
24+
25+
class StringConvertersTest(absltest.TestCase):
26+
27+
def setUp(self):
28+
super().setUp()
29+
self.default_prompts = {'prompt1': 'def1', 'prompt2': 'def2'}
30+
self.config = string_converters.PromptTuningConfig(
31+
default_prompts=self.default_prompts
32+
)
33+
34+
def test_augment_problem(self):
35+
problem = vz.ProblemStatement()
36+
tuning_problem = self.config.augment_problem(problem)
37+
for k, v in self.default_prompts.items():
38+
pconfig = tuning_problem.search_space.get(k)
39+
self.assertCountEqual(pconfig.feasible_values, [v])
40+
self.assertEqual(pconfig.default_value, v)
41+
42+
def test_prompt_trials(self):
43+
trial = vz.Trial(parameters={'int': 3, 'float': 1.2, 'cat': 'test'})
44+
trial.metadata.ns(string_converters.PROMPT_TUNING_NS)['values'] = (
45+
json.dumps({'prompt1': 'test1', 'prompt2': 'test2'})
46+
)
47+
results = self.config.to_prompt_trials([trial])
48+
self.assertLen(results, 1)
49+
50+
self.assertStartsWith(results[0].parameters['prompt1'].value, 'test1')
51+
self.assertEqual(results[0].parameters['prompt2'].value, 'test2')
52+
53+
def test_valid_suggestions(self):
54+
problem = vz.ProblemStatement()
55+
tuning_problem = self.config.augment_problem(problem)
56+
suggestion = vz.TrialSuggestion(
57+
parameters={'prompt1': 'test1', 'prompt2': 'test2'}
58+
)
59+
valid_suggestions = self.config.to_valid_suggestions([suggestion])
60+
self.assertLen(valid_suggestions, 1)
61+
valid_suggestion = valid_suggestions[0]
62+
self.assertTrue(
63+
tuning_problem.search_space.contains(valid_suggestion.parameters)
64+
)
65+
66+
# Test the reverse conversion retrieves original parameters.
67+
trial = valid_suggestion.to_trial().complete(vz.Measurement())
68+
prompt_trial = self.config.to_prompt_trials([trial])[0]
69+
self.assertCountEqual(prompt_trial.parameters, suggestion.parameters)
70+
71+
72+
if __name__ == '__main__':
73+
absltest.main()

0 commit comments

Comments
 (0)