Skip to content

Commit ee2e7c3

Browse files
committed
feat: update discover functions to return fastapi app
1 parent 4aa15bd commit ee2e7c3

File tree

1 file changed

+27
-6
lines changed

1 file changed

+27
-6
lines changed

src/fastapi_cli/discover.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,27 +45,34 @@ class ModuleData:
4545
def get_module_data_from_path(path: Path) -> ModuleData:
4646
use_path = path.resolve()
4747
module_path = use_path
48+
4849
if use_path.is_file() and use_path.stem == "__init__":
4950
module_path = use_path.parent
51+
5052
module_paths = [module_path]
5153
extra_sys_path = module_path.parent
54+
5255
for parent in module_path.parents:
5356
init_path = parent / "__init__.py"
57+
5458
if init_path.is_file():
5559
module_paths.insert(0, parent)
5660
extra_sys_path = parent.parent
5761
else:
5862
break
5963

6064
module_str = ".".join(p.stem for p in module_paths)
65+
6166
return ModuleData(
6267
module_import_str=module_str,
6368
extra_sys_path=extra_sys_path.resolve(),
6469
module_paths=module_paths,
6570
)
6671

6772

68-
def get_app_name(*, mod_data: ModuleData, app_name: Union[str, None] = None) -> str:
73+
def get_app_name(
74+
*, mod_data: ModuleData, app_name: Union[str, None] = None
75+
) -> tuple[str, FastAPI]:
6976
try:
7077
mod = importlib.import_module(mod_data.module_import_str)
7178
except (ImportError, ValueError) as e:
@@ -74,32 +81,41 @@ def get_app_name(*, mod_data: ModuleData, app_name: Union[str, None] = None) ->
7481
"Ensure all the package directories have an [blue]__init__.py[/blue] file"
7582
)
7683
raise
84+
7785
if not FastAPI: # type: ignore[truthy-function]
7886
raise FastAPICLIException(
7987
"Could not import FastAPI, try running 'pip install fastapi'"
8088
) from None
89+
8190
object_names = dir(mod)
8291
object_names_set = set(object_names)
92+
8393
if app_name:
8494
if app_name not in object_names_set:
8595
raise FastAPICLIException(
8696
f"Could not find app name {app_name} in {mod_data.module_import_str}"
8797
)
98+
8899
app = getattr(mod, app_name)
100+
89101
if not isinstance(app, FastAPI):
90102
raise FastAPICLIException(
91103
f"The app name {app_name} in {mod_data.module_import_str} doesn't seem to be a FastAPI app"
92104
)
93-
return app_name
105+
106+
return app_name, app
107+
94108
for preferred_name in ["app", "api"]:
95109
if preferred_name in object_names_set:
96110
obj = getattr(mod, preferred_name)
97111
if isinstance(obj, FastAPI):
98-
return preferred_name
112+
return preferred_name, obj
113+
99114
for name in object_names:
100115
obj = getattr(mod, name)
101116
if isinstance(obj, FastAPI):
102-
return name
117+
return name, obj
118+
103119
raise FastAPICLIException("Could not find FastAPI app in module, try using --app")
104120

105121

@@ -108,6 +124,7 @@ class ImportData:
108124
app_name: str
109125
module_data: ModuleData
110126
import_string: str
127+
fastapi_app: FastAPI
111128

112129

113130
def get_import_data(
@@ -121,12 +138,16 @@ def get_import_data(
121138

122139
if not path.exists():
123140
raise FastAPICLIException(f"Path does not exist {path}")
141+
124142
mod_data = get_module_data_from_path(path)
125143
sys.path.insert(0, str(mod_data.extra_sys_path))
126-
use_app_name = get_app_name(mod_data=mod_data, app_name=app_name)
144+
use_app_name, app = get_app_name(mod_data=mod_data, app_name=app_name)
127145

128146
import_string = f"{mod_data.module_import_str}:{use_app_name}"
129147

130148
return ImportData(
131-
app_name=use_app_name, module_data=mod_data, import_string=import_string
149+
app_name=use_app_name,
150+
module_data=mod_data,
151+
import_string=import_string,
152+
fastapi_app=app,
132153
)

0 commit comments

Comments
 (0)