Skip to content

Commit 7bed9ad

Browse files
more tests, and some cleanup of the Time object
general cleanup of Dataset.from_netCDF
1 parent 3743cff commit 7bed9ad

File tree

5 files changed

+84
-47
lines changed

5 files changed

+84
-47
lines changed

gridded/depth.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -926,15 +926,17 @@ def from_netCDF(cls,
926926
typs = cls.sd_types + cls.ld_types
927927
available_to_create = [typ._can_create_from_netCDF(grid_file=dg, data_file=ds) for typ in typs]
928928
if not any(available_to_create):
929-
warnings.warn('''Unable to automatically determine depth system so
930-
reverting to surface-only mode. Manually check the
931-
(depth_object).surface_index attribute and set it
932-
to the appropriate array index for your model data''', RuntimeWarning)
929+
warnings.warn("Unable to automatically determine depth system so "
930+
"reverting to surface-only mode. Manually check the "
931+
"(depth_object).surface_index attribute and set it "
932+
"to the appropriate array index for your model data",
933+
RuntimeWarning)
933934
return cls.surf_types[0].from_netCDF(data_file=ds, grid_file=dg, **kwargs)
934935
else:
935936
typ = typs[np.argmax(available_to_create)]
936937
if sum(available_to_create) > 1:
937-
warnings.warn('''Multiple depth systems detected. Using the first one found: {0}'''.format(typ.__repr__), RuntimeWarning)
938+
warnings.warn("Multiple depth systems detected. Using the first one found: "
939+
f"{typ!r}", RuntimeWarning)
938940
return typ.from_netCDF(filename=filename,
939941
dataset=dataset,
940942
data_file=data_file,

gridded/gridded.py

Lines changed: 42 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ class Dataset:
2424
An object that represent an entire complete dataset --
2525
a collection of Variables and the Grid that they are stored on.
2626
"""
27-
2827
def __init__(self,
2928
ncfile=None,
3029
grid=None,
@@ -68,35 +67,55 @@ def __init__(self,
6867
# raise ValueError("don't create from a file")
6968
warnings.warn("Creating a Dataset from a netcdfile directly is deprecated. "
7069
"Please use Dataset.from_netCDF() instead. "
71-
"Or one of the utilities in gridded.io", DeprecationWarning)
70+
"Or use one of the utilities in gridded.io",
71+
DeprecationWarning)
7272

7373
if ncfile is not None:
7474
if (grid is not None or
7575
variables is not None or
7676
attributes is not None):
7777
raise ValueError("You can create a Dataset from a file, or from raw data"
7878
"but not both.")
79-
self.nc_dataset = get_dataset(ncfile)
80-
self.filename = self.nc_dataset.filepath()
81-
self.grid = Grid.from_netCDF(filename=self.filename,
82-
dataset=self.nc_dataset,
83-
grid_topology=grid_topology)
84-
self.variables = self._load_variables(self.nc_dataset)
85-
self.attributes = get_dataset_attrs(self.nc_dataset)
86-
else: # no file passed in -- create from grid and variables
79+
self._init_from_netCDF(ncfile)
80+
else: # Create from grid and variables -- this is what should usually happen.
8781
self.filename = None
8882
self.grid = grid
8983
self.variables = {} if variables is None else variables
9084
self.attributes = {} if attributes is None else attributes
9185

86+
87+
def _init_from_netCDF(self,
88+
filename=None,
89+
grid_file=None,
90+
variable_files=None,
91+
grid_topology=None):
92+
"""
93+
internal implementation -- users should call the .from_netCDF()
94+
classmethod -- see its docstring for usage.
95+
96+
This is used to initialize a dataset from a netCDF file --
97+
done this way, so it can be called from more than one place.
98+
"""
99+
if (grid_file is not None) or (variable_files is not None):
100+
raise NotImplementedError("Loading from separate netcdf files is not yet supported")
101+
102+
self.nc_dataset = get_dataset(filename)
103+
self.filename = self.nc_dataset.filepath()
104+
self.grid = Grid.from_netCDF(filename=self.filename,
105+
dataset=self.nc_dataset,
106+
grid_topology=grid_topology)
107+
# fixme: this should load the depth and time, and then the variables.
108+
self.variables = self._variables_from_netCDF(self.nc_dataset)
109+
self.attributes = get_dataset_attrs(self.nc_dataset)
110+
111+
92112
@classmethod
93113
def from_netCDF(cls,
94114
filename=None,
95115
grid_file=None,
96116
variable_files=None,
97117
grid_topology=None):
98118
"""
99-
100119
NOTE: only loading from a single file is currently implemented.
101120
you can create a DATaset by hand, by loading the grid and
102121
variables separately, and then adding them
@@ -117,20 +136,13 @@ def from_netCDF(cls,
117136
:type grid_topology: mapping with keys of topology components and values are
118137
variable names.
119138
"""
120-
if (grid_file is not None) or (variable_files is not None):
121-
raise NotImplementedError("Loading from separate netcdf files is not yet supported")
122-
123-
# create an empty DAtaset:
139+
# create an empty Dataset:
124140
ds = cls()
125-
126-
ds.nc_dataset = get_dataset(filename)
127-
ds.filename = ds.nc_dataset.filepath()
128-
ds.grid = Grid.from_netCDF(filename=ds.filename,
129-
dataset=ds.nc_dataset,
130-
grid_topology=grid_topology)
131-
ds.variables = ds._load_variables(ds.nc_dataset)
132-
ds.attributes = get_dataset_attrs(ds.nc_dataset)
133-
141+
# initialize it
142+
ds._init_from_netCDF(filename,
143+
grid_file,
144+
variable_files,
145+
grid_topology)
134146
return ds
135147

136148
def __getitem__(self, key):
@@ -147,10 +159,15 @@ def __str__(self):
147159
return descp
148160

149161

150-
def _load_variables(self, ds):
162+
def _variables_from_netCDF(self, ds):
151163
"""
152164
load up the variables in the nc file
165+
166+
:param ds: initialized netCDF dataset
153167
"""
168+
# fixme: this needs work
169+
# It *should* have already gotten the grid and depth and time,
170+
# and then the variables can be loaded directly.
154171
variables = {}
155172
for k in ds.variables.keys():
156173
# find which netcdf variables are used to define the grid

gridded/tests/test_dataset.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,30 +6,45 @@
66
import netCDF4 as nc
77

88
from gridded import Dataset
9-
from .utilities import get_test_file_dir
9+
from gridded.grids import Grid_S
10+
from .utilities import TEST_DATA
1011

11-
test_dir = get_test_file_dir()
1212

1313
# Need to hook this up to existing test data infrastructure
1414
# ... and add more infrastructure
1515

16-
sample_sgrid_file = os.path.join(test_dir, 'staggered_sine_channel.nc')
17-
arakawa_c_file = os.path.join(test_dir, 'arakawa_c_test_grid.nc')
16+
sample_sgrid_file = TEST_DATA / 'staggered_sine_channel.nc'
17+
arakawa_c_file = TEST_DATA / 'arakawa_c_test_grid.nc'
1818

1919

2020
def test_load_sgrid():
21-
""" tests you can intitilize an conforming sgrid file"""
22-
sinusoid = Dataset(sample_sgrid_file)
21+
""" tests you can initialize an conforming sgrid file"""
22+
sinusoid = Dataset.from_netCDF(sample_sgrid_file)
23+
24+
assert isinstance(sinusoid.grid, Grid_S)
2325

2426
assert True # just to make it a test
2527

2628

29+
def test_init_from_netcdf_file_directly():
30+
"""
31+
This should raise a deprecation warning, but still work
32+
"""
33+
with pytest.warns(DeprecationWarning):
34+
gds = Dataset(arakawa_c_file)
35+
36+
print(gds.info)
37+
38+
assert isinstance(gds.grid, Grid_S)
39+
assert len(gds.variables) == 6
40+
41+
2742
def test_info():
2843
"""
2944
Make sure the info property is working
3045
This doesn't test much -- jsut tht it won't crash
3146
"""
32-
gds = Dataset(sample_sgrid_file)
47+
gds = Dataset.from_netCDF(sample_sgrid_file)
3348

3449
info = gds.info
3550

@@ -40,15 +55,15 @@ def test_info():
4055
assert "attributes:" in info
4156

4257
def test_get_variable_by_attribute_one_there():
43-
gds = Dataset(arakawa_c_file)
58+
gds = Dataset.from_netCDF(arakawa_c_file)
4459

4560
vars = gds.get_variables_by_attribute('long_name', 'v-momentum component')
4661

4762
assert len(vars) == 1
4863
assert vars[0].attributes['long_name'] == 'v-momentum component'
4964

5065
def test_get_variable_by_attribute_multiple():
51-
gds = Dataset(arakawa_c_file)
66+
gds = Dataset.from_netCDF(arakawa_c_file)
5267

5368
vars = gds.get_variables_by_attribute('units', 'meter second-1')
5469

@@ -62,7 +77,7 @@ def test_get_variable_by_attribute_not_there():
6277
"""
6378
This should return an empty list
6479
"""
65-
gds = Dataset(arakawa_c_file)
80+
gds = Dataset.from_netCDF(arakawa_c_file)
6681

6782
var = gds.get_variables_by_attribute('some_junk', 'more_junk')
6883

gridded/tests/utilities.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
HERE = Path(__file__).parent
1919
EXAMPLE_DATA = HERE / "example_data"
20+
TEST_DATA = HERE / "test_data"
2021

2122
# # Files on PYGNOME server -- add them here as needed
2223
data_file_cache = pooch.create(

gridded/time.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -168,15 +168,16 @@ def __init__(self,
168168
raise TimeSeriesError("Time sequence has duplicate entries")
169169
super(Time, self).__init__(*args, **kwargs)
170170

171-
@classmethod
172-
def locate_time_var_from_var(cls, datavar):
171+
172+
@staticmethod
173+
def locate_time_var_from_var(datavar):
173174
if hasattr(datavar, 'time') and datavar.time in datavar._grp.dimensions.keys():
174175
varname = datavar.time
175176
else:
176177
varname = datavar.dimensions[0] if 'time' in datavar.dimensions[0] else None
177-
178+
178179
return varname
179-
180+
180181

181182
@classmethod
182183
def from_netCDF(cls,
@@ -432,14 +433,15 @@ def time_in_bounds(self, time):
432433
def valid_time(self, time):
433434
"""
434435
Raises a OutOfTimeRangeError if time is not within the bounds of the timeseries
436+
437+
:param time: a datetime object that you want to check.
435438
"""
436439
# if time < self.min_time or time > self.max_time:
437440
if not self.time_in_bounds(time):
438-
raise OutOfTimeRangeError(f'time specified ({time.strftime('%c')}) is not within the bounds of '
439-
f'({self.min_time.strftime('%c')} to {self.max_time.strftime('%c')})'
441+
raise OutOfTimeRangeError(f'time specified: ({time.isoformat()}) is not within the bounds of '
442+
f'({self.min_time.isoformat()} to {self.max_time.isoformat()})'
440443
)
441444

442-
443445
def index_of(self, time, extrapolate=False):
444446
'''
445447
Returns the index of the provided time with respect to the time intervals in the file.

0 commit comments

Comments
 (0)