@@ -116,7 +116,7 @@ def _rereference(data_arr, method='avg', return_ref=False):
116
116
return data_rereferenced
117
117
118
118
119
- def make_contact_rereference_arr (channelnames , extent = None ):
119
+ def make_contact_rereference_arr (channelnames , extent = None , grid_sizes = {} ):
120
120
"""
121
121
Create grid which defines re-referencing scheme based on electrodes being on the same contact as
122
122
each other.
@@ -128,13 +128,17 @@ def make_contact_rereference_arr(channelnames, extent=None):
128
128
be alphanumeric, with any numbers only being on the right. 2) The numeric portion specifies a
129
129
different electrode number, while the character portion in the left of the channelname specifies the
130
130
contact name. E.g. ['RT1','RT2','RT3','Ls1','Ls2'] indicates two contacts, the first with 3 electrodes
131
- and the second with 2 electrodes. 3) Electrodes from the same contact must be contiguous.
131
+ and the second with 2 electrodes.
132
132
extent : int, optional, default=None
133
133
If provided, then only contacts from the same group which are within ``extent`` electrodes away
134
- from each other (inclusive) are still grouped together. Only used if ``method='contact'``. For
135
- example, if ``extent=1``, only the nearest electrode on either side of a given electrode on the
136
- same contact is still grouped with it. For example, extent=1 produces the traditional local
137
- average reference scheme.
134
+ from each other (inclusive) are still grouped together. For example, if ``extent=1``, only the
135
+ nearest electrode on either side of a given electrode on the same contact is still grouped with it.
136
+ This ``extent=1`` produces the traditional local average reference scheme.
137
+ The default ``extent=None`` produces the traditional common average reference scheme.
138
+ grid_sizes : dict, optional, default={}
139
+ If provided, contains {'contact_name': (nrow, ncol)} values for any known ECoG grid sizes.
140
+ E.g. {'GridA': (8, 16)} indicates that electrodes on contact 'GridA' are arranged in an 8 x 16 grid,
141
+ which is needed to determine adjacent electrodes for local average referencing with ``extent >= 1``.
138
142
139
143
Returns
140
144
-------
@@ -145,18 +149,89 @@ def make_contact_rereference_arr(channelnames, extent=None):
145
149
--------
146
150
rereference
147
151
"""
148
- contact_arrays = pd .Series ([x .rstrip ('0123456789' ) for x in channelnames ])
149
- connections = np .zeros ((len (contact_arrays ),) * 2 , dtype = float )
150
- for _ , inds in contact_arrays .groupby (contact_arrays ):
151
- for i in inds .index :
152
- connections [i , inds .index ] = 1.0
152
+ def _find_adjacent_numbers (a , b , number , extent ):
153
+ '''
154
+ Used to determine electrodes for local averaging ECoG grid"
155
+ '''
156
+ # Validate if the number is within the valid range
157
+ if number < 1 or number > a * b :
158
+ raise ValueError ("The number is outside the range of the grid." )
153
159
154
- # remove longer than extent if desired
155
- if extent is not None :
156
- if extent < 1 :
157
- raise ValueError (f'Invalid extent. Must be no less than 1 but got extent={ extent } ' )
158
- connections *= np .tri (* connections .shape , k = extent )
159
- connections *= np .fliplr (np .flipud (np .tri (* connections .shape , k = extent )))
160
- connections = connections
161
-
160
+ # Calculate the row and column of the given number
161
+ row = (number - 1 ) // b
162
+ col = (number - 1 ) % b
163
+
164
+ # Find all adjacent numbers within the extent
165
+ adjacent_numbers = []
166
+ for dr in range (- extent , extent + 1 ): # Rows within the extent
167
+ for dc in range (- extent , extent + 1 ): # Columns within the extent
168
+ if dr == 0 and dc == 0 :
169
+ continue # Skip the number itself
170
+ new_row , new_col = row + dr , col + dc
171
+ if 0 <= new_row < a and 0 <= new_col < b :
172
+ adjacent_num = new_row * b + new_col + 1
173
+ adjacent_numbers .append (adjacent_num )
174
+
175
+ return np .array (adjacent_numbers , dtype = int )
176
+
177
+ connections = np .zeros ((len (channelnames ),) * 2 , dtype = float )
178
+ channelnames = np .array (channelnames )
179
+ contact_arrays = np .array ([x .rstrip ('0123456789' ) for x in channelnames ])
180
+ contacts = np .unique (contact_arrays )
181
+ # Determine the channel numbers on each contact
182
+ ch_per_contact = {contact :[int (x .replace (contact ,'' )) for x in channelnames
183
+ if x .rstrip ('0123456789' )== contact ]
184
+ for contact in contacts }
185
+
186
+ if extent is None :
187
+ # Common average referencing per electrode array (ECoG grid or sEEG shank)
188
+ # CAR will end up subtracting parts of channel ch from itself
189
+ for contact in contacts :
190
+ for ch in ch_per_contact [contact ]:
191
+ curr = np .where (channelnames == f'{ contact } { ch } ' )[0 ]
192
+ inds = np .where (contact_arrays == contact )[0 ]
193
+ connections [curr ,inds ] = 1
194
+ elif extent < 1 :
195
+ raise ValueError (f'Invalid extent. Must be no less than 1 but got extent={ extent } ' )
196
+ else :
197
+ # Local average referencing within each electrode array
198
+ # LAR will NOT subtract parts of channel ch from itself
199
+ for contact in contacts :
200
+ for ch in ch_per_contact [contact ]:
201
+ # Local referencing for ECoG grids
202
+ if 'grid' in contact .lower ():
203
+ num_ch = len (ch_per_contact [contact ])
204
+ side = np .sqrt (num_ch )
205
+ half_side = np .sqrt (num_ch / 2 )
206
+ # Check grid_sizes dict
207
+ if contact in grid_sizes :
208
+ nrows , ncols = grid_sizes [contact ]
209
+ # Assume a square
210
+ elif np .isclose (side , int (side )):
211
+ nrows , ncols = side , side
212
+ # Assume a 1 x 2 rectangle
213
+ elif np .isclose (half_side , int (half_side )):
214
+ nrows , ncols = half_side , half_side * 2
215
+ else :
216
+ raise Exception (f'Cannot determine { contact } layout. Please include layout in `grid_sizes`' )
217
+ adjacent = _find_adjacent_numbers (nrows , ncols , ch , extent )
218
+ curr = np .where (channelnames == f'{ contact } { ch } ' )[0 ]
219
+ inds = []
220
+ for adj in adjacent :
221
+ inds .append (np .where (channelnames == f'{ contact } { adj } ' )[0 ])
222
+
223
+ # Local referencing for sEEG shanks and strips
224
+ else :
225
+ curr = np .where (channelnames == f'{ contact } { ch } ' )[0 ]
226
+ inds = []
227
+ for cc in range (ch - extent , ch + extent + 1 ):
228
+ if cc != ch :
229
+ inds .append (np .where (channelnames == f'{ contact } { cc } ' )[0 ])
230
+
231
+ inds = np .concatenate (inds )
232
+ if len (inds ) < 1 :
233
+ print (f'{ contact } { cc } has no re-references.' )
234
+ else :
235
+ connections [curr ,inds ] = 1
236
+
162
237
return connections
0 commit comments