Skip to content

Commit 9268737

Browse files
Merge pull request #209 from naplab/Improve-automatic-local-referencing
Improve automatic local referencing by improving robustness to grids and to the channel order.
2 parents 710e35e + cb55054 commit 9268737

File tree

2 files changed

+116
-22
lines changed

2 files changed

+116
-22
lines changed

naplib/preprocessing/rereference.py

Lines changed: 94 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def _rereference(data_arr, method='avg', return_ref=False):
116116
return data_rereferenced
117117

118118

119-
def make_contact_rereference_arr(channelnames, extent=None):
119+
def make_contact_rereference_arr(channelnames, extent=None, grid_sizes={}):
120120
"""
121121
Create grid which defines re-referencing scheme based on electrodes being on the same contact as
122122
each other.
@@ -128,13 +128,17 @@ def make_contact_rereference_arr(channelnames, extent=None):
128128
be alphanumeric, with any numbers only being on the right. 2) The numeric portion specifies a
129129
different electrode number, while the character portion in the left of the channelname specifies the
130130
contact name. E.g. ['RT1','RT2','RT3','Ls1','Ls2'] indicates two contacts, the first with 3 electrodes
131-
and the second with 2 electrodes. 3) Electrodes from the same contact must be contiguous.
131+
and the second with 2 electrodes.
132132
extent : int, optional, default=None
133133
If provided, then only contacts from the same group which are within ``extent`` electrodes away
134-
from each other (inclusive) are still grouped together. Only used if ``method='contact'``. For
135-
example, if ``extent=1``, only the nearest electrode on either side of a given electrode on the
136-
same contact is still grouped with it. For example, extent=1 produces the traditional local
137-
average reference scheme.
134+
from each other (inclusive) are still grouped together. For example, if ``extent=1``, only the
135+
nearest electrode on either side of a given electrode on the same contact is still grouped with it.
136+
This ``extent=1`` produces the traditional local average reference scheme.
137+
The default ``extent=None`` produces the traditional common average reference scheme.
138+
grid_sizes : dict, optional, default={}
139+
If provided, contains {'contact_name': (nrow, ncol)} values for any known ECoG grid sizes.
140+
E.g. {'GridA': (8, 16)} indicates that electrodes on contact 'GridA' are arranged in an 8 x 16 grid,
141+
which is needed to determine adjacent electrodes for local average referencing with ``extent >= 1``.
138142
139143
Returns
140144
-------
@@ -145,18 +149,89 @@ def make_contact_rereference_arr(channelnames, extent=None):
145149
--------
146150
rereference
147151
"""
148-
contact_arrays = pd.Series([x.rstrip('0123456789') for x in channelnames])
149-
connections = np.zeros((len(contact_arrays),) * 2, dtype=float)
150-
for _, inds in contact_arrays.groupby(contact_arrays):
151-
for i in inds.index:
152-
connections[i, inds.index] = 1.0
152+
def _find_adjacent_numbers(a, b, number, extent):
153+
'''
154+
Used to determine electrodes for local averaging ECoG grid"
155+
'''
156+
# Validate if the number is within the valid range
157+
if number < 1 or number > a * b:
158+
raise ValueError("The number is outside the range of the grid.")
153159

154-
# remove longer than extent if desired
155-
if extent is not None:
156-
if extent < 1:
157-
raise ValueError(f'Invalid extent. Must be no less than 1 but got extent={extent}')
158-
connections *= np.tri(*connections.shape, k=extent)
159-
connections *= np.fliplr(np.flipud(np.tri(*connections.shape, k=extent)))
160-
connections = connections
161-
160+
# Calculate the row and column of the given number
161+
row = (number - 1) // b
162+
col = (number - 1) % b
163+
164+
# Find all adjacent numbers within the extent
165+
adjacent_numbers = []
166+
for dr in range(-extent, extent + 1): # Rows within the extent
167+
for dc in range(-extent, extent + 1): # Columns within the extent
168+
if dr == 0 and dc == 0:
169+
continue # Skip the number itself
170+
new_row, new_col = row + dr, col + dc
171+
if 0 <= new_row < a and 0 <= new_col < b:
172+
adjacent_num = new_row * b + new_col + 1
173+
adjacent_numbers.append(adjacent_num)
174+
175+
return np.array(adjacent_numbers, dtype=int)
176+
177+
connections = np.zeros((len(channelnames),) * 2, dtype=float)
178+
channelnames = np.array(channelnames)
179+
contact_arrays = np.array([x.rstrip('0123456789') for x in channelnames])
180+
contacts = np.unique(contact_arrays)
181+
# Determine the channel numbers on each contact
182+
ch_per_contact = {contact:[int(x.replace(contact,'')) for x in channelnames
183+
if x.rstrip('0123456789')==contact]
184+
for contact in contacts}
185+
186+
if extent is None:
187+
# Common average referencing per electrode array (ECoG grid or sEEG shank)
188+
# CAR will end up subtracting parts of channel ch from itself
189+
for contact in contacts:
190+
for ch in ch_per_contact[contact]:
191+
curr = np.where(channelnames==f'{contact}{ch}')[0]
192+
inds = np.where(contact_arrays==contact)[0]
193+
connections[curr,inds] = 1
194+
elif extent < 1:
195+
raise ValueError(f'Invalid extent. Must be no less than 1 but got extent={extent}')
196+
else:
197+
# Local average referencing within each electrode array
198+
# LAR will NOT subtract parts of channel ch from itself
199+
for contact in contacts:
200+
for ch in ch_per_contact[contact]:
201+
# Local referencing for ECoG grids
202+
if 'grid' in contact.lower():
203+
num_ch = len(ch_per_contact[contact])
204+
side = np.sqrt(num_ch)
205+
half_side = np.sqrt(num_ch/2)
206+
# Check grid_sizes dict
207+
if contact in grid_sizes:
208+
nrows, ncols = grid_sizes[contact]
209+
# Assume a square
210+
elif np.isclose(side, int(side)):
211+
nrows, ncols = side, side
212+
# Assume a 1 x 2 rectangle
213+
elif np.isclose(half_side, int(half_side)):
214+
nrows, ncols = half_side, half_side*2
215+
else:
216+
raise Exception(f'Cannot determine {contact} layout. Please include layout in `grid_sizes`')
217+
adjacent = _find_adjacent_numbers(nrows, ncols, ch, extent)
218+
curr = np.where(channelnames==f'{contact}{ch}')[0]
219+
inds = []
220+
for adj in adjacent:
221+
inds.append(np.where(channelnames==f'{contact}{adj}')[0])
222+
223+
# Local referencing for sEEG shanks and strips
224+
else:
225+
curr = np.where(channelnames==f'{contact}{ch}')[0]
226+
inds = []
227+
for cc in range(ch-extent, ch+extent+1):
228+
if cc != ch:
229+
inds.append(np.where(channelnames==f'{contact}{cc}')[0])
230+
231+
inds = np.concatenate(inds)
232+
if len(inds) < 1:
233+
print(f'{contact}{cc} has no re-references.')
234+
else:
235+
connections[curr,inds] = 1
236+
162237
return connections

tests/preprocessing/test_rereference.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,29 @@
44

55

66
def test_create_contact_rereference_arr():
7-
expected = np.array([[1,1,0,0],[1,1,0,0],[0,0,1,1],[0,0,1,1]])
8-
g = ['LT1','LT2','RT1','RT2']
9-
arr = make_contact_rereference_arr(g)
7+
expected = np.array([[0,1,0,0,0,0,0,0],
8+
[1,0,0,0,0,0,0,0],
9+
[0,0,0,1,0,0,0,0],
10+
[0,0,1,0,0,0,0,0],
11+
[0,0,0,0,0,1,1,1],
12+
[0,0,0,0,1,0,1,1],
13+
[0,0,0,0,1,1,0,1],
14+
[0,0,0,0,1,1,1,0],
15+
])
16+
expected1 = np.array([[1,1,0,0,0,0,0,0],
17+
[1,1,0,0,0,0,0,0],
18+
[0,0,1,1,0,0,0,0],
19+
[0,0,1,1,0,0,0,0],
20+
[0,0,0,0,1,1,1,1],
21+
[0,0,0,0,1,1,1,1],
22+
[0,0,0,0,1,1,1,1],
23+
[0,0,0,0,1,1,1,1],
24+
])
25+
g = ['LT1','LT2','GridA1','GridA2'] + [f'GridB{n}' for n in range(1,5)]
26+
arr = make_contact_rereference_arr(g, extent=1, grid_sizes={'GridA':(1,2)})
27+
arr1 = make_contact_rereference_arr(g)
1028
assert np.allclose(expected, arr)
29+
assert np.allclose(expected1, arr1)
1130

1231
def test_rereference_avg():
1332
arr = np.array([[1,1,0,0],[1,1,0,0],[0,0,1,1],[0,0,1,1]])

0 commit comments

Comments
 (0)