Skip to content

Commit 9092e7c

Browse files
committed
feat: method to replace all parameters
1 parent 2d4afe4 commit 9092e7c

File tree

2 files changed

+71
-44
lines changed

2 files changed

+71
-44
lines changed

docs/tutorials/parameters.ipynb

Lines changed: 48 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@
1616
},
1717
{
1818
"cell_type": "code",
19-
"execution_count": null,
2019
"metadata": {},
21-
"outputs": [],
2220
"source": [
2321
"import copy\n",
2422
"\n",
@@ -30,7 +28,9 @@
3028
" compute_taxes_and_transfers,\n",
3129
" create_synthetic_data,\n",
3230
")"
33-
]
31+
],
32+
"outputs": [],
33+
"execution_count": null
3434
},
3535
{
3636
"cell_type": "markdown",
@@ -56,12 +56,12 @@
5656
},
5757
{
5858
"cell_type": "code",
59-
"execution_count": null,
6059
"metadata": {},
61-
"outputs": [],
6260
"source": [
6361
"environment = PolicyEnvironment.for_date(\"2020\")"
64-
]
62+
],
63+
"outputs": [],
64+
"execution_count": null
6565
},
6666
{
6767
"cell_type": "markdown",
@@ -74,12 +74,12 @@
7474
},
7575
{
7676
"cell_type": "code",
77-
"execution_count": null,
7877
"metadata": {},
79-
"outputs": [],
8078
"source": [
8179
"print(*environment.params.keys(), sep=\"\\n\")"
82-
]
80+
],
81+
"outputs": [],
82+
"execution_count": null
8383
},
8484
{
8585
"cell_type": "markdown",
@@ -90,12 +90,12 @@
9090
},
9191
{
9292
"cell_type": "code",
93-
"execution_count": null,
9493
"metadata": {},
95-
"outputs": [],
9694
"source": [
9795
"print(*environment.params[\"kindergeld\"].keys(), sep=\"\\n\")"
98-
]
96+
],
97+
"outputs": [],
98+
"execution_count": null
9999
},
100100
{
101101
"cell_type": "markdown",
@@ -116,12 +116,12 @@
116116
},
117117
{
118118
"cell_type": "code",
119-
"execution_count": null,
120119
"metadata": {},
121-
"outputs": [],
122120
"source": [
123121
"environment.params[\"kindergeld\"][\"kindergeld\"]"
124-
]
122+
],
123+
"outputs": [],
124+
"execution_count": null
125125
},
126126
{
127127
"cell_type": "markdown",
@@ -136,12 +136,12 @@
136136
},
137137
{
138138
"cell_type": "code",
139-
"execution_count": null,
140139
"metadata": {},
141-
"outputs": [],
142140
"source": [
143141
"policy_params_new = copy.deepcopy(environment.params)"
144-
]
142+
],
143+
"outputs": [],
144+
"execution_count": null
145145
},
146146
{
147147
"cell_type": "markdown",
@@ -152,14 +152,14 @@
152152
},
153153
{
154154
"cell_type": "code",
155-
"execution_count": null,
156155
"metadata": {},
157-
"outputs": [],
158156
"source": [
159157
"# Loop through policy paramaters to add the special child bonus.\n",
160158
"for n in policy_params_new[\"kindergeld\"][\"kindergeld\"]:\n",
161159
" policy_params_new[\"kindergeld\"][\"kindergeld\"][n] += 20"
162-
]
160+
],
161+
"outputs": [],
162+
"execution_count": null
163163
},
164164
{
165165
"cell_type": "markdown",
@@ -170,12 +170,12 @@
170170
},
171171
{
172172
"cell_type": "code",
173-
"execution_count": null,
174173
"metadata": {},
175-
"outputs": [],
176174
"source": [
177175
"policy_params_new[\"kindergeld\"][\"kindergeld\"]"
178-
]
176+
],
177+
"outputs": [],
178+
"execution_count": null
179179
},
180180
{
181181
"cell_type": "markdown",
@@ -195,9 +195,7 @@
195195
},
196196
{
197197
"cell_type": "code",
198-
"execution_count": null,
199198
"metadata": {},
200-
"outputs": [],
201199
"source": [
202200
"data = create_synthetic_data(\n",
203201
" n_adults=2,\n",
@@ -224,7 +222,9 @@
224222
" \"sum_ges_rente_priv_rente_m\"\n",
225223
"]\n",
226224
"data.head()"
227-
]
225+
],
226+
"outputs": [],
227+
"execution_count": null
228228
},
229229
{
230230
"attachments": {},
@@ -238,9 +238,7 @@
238238
},
239239
{
240240
"cell_type": "code",
241-
"execution_count": null,
242241
"metadata": {},
243-
"outputs": [],
244242
"source": [
245243
"kindergeld_status_quo = compute_taxes_and_transfers(\n",
246244
" data=data,\n",
@@ -249,7 +247,9 @@
249247
")\n",
250248
"\n",
251249
"kindergeld_status_quo[[\"kindergeld_m_hh\"]]"
252-
]
250+
],
251+
"outputs": [],
252+
"execution_count": null
253253
},
254254
{
255255
"cell_type": "markdown",
@@ -260,18 +260,18 @@
260260
},
261261
{
262262
"cell_type": "code",
263-
"execution_count": null,
264263
"metadata": {},
265-
"outputs": [],
266264
"source": [
267-
"environment_new = PolicyEnvironment(environment.functions, policy_params_new)\n",
265+
"environment_new = environment.replace_all_parameters(policy_params_new)\n",
268266
"kindergeld_new = compute_taxes_and_transfers(\n",
269267
" data=data,\n",
270268
" environment=environment_new,\n",
271269
" targets=\"kindergeld_m_hh\",\n",
272270
")\n",
273271
"kindergeld_new[[\"kindergeld_m_hh\"]]"
274-
]
272+
],
273+
"outputs": [],
274+
"execution_count": null
275275
},
276276
{
277277
"cell_type": "markdown",
@@ -284,20 +284,18 @@
284284
},
285285
{
286286
"cell_type": "code",
287-
"execution_count": null,
288287
"metadata": {},
289-
"outputs": [],
290288
"source": [
291289
"# Group data by household id and sum the gross monthly income.\n",
292290
"total_income_m_hh = data.groupby(\"hh_id\")[\"bruttolohn_m\"].sum()\n",
293291
"total_income_m_hh.tail(10)"
294-
]
292+
],
293+
"outputs": [],
294+
"execution_count": null
295295
},
296296
{
297297
"cell_type": "code",
298-
"execution_count": null,
299298
"metadata": {},
300-
"outputs": [],
301299
"source": [
302300
"# Create DataFrame with relevant columns for plotting.\n",
303301
"df = pd.DataFrame()\n",
@@ -306,7 +304,9 @@
306304
"df[\"hh_id\"] = data[\"hh_id\"]\n",
307305
"df = df.drop_duplicates(\"hh_id\").set_index(\"hh_id\")\n",
308306
"df[\"Income (per household)\"] = total_income_m_hh"
309-
]
307+
],
308+
"outputs": [],
309+
"execution_count": null
310310
},
311311
{
312312
"cell_type": "markdown",
@@ -317,9 +317,7 @@
317317
},
318318
{
319319
"cell_type": "code",
320-
"execution_count": null,
321320
"metadata": {},
322-
"outputs": [],
323321
"source": [
324322
"fig = px.line(\n",
325323
" data_frame=df,\n",
@@ -331,7 +329,9 @@
331329
" xaxis_title=\"Monthly gross income in € (per household)\",\n",
332330
" yaxis_title=\"Kindergeld in € per month\",\n",
333331
")"
334-
]
332+
],
333+
"outputs": [],
334+
"execution_count": null
335335
},
336336
{
337337
"cell_type": "markdown",
@@ -354,6 +354,11 @@
354354
"name": "python",
355355
"nbconvert_exporter": "python",
356356
"pygments_lexer": "ipython3"
357+
},
358+
"kernelspec": {
359+
"name": "python3",
360+
"language": "python",
361+
"display_name": "Python 3 (ipykernel)"
357362
}
358363
},
359364
"nbformat": 4,

src/_gettsim/policy_environment.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,6 @@ def upsert_functions(
109109
The policy environment with the new functions.
110110
"""
111111
new_functions = {**self._functions}
112-
113112
for function in functions:
114113
f = (
115114
function
@@ -124,6 +123,29 @@ def upsert_functions(
124123

125124
return result
126125

126+
def replace_all_parameters(
127+
self, params: dict[str, Any]
128+
):
129+
"""
130+
Replace all parameters of the policy environment. Note that this
131+
method does not modify the current policy environment but returns a new one.
132+
133+
Parameters
134+
----------
135+
params:
136+
The new parameters.
137+
138+
Returns
139+
-------
140+
new_environment:
141+
The policy environment with the new parameters.
142+
"""
143+
result = object.__new__(PolicyEnvironment)
144+
result._functions = self._functions # noqa: SLF001
145+
result._params = params # noqa: SLF001
146+
147+
return result
148+
127149

128150
def set_up_policy_environment(date: datetime.date | str | int) -> PolicyEnvironment:
129151
"""

0 commit comments

Comments
 (0)