Skip to content
This repository was archived by the owner on Sep 11, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 110 additions & 42 deletions astrodbkit/astrocat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@
import astropy.table as at
import astropy.coordinates as coord
import datetime
from sklearn.cluster import DBSCAN
from collections import Counter
from scipy.stats import norm
# from scipy.stats import norm
from astroquery.vizier import Vizier
from astroquery.xmatch import XMatch
from sklearn.externals import joblib
from astropy.coordinates import SkyCoord
import pandas as pd
from bokeh.plotting import ColumnDataSource, figure, output_file, show
from bokeh.io import output_notebook, show
from sklearn.cluster import DBSCAN
from sklearn.externals import joblib

Vizier.ROW_LIMIT = -1

Expand All @@ -33,7 +36,7 @@ def __init__(self, name='Test'):
The name of the database
"""
self.name = name
self.catalog = pd.DataFrame(columns=('id','ra','dec','flag','datasets'))
self.sources = pd.DataFrame(columns=('id', 'ra', 'dec', 'flag', 'datasets'))
self.n_sources = 0
self.history = "{}: Database created".format(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
self.catalogs = {}
Expand All @@ -49,7 +52,40 @@ def info(self):
"""
print(self.history)

def add_source(self, ra, dec, flag='', radius=10*q.arcsec):
def plot(self, cat_name, x, y, **kwargs):
"""
Plot the named columns of the given attribute if the value is a pandas.DataFrame

Parameters
----------
cat_name: str
The attribute name
x: str
The name of the column to plot on the x-axis
y: str
The name of the column to plot on the y-axis
"""
# Get the attribute
if isinstance(cat_name, str) and hasattr(self, cat_name):
attr = getattr(self, cat_name)
else:
print('No attribute named',cat_name)
return

# Make sure the attribute is a DataFrame
if isinstance(attr, pd.core.frame.DataFrame):
ds = ColumnDataSource(attr)
myPlot = figure()
myPlot.xaxis.axis_label = x
myPlot.yaxis.axis_label = y
myPlot.circle(x, y, source=ds)
plt = show(myPlot, notebook_handle=True)

else:
print(cat_name,'is not a Pandas DataFrame!')
return

def add_source(self, ra, dec, flag='', radius=10*q.arcsec, catalogs={}):
"""
Add a source to the catalog manually and find data in existing catalogs

Expand All @@ -63,9 +99,12 @@ def add_source(self, ra, dec, flag='', radius=10*q.arcsec):
A flag for the source
radius: float
The cross match radius for the list of catalogs
catalogs: dict
Additional catalogs to search, e.g.
catalogs={'TMASS':{'cat_loc':'II/246/out', 'id_col':'id', 'ra_col':'RAJ2000', 'dec_col':'DEJ2000'}}
"""
# Get the id
id = int(len(self.catalog)+1)
# Set the id
self.n_sources += 1

# Check the coordinates
ra = ra.to(q.deg)
Expand All @@ -74,10 +113,16 @@ def add_source(self, ra, dec, flag='', radius=10*q.arcsec):

# Search the catalogs for this source
for cat_name,params in self.catalogs.items():
self.Vizier_query(params['cat_loc'], cat_name, ra, dec, radius, ra_col=params['ra_col'], dec_col=params['dec_col'], append=True, group=False)
self.Vizier_query(params['cat_loc'], cat_name, ra, dec, radius, ra_col=params['ra_col'], dec_col=params['dec_col'], append=True, force_id=self.n_sources, group=False)

# Search additional catalogs
for cat_name,params in catalogs.items():
if cat_name not in self.catalogs:
self.Vizier_query(params['cat_loc'], cat_name, ra, dec, radius, ra_col=params['ra_col'], dec_col=params['dec_col'], force_id=self.n_sources, group=False)

# Add the source to the catalog
self.catalog = self.catalog.append([id, ra.value, dec.value, flag, datasets], ignore_index=True)
new_cat = pd.DataFrame([[self.n_sources, ra.value, dec.value, flag, datasets]], columns=self.sources.columns)
self.sources = self.sources.append(new_cat, ignore_index=True)

def delete_source(self, id):
"""
Expand All @@ -89,18 +134,18 @@ def delete_source(self, id):
The id of the source in the catalog
"""
# Set the index
self.catalog.set_index('id')
self.sources.set_index('id')

# Exclude the unwanted source
self.catalog = self.catalog[self.catalog.id!=id]
self.sources = self.sources[self.sources.id!=id]

# Remove the records from the catalogs
for cat_name in self.catalogs:
new_cat = getattr(self, cat_name)[getattr(self, cat_name).source_id!=id]
print('{} records removed from {} catalog'.format(int(len(getattr(self, cat_name))-len(new_cat)), cat_name))
setattr(self, cat_name, new_cat)

def ingest_data(self, data, cat_name, id_col, ra_col='_RAJ2000', dec_col='_DEJ2000', cat_loc='', append=False, count=-1):
def ingest_data(self, data, cat_name, id_col, ra_col='_RAJ2000', dec_col='_DEJ2000', cat_loc='', append=False, delimiter='\t', force_id='', count=-1):
"""
Ingest a data file and regroup sources

Expand All @@ -120,6 +165,8 @@ def ingest_data(self, data, cat_name, id_col, ra_col='_RAJ2000', dec_col='_DEJ20
The location of the original catalog data
append: bool
Append the catalog rather than replace
force_id: int
Assigns a specific id in the catalog
count: int
The number of table rows to add
(This is mainly for testing purposes)
Expand All @@ -133,7 +180,7 @@ def ingest_data(self, data, cat_name, id_col, ra_col='_RAJ2000', dec_col='_DEJ20

if isinstance(data, str):
cat_loc = cat_loc or data
data = pd.read_csv(data, sep='\t', comment='#', engine='python')[:count]
data = pd.read_csv(data, sep=delimiter, comment='#', engine='python')[:count]

elif isinstance(data, pd.core.frame.DataFrame):
cat_loc = cat_loc or type(data)
Expand Down Expand Up @@ -167,7 +214,7 @@ def ingest_data(self, data, cat_name, id_col, ra_col='_RAJ2000', dec_col='_DEJ20
data.insert(0,'catID', ['{}_{}'.format(cat_name,n+1) for n in range(last,last+len(data))])
data.insert(0,'dec_corr', data['dec'])
data.insert(0,'ra_corr', data['ra'])
data.insert(0,'source_id', np.nan)
data.insert(0,'source_id', force_id or np.nan)

print('Ingesting {} rows from {} catalog...'.format(len(data),cat_name))

Expand All @@ -185,7 +232,7 @@ def ingest_data(self, data, cat_name, id_col, ra_col='_RAJ2000', dec_col='_DEJ20
except AttributeError:
print("No catalog named '{}'. Set 'append=False' to create it.".format(cat_name))

def inventory(self, source_id):
def inventory(self, source_id, return_inventory=False):
"""
Look at the inventory for a given source

Expand All @@ -203,15 +250,30 @@ def inventory(self, source_id):
print('Please enter an integer between 1 and',self.n_sources)

else:

print('Source:')
print(at.Table.from_pandas(self.catalog[self.catalog['id']==source_id]).pprint())

# Empty inventory
inv = {}

# Add the record from the source table
inv['source'] = at.Table.from_pandas(self.sources[self.sources['id']==source_id])

for cat_name in self.catalogs:
cat = getattr(self, cat_name)
rows = cat[cat['source_id']==source_id]
if not rows.empty:
print('\n{}:'.format(cat_name))
at.Table.from_pandas(rows).pprint()
inv[cat_name] = at.Table.from_pandas(rows)

if return_inventory:

# Return the data
return inv

else:

# Print out the data in each catalog
for cat_name, data in inv.items():
print('\n',cat_name,':')
data.pprint()

def _catalog_check(self, cat_name, append=False):
"""
Expand Down Expand Up @@ -262,7 +324,7 @@ def SDSS_spectra_query(self, cat_name, ra, dec, radius, group=True, **kwargs):
if self._catalog_check(cat_name):

# Prep the current catalog as an astropy.QTable
tab = at.Table.from_pandas(self.catalog)
tab = at.Table.from_pandas(self.sources)

# Cone search Vizier
print("Searching SDSS for sources within {} of ({}, {}). Please be patient...".format(viz_cat, radius, ra, dec))
Expand All @@ -280,7 +342,7 @@ def SDSS_spectra_query(self, cat_name, ra, dec, radius, group=True, **kwargs):
if len(self.catalogs)>1 and group:
self.group_sources(self.xmatch_radius)

def Vizier_query(self, viz_cat, cat_name, ra, dec, radius, ra_col='RAJ2000', dec_col='DEJ2000', columns=["**"], append=False, group=True, **kwargs):
def Vizier_query(self, viz_cat, cat_name, ra, dec, radius, ra_col='RAJ2000', dec_col='DEJ2000', columns=["**", "+_r"], append=False, force_id='', group=True, nrows=-1, **kwargs):
"""
Use astroquery to search a catalog for sources within a search cone

Expand All @@ -304,6 +366,8 @@ def Vizier_query(self, viz_cat, cat_name, ra, dec, radius, ra_col='RAJ2000', dec
The list of columns to pass to astroquery
append: bool
Append the catalog rather than replace
force_id: int
Assigns a specific id in the catalog
"""
# Verify the cat_name
if self._catalog_check(cat_name, append=append):
Expand All @@ -312,24 +376,28 @@ def Vizier_query(self, viz_cat, cat_name, ra, dec, radius, ra_col='RAJ2000', dec
print("Searching {} for sources within {} of ({}, {}). Please be patient...".format(viz_cat, radius, ra, dec))
crds = coord.SkyCoord(ra=ra, dec=dec, frame='icrs')
V = Vizier(columns=columns, **kwargs)
V.ROW_LIMIT = -1
V.ROW_LIMIT = nrows

try:
data = V.query_region(crds, radius=radius, catalog=viz_cat)[0]

# Add the link to original record
data['record'] = ['http://vizier.u-strasbg.fr/viz-bin/VizieR-5?-ref=VIZ5b17f9660734&-out.add=.&-source={}&recno={}'.format(viz_cat,n+1) for n in range(len(data))]

except:
print("No data found in {} within {} of ({}, {}).".format(viz_cat, radius, ra, dec))
return

# Ingest the data
self.ingest_data(data, cat_name, 'id', ra_col=ra_col, dec_col=dec_col, cat_loc=viz_cat, append=append)
self.ingest_data(data, cat_name, 'id', ra_col=ra_col, dec_col=dec_col, cat_loc=viz_cat, append=append, force_id=force_id)

# Regroup
if len(self.catalogs)>1 and group:
self.group_sources(self.xmatch_radius)

def Vizier_xmatch(self, viz_cat, cat_name, ra_col='_RAJ2000', dec_col='_DEJ2000', radius='', group=True):
"""
Use astroquery to pull in and cross match a catalog with sources in self.catalog
Use astroquery to pull in and cross match a catalog with sources in self.sources

Parameters
----------
Expand All @@ -341,7 +409,7 @@ def Vizier_xmatch(self, viz_cat, cat_name, ra_col='_RAJ2000', dec_col='_DEJ2000'
The matching radius
"""
# Make sure sources have been grouped
if self.catalog.empty:
if self.sources.empty:
print('Please run group_sources() before cross matching.')
return

Expand All @@ -351,7 +419,7 @@ def Vizier_xmatch(self, viz_cat, cat_name, ra_col='_RAJ2000', dec_col='_DEJ2000'
viz_cat = "vizier:{}".format(viz_cat)

# Prep the current catalog as an astropy.QTable
tab = at.Table.from_pandas(self.catalog)
tab = at.Table.from_pandas(self.sources)

# Crossmatch with Vizier
print("Cross matching {} sources with {} catalog. Please be patient...".format(len(tab), viz_cat))
Expand Down Expand Up @@ -413,12 +481,12 @@ def group_sources(self, radius='', plot=False):
unique_coords = np.asarray([np.mean(coords[source_ids==id], axis=0) for id in list(set(source_ids))])

# Generate a source catalog
self.catalog = pd.DataFrame(columns=('id','ra','dec','flag','datasets'))
self.catalog['id'] = unique_source_ids
self.catalog[['ra','dec']] = unique_coords
self.catalog['flag'] = [None]*len(unique_source_ids)
# self.catalog['flag'] = ['d{}'.format(i) if i>1 else '' for i in Counter(source_ids).values()]
self.catalog['datasets'] = Counter(source_ids).values()
self.sources = pd.DataFrame(columns=('id','ra','dec','flag','datasets'))
self.sources['id'] = unique_source_ids
self.sources[['ra','dec']] = unique_coords
self.sources['flag'] = [None]*len(unique_source_ids)
# self.sources['flag'] = ['d{}'.format(i) if i>1 else '' for i in Counter(source_ids).values()]
self.sources['datasets'] = Counter(source_ids).values()

# Update history
self.history += "\n{}: Catalog grouped with radius {} arcsec.".format(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), self.xmatch_radius)
Expand Down Expand Up @@ -495,7 +563,7 @@ def load(self, path):
DB = joblib.load(path)

# Load the attributes
self.catalog = DB.catalog
self.sources = DB.catalog
self.n_sources = DB.n_sources
self.name = DB.name
self.history = DB.history
Expand Down Expand Up @@ -535,11 +603,11 @@ def correct_offsets(self, cat_name, truth='ACS'):
else:

# First, remove any previous catalog correction
self.catalog.loc[self.catalog['cat_name']==cat_name, 'ra_corr'] = self.catalog.loc[self.catalog['cat_name']==cat_name, '_RAJ2000']
self.catalog.loc[self.catalog['cat_name']==cat_name, 'dec_corr'] = self.catalog.loc[self.catalog['cat_name']==cat_name, '_DEJ2000']
self.sources.loc[self.sources['cat_name']==cat_name, 'ra_corr'] = self.sources.loc[self.sources['cat_name']==cat_name, '_RAJ2000']
self.sources.loc[self.sources['cat_name']==cat_name, 'dec_corr'] = self.sources.loc[self.sources['cat_name']==cat_name, '_DEJ2000']

# Copy the catalog
onc_gr = self.catalog.copy()
onc_gr = self.sources.copy()

# restrict to one-to-one matches, sort by oncID so that matches are paired
o2o_new = onc_gr.loc[(onc_gr['oncflag'].str.contains('o')) & (onc_gr['cat_name'] == cat_name) ,:].sort_values('oncID')
Expand Down Expand Up @@ -582,8 +650,8 @@ def correct_offsets(self, cat_name, truth='ACS'):

# Update the coordinates of the appropriate sources
print('Shifting {} sources by {}" in RA and {}" in Dec...'.format(cat_name,mu_ra,mu_dec))
self.catalog.loc[self.catalog['cat_name']==cat_name, 'ra_corr'] += mu_ra
self.catalog.loc[self.catalog['cat_name']==cat_name, 'dec_corr'] += mu_dec
self.sources.loc[self.sources['cat_name']==cat_name, 'ra_corr'] += mu_ra
self.sources.loc[self.sources['cat_name']==cat_name, 'dec_corr'] += mu_dec

# Update history
now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
Expand Down Expand Up @@ -625,7 +693,7 @@ def default_rename_columns(cat_name):
defaults = {'2MASS':{'JD':'epoch', 'Qflg':'flags', 'Jmag':'2MASS.J', 'Hmag':'2MASS.H', 'Kmag':'2MASS.Ks', 'e_Jmag':'2MASS.J_unc', 'e_Hmag':'2MASS.H_unc', 'e_Kmag':'2MASS.Ks_unc'},
'WISE':{'qph':'flags', 'W1mag':'WISE.W1', 'W2mag':'WISE.W2', 'W3mag':'WISE.W3', 'W4mag':'WISE.W4', 'e_W1mag':'WISE.W1_unc', 'e_W2mag':'WISE.W2_unc', 'e_W3mag':'WISE.W3_unc', 'e_W4mag':'WISE.W4_unc'},
'SDSS':{'ObsDate':'epoch', 'flags':'oflags', 'Q':'flags', 'umag':'SDSS.u', 'gmag':'SDSS.g', 'rmag':'SDSS.r', 'imag':'SDSS.i', 'zmag':'SDSS.z', 'e_umag':'SDSS.u_unc', 'e_gmag':'SDSS.g_unc', 'e_rmag':'SDSS.r_unc', 'e_imag':'SDSS.i_unc', 'e_zmag':'SDSS.z_unc'},
'TGAS':{'Epoch':'epoch', 'Plx':'parallax', 'e_Plx':'parallax_unc'}}
'GAIA':{'Epoch':'epoch', 'Plx':'parallax', 'e_Plx':'parallax_unc', 'Gmag':'Gaia.G', 'e_Gmag':'Gaia.G_unc', 'BPmag':'Gaia.BP', 'e_BPmag':'Gaia.BP_unc', 'RPmag':'Gaia.RP', 'e_RPmag':'Gaia.RP_unc'}}

return defaults[cat_name]

Expand All @@ -646,7 +714,7 @@ def default_column_fill(cat_name):
defaults = {'2MASS':{'publication_shortname':'Cutr03', 'telescope_id':2, 'instrument_id':5, 'system_id':2},
'WISE':{'publication_shortname':'Cutr13', 'telescope_id':3, 'instrument_id':6, 'system_id':2},
'SDSS':{'publication_shortname':'Alam15', 'telescope_id':6, 'instrument_id':9, 'system_id':2},
'TGAS':{'publication_shortname':'Gaia16', 'telescope_id':4, 'instrument_id':7, 'system_id':1}}
'GAIA':{'publication_shortname':'Gaia18', 'telescope_id':4, 'instrument_id':7, 'system_id':1}}

return defaults[cat_name]

12 changes: 10 additions & 2 deletions astrodbkit/astrodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def create_database(dbpath, schema='', overwrite=True):

# Load the schema if given
if schema:
os.system("cat {} | sqlite3 {}".format(schema,dbpath))
os.system("cat {} | sqlite3 {}".format(schema, dbpath))

# Otherwise just make an empty SOURCES table
else:
Expand Down Expand Up @@ -326,8 +326,16 @@ def add_data(self, data, table, delimiter='|', bands='', clean_up=True, rename_c
# Rename columns
if isinstance(rename_columns,str):
rename_columns = astrocat.default_rename_columns(rename_columns)

try_again = []
for input_name,new_name in rename_columns.items():
data.rename_column(input_name,new_name)
try:
data.rename_column(input_name,new_name)
except KeyError:
try_again.append(input_name)

for input_name in try_again:
data.rename_column(input_name,rename_columns[input_name])

# Add column fills
if isinstance(column_fill,str):
Expand Down
Loading