Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 79 additions & 53 deletions src/model_fit_api/app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Literal, List, Dict
from typing import Literal, List, Dict, Optional

import numpy as np
import pandas as pd
Expand All @@ -8,58 +8,54 @@
from fastapi import FastAPI
from pydantic import BaseModel

models = [
"nugent-sn1a",
"nugent-sn91t",
"nugent-sn91bg",
"nugent-sn1bc",
"nugent-hyper",
"nugent-sn2n",
"nugent-sn2p",
"nugent-sn2l",
"salt2",
"salt3-nir",
"salt3",
"snf-2011fe",
"v19-1993j",
"v19-1998bw",
"v19-1999em",
"v19-2009ip",
]

app = FastAPI()


class Observation(BaseModel):
mjd: float
flux: float
fluxerr: float
zp: float = 8.9
mjd: Optional[float]
band: Optional[str]
flux: Optional[float]
fluxerr: Optional[float]
zp: Optional[float] = 8.9
zpsys: Literal["ab", "vega"] = "ab"
band: str


class Target(BaseModel):
light_curve: List[Observation]
ebv: float
t_min: float
t_max: float
count: int
name_model: str
ebv: Optional[float]
t_min: Optional[float]
t_max: Optional[float]
count: Optional[int]
name_model: Optional[str]
redshift: List[float]


class Model_data(BaseModel):
parameters: Dict[str, float]
name_model: str
zp: float
zpsys: str
band_list: List[str]
t_min: float
t_max: float
count: int

class Point(BaseModel):
time: float
flux: float
band: str
time: Optional[float]
difflux: Optional[float]
band: Optional[str]


class Result(BaseModel):
flux_jansky: List[Point]
class Parameters(BaseModel):
degrees_of_freedom: int
covariance: List[List[float]]
chi2: float
parameters: Dict[str, float]


class Flux(BaseModel):
flux_jansky: List[Point]


def fit(data, name_model, ebv, redshift):
Expand All @@ -72,36 +68,66 @@ def fit(data, name_model, ebv, redshift):
return summary, fitted_model


def get_flux_and_params(summary, data, fitted_model, t_min, t_max, count):
segment = np.linspace(t_min, t_max, count)
df = data.to_pandas()
def get_flux(data: Model_data):
dust = sncosmo.CCM89Dust()
fitted_model = sncosmo.Model(source=data.name_model, effects=[dust], effect_names=["mw"], effect_frames=["obs"])
fitted_model.set(**data.parameters)
segment = np.linspace(data.t_min, data.t_max, data.count)
points = []
for band in df["band"].unique():
predicts = fitted_model.bandflux(band, segment, df["zp"][0], df["zpsys"][0])
for band in data.band_list:
predicts = fitted_model.bandflux(band, segment, data.zp, data.zpsys)
points += [Point(time=time, flux=flux, band=band) for time, flux in zip(segment, predicts)]
return Result(
flux_jansky=points,
return Flux(
flux_jansky=points
)


def get_params(data: Target):
df = pd.DataFrame([dict(obs) for obs in data.light_curve])
table = Table.from_pandas(df)
summary, fitted_model = fit(table, data.name_model, data.ebv, data.redshift)
try: cov=summary.covariance.tolist()
except:
cov=[[]]
print('covariance is none')
return Parameters(
parameters=dict(zip(summary.param_names, summary.parameters)),
degrees_of_freedom=summary.ndof,
covariance=summary.covariance.tolist(),
covariance=cov,
chi2=summary.chisq,
)


def approximate(data: Target):
df = pd.DataFrame([obs.model_dump() for obs in data.light_curve])
table = Table.from_pandas(df)
summary, fitted_model = fit(table, data.name_model, data.ebv, data.redshift)
result = get_flux_and_params(summary, table, fitted_model, data.t_min, data.t_max, data.count)
return result


@app.post("/api/v1/sncosmo")
@app.post("/api/v1/sncosmo/fit")
async def sn_cosmo(data: Target):
"""Fit light curve with sncosmo."""
return approximate(data)
return get_params(data)


@app.post("/api/v1/sncosmo/get_bright")
async def sn_cosmo(data: Model_data):
"""Fit light curve with sncosmo."""
return get_flux(data)


@app.get("/api/v1/models")
async def models(data: Target):
async def models():
models = [
"nugent-sn1a",
"nugent-sn91t",
"nugent-sn91bg",
"nugent-sn1bc",
"nugent-hyper",
"nugent-sn2n",
"nugent-sn2p",
"nugent-sn2l",
"salt2",
"salt3-nir",
"salt3",
"snf-2011fe",
"v19-1993j",
"v19-1998bw",
"v19-1999em",
"v19-2009ip",
]
return {"models": models}
Loading