Skip to content

Commit ab37c03

Browse files
authored
Add FileIO-level allowed_paths check (#831)
# Motivation <!-- Why is this change necessary? --> # Content <!-- Please include a summary of the change --> # Testing <!-- How was the change tested? --> # Please check the following before marking your PR as ready for review - [ ] I have added tests for my changes - [ ] I have updated the documentation or added new documentation as needed
1 parent 8031e77 commit ab37c03

File tree

4 files changed

+128
-10
lines changed

4 files changed

+128
-10
lines changed

src/codegen/sdk/codebase/codebase_context.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,6 @@ def __init__(
159159

160160
# =====[ __init__ attributes ]=====
161161
self.projects = projects
162-
self.io = io or FileIO()
163162
context = projects[0]
164163
self.node_classes = get_node_classes(context.programming_language)
165164
self.config = config or CodebaseConfig()
@@ -169,6 +168,11 @@ def __init__(
169168
self.full_path = os.path.join(self.repo_path, context.base_path) if context.base_path else self.repo_path
170169
self.codeowners_parser = context.repo_operator.codeowners_parser
171170
self.base_url = context.repo_operator.base_url
171+
if not self.config.allow_external:
172+
# TODO: Fix this to be more robust with multiple projects
173+
self.io = io or FileIO(allowed_paths=[Path(self.repo_path).resolve()])
174+
else:
175+
self.io = io or FileIO()
172176
# =====[ computed attributes ]=====
173177
self.transaction_manager = TransactionManager()
174178
self._autocommit = AutoCommit(self)
@@ -188,6 +192,13 @@ def __init__(
188192
logger.warning("WARNING: The codebase is using an unsupported language!")
189193
logger.warning("Some features may not work as expected. Advanced static analysis will be disabled but simple file IO will still work.")
190194

195+
# Assert config assertions
196+
# External import resolution must be enabled if syspath is enabled
197+
if self.config.py_resolve_syspath:
198+
if not self.config.allow_external:
199+
msg = "allow_external must be set to True when py_resolve_syspath is enabled"
200+
raise ValueError(msg)
201+
191202
# Build the graph
192203
if not self.config.exp_lazy_graph and self.config.use_pink != PinkMode.ALL_FILES:
193204
self.build_graph(context.repo_operator)

src/codegen/sdk/codebase/io/file_io.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,33 @@ class FileIO(IO):
1111
"""IO implementation that writes files to disk, and tracks pending changes."""
1212

1313
files: dict[Path, bytes]
14+
allowed_paths: list[Path] | None
1415

15-
def __init__(self):
16+
def __init__(self, allowed_paths: list[Path] | None = None):
1617
self.files = {}
18+
self.allowed_paths = allowed_paths
19+
20+
def _verify_path(self, path: Path) -> None:
21+
if self.allowed_paths is not None:
22+
if not any(path.resolve().is_relative_to(p.resolve()) for p in self.allowed_paths):
23+
msg = f"Path {path.resolve()} is not within allowed paths {self.allowed_paths}"
24+
raise BadWriteError(msg)
1725

1826
def write_bytes(self, path: Path, content: bytes) -> None:
27+
self._verify_path(path)
1928
self.files[path] = content
2029

2130
def read_bytes(self, path: Path) -> bytes:
31+
self._verify_path(path)
2232
if path in self.files:
2333
return self.files[path]
2434
else:
2535
return path.read_bytes()
2636

2737
def save_files(self, files: set[Path] | None = None) -> None:
2838
to_save = set(filter(lambda f: f in files, self.files)) if files is not None else self.files.keys()
39+
for path in to_save:
40+
self._verify_path(path)
2941
with ThreadPoolExecutor() as exec:
3042
exec.map(lambda path: path.write_bytes(self.files[path]), to_save)
3143
if files is None:
@@ -40,12 +52,15 @@ def check_changes(self) -> None:
4052
self.files.clear()
4153

4254
def delete_file(self, path: Path) -> None:
55+
self._verify_path(path)
4356
self.untrack_file(path)
4457
if path.exists():
4558
path.unlink()
4659

4760
def untrack_file(self, path: Path) -> None:
61+
self._verify_path(path)
4862
self.files.pop(path, None)
4963

5064
def file_exists(self, path: Path) -> bool:
65+
self._verify_path(path)
5166
return path.exists()

src/codegen/sdk/core/codebase.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -217,13 +217,6 @@ def __init__(
217217

218218
self._pink_codebase = codegen_sdk_pink.Codebase(self.repo_path)
219219

220-
# Assert config assertions
221-
# External import resolution must be enabled if syspath is enabled
222-
if self.ctx.config.py_resolve_syspath:
223-
if not self.ctx.config.allow_external:
224-
msg = "allow_external must be set to True when py_resolve_syspath is enabled"
225-
raise ValueError(msg)
226-
227220
@noapidoc
228221
def __str__(self) -> str:
229222
return f"<Codebase(name={self.name}, language={self.language}, path={self.repo_path})>"

tests/unit/codegen/sdk/io/test_file_io.py

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22

3-
from codegen.sdk.codebase.io.file_io import FileIO
3+
from codegen.sdk.codebase.io.file_io import BadWriteError, FileIO
44

55

66
@pytest.fixture
@@ -61,3 +61,102 @@ def test_delete_file(file_io, tmp_path):
6161

6262
assert not test_file.exists()
6363
assert test_file not in file_io.files
64+
65+
66+
def test_read_and_write_bounded(file_io, tmp_path):
67+
allowed_dir = tmp_path / "allowed"
68+
file_io.allowed_paths = [allowed_dir]
69+
70+
allowed_file = allowed_dir / "test.txt"
71+
content = b"test content"
72+
73+
file_io.write_bytes(allowed_file, content)
74+
assert file_io.read_bytes(allowed_file) == content
75+
76+
with pytest.raises(BadWriteError) as exc_info:
77+
bad_file = tmp_path / "test.txt"
78+
file_io.write_bytes(bad_file, content)
79+
80+
assert "is not within allowed paths" in str(exc_info.value)
81+
82+
with pytest.raises(BadWriteError) as exc_info:
83+
bad_file_2 = allowed_dir / ".." / "test2.txt"
84+
file_io.write_bytes(bad_file_2, content)
85+
86+
assert "is not within allowed paths" in str(exc_info.value)
87+
88+
89+
def test_read_bounded(file_io, tmp_path):
90+
allowed_dir = tmp_path / "allowed"
91+
allowed_dir.mkdir(exist_ok=True)
92+
file_io.allowed_paths = [allowed_dir]
93+
94+
allowed_file = allowed_dir / "test.txt"
95+
content = b"test content"
96+
allowed_file.write_bytes(content)
97+
98+
assert file_io.read_bytes(allowed_file) == content
99+
100+
with pytest.raises(BadWriteError) as exc_info:
101+
bad_file = tmp_path / "test.txt"
102+
bad_file.write_bytes(content)
103+
file_io.read_bytes(bad_file)
104+
105+
assert "is not within allowed paths" in str(exc_info.value)
106+
107+
with pytest.raises(BadWriteError) as exc_info:
108+
bad_file_2 = allowed_dir / ".." / "test2.txt"
109+
bad_file_2.write_bytes(content)
110+
file_io.read_bytes(bad_file_2)
111+
112+
assert "is not within allowed paths" in str(exc_info.value)
113+
114+
115+
def test_delete_file_bounded(file_io, tmp_path):
116+
allowed_dir = tmp_path / "allowed"
117+
allowed_dir.mkdir(exist_ok=True)
118+
file_io.allowed_paths = [allowed_dir]
119+
120+
allowed_file = allowed_dir / "test.txt"
121+
allowed_file.write_bytes(b"test content")
122+
123+
file_io.delete_file(allowed_file)
124+
125+
with pytest.raises(BadWriteError) as exc_info:
126+
bad_file = tmp_path / "test.txt"
127+
bad_file.write_bytes(b"test content")
128+
file_io.delete_file(bad_file)
129+
130+
assert "is not within allowed paths" in str(exc_info.value)
131+
132+
with pytest.raises(BadWriteError) as exc_info:
133+
bad_file_2 = allowed_dir / ".." / "test2.txt"
134+
bad_file_2.write_bytes(b"test content")
135+
file_io.delete_file(bad_file_2)
136+
137+
assert "is not within allowed paths" in str(exc_info.value)
138+
139+
140+
def test_file_exists_bounded(file_io, tmp_path):
141+
allowed_dir = tmp_path / "allowed"
142+
allowed_dir.mkdir(exist_ok=True)
143+
file_io.allowed_paths = [allowed_dir]
144+
145+
allowed_file = allowed_dir / "test.txt"
146+
allowed_file.write_bytes(b"test content")
147+
148+
assert file_io.file_exists(allowed_file)
149+
150+
with pytest.raises(BadWriteError) as exc_info:
151+
bad_file = tmp_path / "test.txt"
152+
bad_file.write_bytes(b"test content")
153+
file_io.file_exists(bad_file)
154+
155+
assert "is not within allowed paths" in str(exc_info.value)
156+
157+
with pytest.raises(BadWriteError) as exc_info:
158+
bad_file_2 = allowed_dir / ".." / "test2.txt"
159+
bad_file_2.write_bytes(b"test content")
160+
file_io.file_exists(bad_file_2)
161+
162+
assert "is not within allowed paths" in str(exc_info.value)

0 commit comments

Comments
 (0)