Skip to content

Commit fb1e382

Browse files
committed
docs: add StepVisualizationWrapper and showcase its usage
1 parent 37781fe commit fb1e382

File tree

4 files changed

+434
-77
lines changed

4 files changed

+434
-77
lines changed

docs/api/wrappers.md

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@ JaxARC provides two types of wrappers:
88
flatten)
99
- **Observation Wrappers**: Add channels to observations (input grid, answer,
1010
clipboard, context)
11+
- **Visualization Wrappers**: Enhance rendering capabilities (step visualization)
1112

1213
## Action Wrappers
1314

1415
### PointActionWrapper
1516

1617
```{eval-rst}
17-
.. autoclass:: jaxarc.PointActionWrapper
18+
.. autoclass:: jaxarc.wrappers.PointActionWrapper
1819
:members:
1920
:undoc-members:
2021
:show-inheritance:
@@ -23,7 +24,7 @@ JaxARC provides two types of wrappers:
2324
### BboxActionWrapper
2425

2526
```{eval-rst}
26-
.. autoclass:: jaxarc.BboxActionWrapper
27+
.. autoclass:: jaxarc.wrappers.BboxActionWrapper
2728
:members:
2829
:undoc-members:
2930
:show-inheritance:
@@ -32,7 +33,7 @@ JaxARC provides two types of wrappers:
3233
### FlattenActionWrapper
3334

3435
```{eval-rst}
35-
.. autoclass:: jaxarc.FlattenActionWrapper
36+
.. autoclass:: jaxarc.wrappers.FlattenActionWrapper
3637
:members:
3738
:undoc-members:
3839
:show-inheritance:
@@ -43,7 +44,7 @@ JaxARC provides two types of wrappers:
4344
### InputGridObservationWrapper
4445

4546
```{eval-rst}
46-
.. autoclass:: jaxarc.InputGridObservationWrapper
47+
.. autoclass:: jaxarc.wrappers.InputGridObservationWrapper
4748
:members:
4849
:undoc-members:
4950
:show-inheritance:
@@ -52,7 +53,7 @@ JaxARC provides two types of wrappers:
5253
### AnswerObservationWrapper
5354

5455
```{eval-rst}
55-
.. autoclass:: jaxarc.AnswerObservationWrapper
56+
.. autoclass:: jaxarc.wrappers.AnswerObservationWrapper
5657
:members:
5758
:undoc-members:
5859
:show-inheritance:
@@ -61,7 +62,7 @@ JaxARC provides two types of wrappers:
6162
### ClipboardObservationWrapper
6263

6364
```{eval-rst}
64-
.. autoclass:: jaxarc.ClipboardObservationWrapper
65+
.. autoclass:: jaxarc.wrappers.ClipboardObservationWrapper
6566
:members:
6667
:undoc-members:
6768
:show-inheritance:
@@ -70,7 +71,18 @@ JaxARC provides two types of wrappers:
7071
### ContextualObservationWrapper
7172

7273
```{eval-rst}
73-
.. autoclass:: jaxarc.ContextualObservationWrapper
74+
.. autoclass:: jaxarc.wrappers.ContextualObservationWrapper
75+
:members:
76+
:undoc-members:
77+
:show-inheritance:
78+
```
79+
80+
## Visualization Wrappers
81+
82+
### StepVisualizationWrapper
83+
84+
```{eval-rst}
85+
.. autoclass:: jaxarc.wrappers.StepVisualizationWrapper
7486
:members:
7587
:undoc-members:
7688
:show-inheritance:
@@ -80,7 +92,7 @@ JaxARC provides two types of wrappers:
8092

8193
```python
8294
from jaxarc import make
83-
from jaxarc import PointActionWrapper, InputGridObservationWrapper
95+
from jaxarc.wrappers import PointActionWrapper, InputGridObservationWrapper
8496

8597
# Create base environment
8698
env, env_params = make("Mini")

docs/tutorials/using-wrappers.ipynb

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,21 @@
3939
"name": "stderr",
4040
"output_type": "stream",
4141
"text": [
42-
"\u001b[32m2025-11-03 22:12:09.413\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mjaxarc.utils.dataset_manager\u001b[0m:\u001b[36mvalidate_dataset\u001b[0m:\u001b[36m212\u001b[0m - \u001b[34m\u001b[1mDataset validation passed: /Users/aadam/workspace/JaxARC/data/MiniARC\u001b[0m\n",
43-
"\u001b[32m2025-11-03 22:12:09.414\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mjaxarc.utils.dataset_manager\u001b[0m:\u001b[36mensure_dataset_available\u001b[0m:\u001b[36m81\u001b[0m - \u001b[34m\u001b[1mDataset 'MiniARC' found at /Users/aadam/workspace/JaxARC/data/MiniARC\u001b[0m\n",
44-
"\u001b[32m2025-11-03 22:12:09.417\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mjaxarc.parsers.mini_arc\u001b[0m:\u001b[36m_validate_grid_constraints\u001b[0m:\u001b[36m104\u001b[0m - \u001b[1mMiniARC parser configured with optimal 5x5 grid constraints\u001b[0m\n",
45-
"\u001b[32m2025-11-03 22:12:09.419\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mjaxarc.parsers.mini_arc\u001b[0m:\u001b[36m_scan_available_tasks\u001b[0m:\u001b[36m131\u001b[0m - \u001b[1mFound 149 tasks in MiniARC dataset (lazy loading - tasks loaded on-demand, optimized for 5x5 grids)\u001b[0m\n",
46-
"\u001b[32m2025-11-03 22:12:09.420\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mjaxarc.parsers.mini_arc\u001b[0m:\u001b[36m_load_task_from_disk\u001b[0m:\u001b[36m171\u001b[0m - \u001b[34m\u001b[1mLoaded MiniARC task 'Most_Common_color_l6ab0lf3xztbyxsu3p' from disk\u001b[0m\n",
47-
"\u001b[32m2025-11-03 22:12:09.816\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mjaxarc.parsers.base_parser\u001b[0m:\u001b[36m_log_parsing_stats\u001b[0m:\u001b[36m479\u001b[0m - \u001b[34m\u001b[1mTask Most_Common_color_l6ab0lf3xztbyxsu3p: 3 train pairs, 1 test pairs, max grid size: 5x5\u001b[0m\n",
48-
"\u001b[32m2025-11-03 22:12:09.817\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mjaxarc.utils.task_manager\u001b[0m:\u001b[36mget_global_task_manager\u001b[0m:\u001b[36m236\u001b[0m - \u001b[34m\u001b[1mCreated global task ID manager\u001b[0m\n",
49-
"\u001b[32m2025-11-03 22:12:09.817\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mjaxarc.utils.task_manager\u001b[0m:\u001b[36mregister_task\u001b[0m:\u001b[36m72\u001b[0m - \u001b[34m\u001b[1mRegistered task 'Most_Common_color_l6ab0lf3xztbyxsu3p' with index 0\u001b[0m\n"
42+
"\u001b[32m2025-11-18 22:47:09.240\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mjaxarc.utils.dataset_manager\u001b[0m:\u001b[36mvalidate_dataset\u001b[0m:\u001b[36m212\u001b[0m - \u001b[34m\u001b[1mDataset validation passed: /Users/aadam/workspace/JaxARC/data/MiniARC\u001b[0m\n",
43+
"\u001b[32m2025-11-18 22:47:09.240\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mjaxarc.utils.dataset_manager\u001b[0m:\u001b[36mensure_dataset_available\u001b[0m:\u001b[36m81\u001b[0m - \u001b[34m\u001b[1mDataset 'MiniARC' found at /Users/aadam/workspace/JaxARC/data/MiniARC\u001b[0m\n",
44+
"\u001b[32m2025-11-18 22:47:09.243\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mjaxarc.parsers.mini_arc\u001b[0m:\u001b[36m_validate_grid_constraints\u001b[0m:\u001b[36m104\u001b[0m - \u001b[1mMiniARC parser configured with optimal 5x5 grid constraints\u001b[0m\n",
45+
"\u001b[32m2025-11-18 22:47:09.245\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mjaxarc.parsers.mini_arc\u001b[0m:\u001b[36m_scan_available_tasks\u001b[0m:\u001b[36m131\u001b[0m - \u001b[1mFound 149 tasks in MiniARC dataset (lazy loading - tasks loaded on-demand, optimized for 5x5 grids)\u001b[0m\n",
46+
"\u001b[32m2025-11-18 22:47:09.246\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mjaxarc.parsers.mini_arc\u001b[0m:\u001b[36m_load_task_from_disk\u001b[0m:\u001b[36m171\u001b[0m - \u001b[34m\u001b[1mLoaded MiniARC task 'Most_Common_color_l6ab0lf3xztbyxsu3p' from disk\u001b[0m\n",
47+
"\u001b[32m2025-11-18 22:47:09.240\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mjaxarc.utils.dataset_manager\u001b[0m:\u001b[36mensure_dataset_available\u001b[0m:\u001b[36m81\u001b[0m - \u001b[34m\u001b[1mDataset 'MiniARC' found at /Users/aadam/workspace/JaxARC/data/MiniARC\u001b[0m\n",
48+
"\u001b[32m2025-11-18 22:47:09.243\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mjaxarc.parsers.mini_arc\u001b[0m:\u001b[36m_validate_grid_constraints\u001b[0m:\u001b[36m104\u001b[0m - \u001b[1mMiniARC parser configured with optimal 5x5 grid constraints\u001b[0m\n",
49+
"\u001b[32m2025-11-18 22:47:09.245\u001b[0m | \u001b[1mINFO \u001b[0m | \u001b[36mjaxarc.parsers.mini_arc\u001b[0m:\u001b[36m_scan_available_tasks\u001b[0m:\u001b[36m131\u001b[0m - \u001b[1mFound 149 tasks in MiniARC dataset (lazy loading - tasks loaded on-demand, optimized for 5x5 grids)\u001b[0m\n",
50+
"\u001b[32m2025-11-18 22:47:09.246\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mjaxarc.parsers.mini_arc\u001b[0m:\u001b[36m_load_task_from_disk\u001b[0m:\u001b[36m171\u001b[0m - \u001b[34m\u001b[1mLoaded MiniARC task 'Most_Common_color_l6ab0lf3xztbyxsu3p' from disk\u001b[0m\n",
51+
"\u001b[32m2025-11-18 22:47:09.658\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mjaxarc.parsers.base_parser\u001b[0m:\u001b[36m_log_parsing_stats\u001b[0m:\u001b[36m479\u001b[0m - \u001b[34m\u001b[1mTask Most_Common_color_l6ab0lf3xztbyxsu3p: 3 train pairs, 1 test pairs, max grid size: 5x5\u001b[0m\n",
52+
"\u001b[32m2025-11-18 22:47:09.658\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mjaxarc.utils.task_manager\u001b[0m:\u001b[36mget_global_task_manager\u001b[0m:\u001b[36m236\u001b[0m - \u001b[34m\u001b[1mCreated global task ID manager\u001b[0m\n",
53+
"\u001b[32m2025-11-18 22:47:09.659\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mjaxarc.utils.task_manager\u001b[0m:\u001b[36mregister_task\u001b[0m:\u001b[36m72\u001b[0m - \u001b[34m\u001b[1mRegistered task 'Most_Common_color_l6ab0lf3xztbyxsu3p' with index 0\u001b[0m\n",
54+
"\u001b[32m2025-11-18 22:47:09.658\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mjaxarc.parsers.base_parser\u001b[0m:\u001b[36m_log_parsing_stats\u001b[0m:\u001b[36m479\u001b[0m - \u001b[34m\u001b[1mTask Most_Common_color_l6ab0lf3xztbyxsu3p: 3 train pairs, 1 test pairs, max grid size: 5x5\u001b[0m\n",
55+
"\u001b[32m2025-11-18 22:47:09.658\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mjaxarc.utils.task_manager\u001b[0m:\u001b[36mget_global_task_manager\u001b[0m:\u001b[36m236\u001b[0m - \u001b[34m\u001b[1mCreated global task ID manager\u001b[0m\n",
56+
"\u001b[32m2025-11-18 22:47:09.659\u001b[0m | \u001b[34m\u001b[1mDEBUG \u001b[0m | \u001b[36mjaxarc.utils.task_manager\u001b[0m:\u001b[36mregister_task\u001b[0m:\u001b[36m72\u001b[0m - \u001b[34m\u001b[1mRegistered task 'Most_Common_color_l6ab0lf3xztbyxsu3p' with index 0\u001b[0m\n"
5057
]
5158
},
5259
{
@@ -117,13 +124,17 @@
117124
"Action keys: ['operation', 'row', 'col']\n",
118125
"\n",
119126
"Initial observation shape: (5, 5, 1)\n",
127+
"\n",
128+
"Initial observation shape: (5, 5, 1)\n",
129+
"Point action executed: {'operation': 2, 'row': 2, 'col': 3}\n",
130+
"Reward: -0.005\n",
120131
"Point action executed: {'operation': 2, 'row': 2, 'col': 3}\n",
121132
"Reward: -0.005\n"
122133
]
123134
}
124135
],
125136
"source": [
126-
"from jaxarc.envs import PointActionWrapper\n",
137+
"from jaxarc.wrappers import PointActionWrapper\n",
127138
"\n",
128139
"# Wrap environment\n",
129140
"point_env = PointActionWrapper(env)\n",
@@ -176,7 +187,7 @@
176187
}
177188
],
178189
"source": [
179-
"from jaxarc.envs import BboxActionWrapper\n",
190+
"from jaxarc.wrappers import BboxActionWrapper\n",
180191
"\n",
181192
"# Wrap environment\n",
182193
"bbox_env = BboxActionWrapper(env)\n",
@@ -226,7 +237,7 @@
226237
}
227238
],
228239
"source": [
229-
"from jaxarc.envs import FlattenActionWrapper\n",
240+
"from jaxarc.wrappers import FlattenActionWrapper\n",
230241
"\n",
231242
"# Wrap environment\n",
232243
"# Using PointActionWrapper here to reduce the action space size for demonstration\n",
@@ -277,12 +288,16 @@
277288
"+ AnswerObservationWrapper: (5, 5, 3)\n",
278289
"+ ClipboardObservationWrapper: (5, 5, 4)\n",
279290
"\n",
291+
"Total channels so far: 4\n",
292+
"+ AnswerObservationWrapper: (5, 5, 3)\n",
293+
"+ ClipboardObservationWrapper: (5, 5, 4)\n",
294+
"\n",
280295
"Total channels so far: 4\n"
281296
]
282297
}
283298
],
284299
"source": [
285-
"from jaxarc.envs import (\n",
300+
"from jaxarc.wrappers import (\n",
286301
" AnswerObservationWrapper,\n",
287302
" ClipboardObservationWrapper,\n",
288303
" InputGridObservationWrapper,\n",
@@ -346,7 +361,7 @@
346361
}
347362
],
348363
"source": [
349-
"from jaxarc.envs import ContextualObservationWrapper\n",
364+
"from jaxarc.wrappers import ContextualObservationWrapper\n",
350365
"\n",
351366
"# Add 3 demonstration pairs as context\n",
352367
"env_with_context = ContextualObservationWrapper(env_with_clipboard, num_context_pairs=3)\n",
@@ -429,8 +444,10 @@
429444
"| `AnswerObservationWrapper` | Add answer grid channel | Training with supervision |\n",
430445
"| `ClipboardObservationWrapper` | Add clipboard channel | Copy-paste operations |\n",
431446
"| `ContextualObservationWrapper` | Add demonstration pairs | Few-shot learning, pattern recognition |\n",
447+
"| **Visualization Wrappers** | | |\n",
448+
"| `StepVisualizationWrapper` | Enable detailed SVG rendering | Debugging agent actions and transitions |\n",
432449
"\n",
433-
"Wrappers enhance environment usability without altering core logic. They enable flexible action formats and richer observations, facilitating effective agent training and evaluation."
450+
"Wrappers enhance environment usability without altering core logic. They enable flexible action formats, richer observations, and better visualization, facilitating effective agent training and evaluation."
434451
]
435452
}
436453
],

0 commit comments

Comments
 (0)