-
Couldn't load subscription status.
- Fork 107
Add cnn model #813
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add cnn model #813
Conversation
| num_conv_layers=5, | ||
| conv_filters=[5, 10, 20, 40, 60], | ||
| kernel_size=3, | ||
| image_size=(8, 9, 22), # dimensions of the example image |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be neat to add a property to the ImageDefinition that contains the resulting image dimension. E.g. ImageDefinition.shape
| """CNN-specific modules, for performing the main learnable operations.""" | ||
|
|
||
| from .cnn import CNN | ||
| from .theos_muonE_upgoing import TheosMuonEUpgoing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.theos_muonE_upgoing breaks with snake-case convention. Do we need "theos" in there? It's very jargony. Credit can be given in the associated docstring instead of the module name
| """Initialize the Lightning CNN signal classifier (LCSC). | ||
| Args: | ||
| num_input_features (int): Number of input features. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great with the detailed argument descriptions, but they break the existing conventions used in the library. The types should be not repeated within the docstring itself, as our documentation automatically adds them to the compiled documentation when compiled based on type hinting in code.
I.e.
num_input_features (int): Number of input features.
should be
num_input_features: Number of input features.
You can see the docstring for DynEdge here and the resulting documentation here
| """ | ||
| super().__init__(nb_inputs=num_input_features, nb_outputs=out_put_dim) | ||
|
|
||
| # Check input parameters |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's quite a bit of parsing in the init here. Looks like you're doing two things: checking incompatible arguments (raising errors) and parsing the acceptable arguments for subsequent use in the layer building. You could instead move this logic into one or more private methods that are used in the init function - this will improve the readability greatly. For example:
def __init__(param1: type, param2: type):
""" Docstring """
# Check and parse input parameters
filters, kernel_sizes, padding, .. = self._parse_conv_arguments(param1 = param1, param2=param2)
pooling_size, pooling_stride, .. = self._parse_pooling_arguments(param1 = param1, param2=param2)
# Set Convolution Layers
self._set_conv_layers(filters = filters, kernel_sizes = kernel_sizes,
....,
pooling_sizes = pooling_sizes)
# Set Linear layers
self.flatten = torch.nn.Flatten()
self.fc1 = torch.nn.Linear(latent_dim, num_fc_neurons)
self.fc2 = torch.nn.Linear(num_fc_neurons, out_put_dim)|
|
||
| def forward(self, data: Data) -> torch.Tensor: | ||
| """Forward pass of the LCSC.""" | ||
| assert len(data.x) == 1, "Only Main Array image is supported for LCSC" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This assertion checks that a single image is produced by the image representation as opposed to multiple, not that a specific image representation is used, e.g. "main array".
| https://github.com/AlexHarn) | ||
| Intended to be used with the IceCube 86 image containing | ||
| only the Main Array image. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it correctly understood that this method will work with any single-image representations, but that the method and default parameters were tested and selected based on IceCube simulation and a particular representation that utilizes the main array only? If so, I think adjusting this sentence would be wise.
| Theo Glauchs thesis (chapter 5.3): | ||
| https://mediatum.ub.tum.de/node?id=1584755 | ||
| NOTE: number of pulses per cluster is not mentioned/used in the thesis |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How is this supposed to be understood? Do you mean that introducing this within the method is your own creation?
| @@ -0,0 +1,411 @@ | |||
| """CNN used for muon energy reconstruction in IceCube. | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
src/graphnet/models/cnn/theos_muonE_upgoing.py breaks with the snake-case convention. Do we strictly need "theos" in the module name? Proper credits can be given in the module docstring.
|
|
||
|
|
||
| class Conv3dBN(LightningModule): | ||
| """The Conv3dBN module from Theos CNN model.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Theos -> Theo Glauch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider adding a bit more detail to inform the reader of what this module is. E.g.
"""Implementation of the Conv3dBN image convolution module from Theo Glauch."""
|
|
||
|
|
||
| class InceptionBlock4(LightningModule): | ||
| """The inception_block4 module from Theos CNN model.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comments above apply here too.
|
|
||
|
|
||
| class InceptionResnet(LightningModule): | ||
| """The inception_resnet module from Theos CNN model.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comments from above apply here, too.
| return x + self._scale * tmp | ||
|
|
||
|
|
||
| class TheosMuonEUpgoing(CNN): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is the official name of the method, and to my knowledge, nothing within the method restricts it to upgoing events only. I would strongly suggest finding a more accessible name for the method. I believe it's more commonly known as the "DNN" within IceCube, no? You can use the docstring to provide further details on its origin. I.e. proper credits to Theo and his use of the method.
| class TheosMuonEUpgoing(CNN): | ||
| """The TheosMuonEUpgoing module.""" | ||
|
|
||
| def __init__(self, nb_inputs: int = 15, nb_outputs: int = 16) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a good reason why the hyperparameters of the method are hardcoded?
If not, let's please make them arguments, as that will greatly increase the reusability of the method.
| Args: | ||
| dtype: data type used for node features. e.g. ´torch.float´ | ||
| string_label: Name of the feature corresponding | ||
| to the DOM string number. Values Integers betweem 1 - 86 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
betweem -> between
| self._include_upper_dc = include_upper_dc | ||
|
|
||
| # read mapping from parquet file | ||
| df = pd.read_parquet(IC86_CNN_MAPPING) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would there be a way for us to compile the mapping at instantiation without relying on a file?
| self._mapping = df | ||
| super().__init__(pixel_feature_names=pixel_feature_names) | ||
|
|
||
| def _set_indeces( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_set_indices
| self._sensor_number_label = sensor_number_label | ||
| self._pixel_feature_names = pixel_feature_names | ||
|
|
||
| self._set_indeces( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_set_indices
| self._dom_number_label = dom_number_label | ||
| self._pixel_feature_names = pixel_feature_names | ||
|
|
||
| self._set_indeces(pixel_feature_names, dom_number_label, string_label) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_set_indices
| self._mapping = df | ||
| super().__init__(pixel_feature_names=pixel_feature_names) | ||
|
|
||
| def _set_indeces( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_set_indices
| row[3], # mat_ax1 | ||
| ] = batch_row_features[i] | ||
|
|
||
| # unqueeze to add dimension for batching |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unqueeze - unsqueeze
| ) | ||
|
|
||
| # data.x is expected to be a tensor with shape (N, F) | ||
| # where N is the number of nodes and F is the number of features. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rows represent pixels, right?
| match_indices = self._mapping.loc[ | ||
| zip(*string_dom_number.t().tolist()) | ||
| ][ | ||
| ["string", "dom_number", "mat_ax0", "mat_ax1", "mat_ax2"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like your method relies on a very specific set of column names in this file.
|
Hey @sevmag - thank you for implementing CNNs!! 🚀 . I like the approach you've taken, and I think the PR is generally in pretty good shape. In addition to the specific comments above, I've been thinking that we can simplify the user experience and eliminate the need for new files by introducing a slight refactor of the "Pixelmapping," which changes the role it plays in the image representation. In essence, I propose that "Pixelmapping" (referred to as "GridDefinition" below) defines the number of images, their sizes, and a method for generating the key-value store(s) that is used to insert pixels into the grid(s) using the existing Detector classes. The functionality of generating grids and inserting pixels would be handled by the image representation. More details below. Could you take a look and let me know if this fits your use-case? Preluding observations
These two observations essentially foresee the existence of two central arguments for image representations. I summarize my proposed scope of each below: PixelDefinition In pseudo-code, the from typing import Optional, List, Dict, Union, Tuple, Any, Callable
from numpy.random import Generator
import numpy as np
import pandas as pd
import torch
from torch_geometric.data import Data
from graphnet.models.data_representation import DataRepresentation
from graphnet.models.detector import Detector
from graphnet.models.graphs.nodes import NodeDefinition
class ImageRepresentation(DataRepresentation):
""" A base class for image representations in GraphNeT."""
def __init__(self,
pixel_definition: NodeDefinition,
grid_definition: GridDefinition,
input_feature_names: Optional[List[str]] = None,
dtype: Optional[torch.dtype] = torch.float,
perturbation_dict: Optional[Dict[str, float]] = None,
seed: Optional[Union[int, Generator]] = None,
add_inactive_sensors: bool = False,
sensor_mask: Optional[List[int]] = None,
string_mask: Optional[List[int]] = None,
repeat_labels: bool = False, ) -> None:
# Base class constructor
super().__init__(
detector=grid_definition.detector, # defines detector
input_feature_names=input_feature_names,
dtype=dtype,
perturbation_dict=perturbation_dict,
seed=seed,
add_inactive_sensors=add_inactive_sensors,
sensor_mask=sensor_mask,
string_mask=string_mask,
repeat_labels=repeat_labels,
)
self._pixel_definition = pixel_definition
self._grid_definition = grid_definition
self._pixel_mappings = grid_definition.mappings() # yields key-value store(s)
self._image_shapes = grid_definition.shape # Shape of image(s)
self._map_pixels_by = self._grid_definition.map_pixels_by
def forward( # type: ignore
self,
input_features: np.ndarray,
input_feature_names: List[str],
truth_dicts: Optional[List[Dict[str, Any]]] = None,
custom_label_functions: Optional[Dict[str, Callable[..., Any]]] = None,
loss_weight_column: Optional[str] = None,
loss_weight: Optional[float] = None,
loss_weight_default_value: Optional[float] = None,
data_path: Optional[str] = None,
) -> Data:
"""Construct graph as ´Data´ object.
Args:
input_features: Input features for graph construction.
Shape ´[num_rows, d]´
input_feature_names: name of each column. Shape ´[,d]´.
truth_dicts: Dictionary containing truth labels.
custom_label_functions: Custom label functions.
loss_weight_column: Name of column that holds loss weight.
Defaults to None.
loss_weight: Loss weight associated with event. Defaults to None.
loss_weight_default_value: default value for loss weight.
Used in instances where some events have
no pre-defined loss weight. Defaults to None.
data_path: Path to dataset data files. Defaults to None.
Returns:
graph
"""
# Process low-level pulses using base-class
data = super().forward(
input_features=input_features,
input_feature_names=input_feature_names,
truth_dicts=truth_dicts,
custom_label_functions=custom_label_functions,
loss_weight_column=loss_weight_column,
loss_weight=loss_weight,
loss_weight_default_value=loss_weight_default_value,
data_path=data_path,
)
# Transform pulses to pixels
x = self._pixel_definition(x = data.x)
# Map pixels to positions in image(s)
x = self._map_pixels_to_grid(x = x,
pixel_mappings = self._pixel_mappings,
image_shapes = self._image_shapes)
# Assign to Data
data.x = x
# other stuff..
return data
def _map_pixels_to_grid(self,
x: torch.Tensor,
pixel_mappings: List[pd.DataFrame],
image_shapes: List[int]) -> List[torch.Tensor]:
"""Insert unorderedpixel values in `x`
into empty image(s) with shape(s) `image_shapes` using the
key-value store defined by `pixel_mappings`."""
# Check that the number of image shapes is equal to number of mappings
assert len(pixel_mappings) == len(image_shapes)
# Create and fill images with pixels
images = []
# We assume the ordering is identical here
for shape, mapping in zip(pixel_mappings, image_shapes):
empty_image = torch.zeros(size = shape)
filled_image = self._apply_map(empty_image = empty_image,
pixels = x,
mapping = mapping,
map_pixels_by = self._map_pixels_by)
# [F,D,H,W] -> [1, F, D, H, W] for 3D
# [F,D,H] -> [1, F, D, H] for 2D
filled_image = filled_image.unsqueeze(0)
images.append(filled_image)
return images
def _apply_map(self,
empty_image: torch.Tensor,
pixels: torch.Tensor,
mapping: pd.DataFrame,
map_pixels_by: List[int]) -> torch.Tensor:
"""
Insert values from `pixels` into `empty_image` at positions
identified by indexing `mapping` with columns `map_pixels_by` in `pixels`
`empty_image` can either be [F,D,H,W]-dimensional (3D) or [F,D,H] (2D)
where F denotes the number of pixel features.
"""
@property
def shape(self) -> List[Tuple[int]]:
return self._image_shapes
def _set_output_feature_names(
self, input_feature_names: List[str]
) -> List[str]:
"""Return ordered list of pixel feature names."""
return self._pixel_definition.output_feature_namesNote, I didn't write out Given this structure, the from abc import abstractmethod
from typing import Optional, List, Dict, Union, Tuple, Any, Callable
from numpy.random import Generator
import numpy as np
import pandas as pd
import torch
from torch_geometric.data import Data
from graphnet.models import Model
from graphnet.models.detector import Detector
from graphnet.models.graphs.nodes import NodeDefinition
class GridDefinition(Model):
""" Base class for constructing image partitions in GraphNeT.
The image partitions define orthonormal grids from detector geometry."""
def __init__(self,
detector: Detector,
pixel_feature_names: List[str],
map_pixels_by: List[str]) -> None:
"""detector: Regular graphnet detector class that holds geometry
pixel_features: list of all available pixel features. Assumed to ordered.
map_pixels_by: sbuset of pixel_features to map by."""
super().__init__(name=__name__, class_name=self.__class__.__name__)
# Checks
assert isinstance(map_pixels_by, list)
assert isinstance(pixel_feature_names, list)
assert isinstance(detector, Detector)
self.detector = detector
self._pixel_features = pixel_feature_names
self._map_pixels_by = map_pixels_by
self._geometry_table = detector.geometry_table
@abstractmethod
def _generate_mappings(self,
geometry_table: pd.DataFrame,
map_pixels_by: List[str],
pixel_feature_names: List[str]) -> Union[List[pd.DataFrame], pd.DataFrame]:
"""Generate a single, or a list of, key-value stores that relates
a pixel position defined by `map_pixels_by` to a position in
the orthonormal grid using the detector geometry table.
The resulting key-value store is required to be an indexed
pd.DataFrame, and may use geometric detector features such as
`from graphnet.models.detector.icecube import IceCube86
detector = IceCube86() # or any other
# Natively indexed on xyz positions
geometry_table = detector.geometry_table.reset_index(drop = False)
unique_sensor_id = detector.sensor_id_column
unique_string_id = detector.string_id_column
unique_sensor_position = detector.xyz`
"""
return NotImplementedError
@abstractmethod
def _generate_shapes(self,
geometry_table: pd.DataFrame,
pixel_features: List[str],
map_pixels_by: List[str]) -> Union[Tuple[int],
List[Tuple[int]]]:
"""Generate the shape(s) of the image grid(s).
E.g. [(10, 5, 2,10), (256, 50, 10, 2)] """
return NotImplementedError
@property
def shape(self) -> Union[Tuple[int],List[Tuple[int]]]:
"""Return the shape(s) of the image(s)."""
if hasattr(self, '_shapes'):
return self._shapes
else:
self._shapes = self._generate_shapes(geometry_table = self._geometry_table,
pixel_features = self._pixel_features,
map_pixels_by= self._map_pixels_by)
return self._shapes
@property
def map_pixels_by(self) -> List[str]:
return self._map_pixels_by
@property
def mappings(self) -> Union[pd.DataFrame, List[pd.DataFrame]]:
"""Return the key-value stores that map a pixel to a point in the grid(s)."""
if hasattr(self, "_mappings"):
return self._mappings
else:
self._mappings = self._generate_mappings(geometry_table = self._geometry_table,
pixel_features = self._pixel_features,
map_pixels_by= self._map_pixels_by)
return self._mappingsWithin this formalism, your existing IC86 representation could look something like this: from graphnet.models.detector import IceCube86
from typing import List, Tuple, Union, Dict
import pandas as pd
# Fixed 10x10 placement for strings 1..78 (from your generator)
_IC86_STRING_TO_AX01: Dict[int, Tuple[int, int]] = {
1:(9,4), 2:(9,5), 3:(9,6), 4:(9,7), 5:(9,8), 6:(9,9),
7:(8,3), 8:(8,4), 9:(8,5), 10:(8,6), 11:(8,7), 12:(8,8), 13:(8,9),
14:(7,2), 15:(7,3), 16:(7,4), 17:(7,5), 18:(7,6), 19:(7,7), 20:(7,8), 21:(7,9),
22:(6,1), 23:(6,2), 24:(6,3), 25:(6,4), 26:(6,5), 27:(6,6), 28:(6,7), 29:(6,8), 30:(6,9),
31:(5,0), 32:(5,1), 33:(5,2), 34:(5,3), 35:(5,4), 36:(5,5), 37:(5,6), 38:(5,7), 39:(5,8), 40:(5,9),
41:(4,0), 42:(4,1), 43:(4,2), 44:(4,3), 45:(4,4), 46:(4,5), 47:(4,6), 48:(4,7), 49:(4,8), 50:(4,9),
51:(3,0), 52:(3,1), 53:(3,2), 54:(3,3), 55:(3,4), 56:(3,5), 57:(3,6), 58:(3,7), 59:(3,8),
60:(2,0), 61:(2,1), 62:(2,2), 63:(2,3), 64:(2,4), 65:(2,5), 66:(2,6), 67:(2,7),
68:(1,0), 69:(1,1), 70:(1,2), 71:(1,3), 72:(1,4), 73:(1,5), 74:(1,6),
75:(0,0), 76:(0,1), 77:(0,2), 78:(0,3),
}
class IC86Grid(GridDefinition):
def __init__(
self,
pixel_feature_names: List[str],
string_label: str = "string",
dom_number_label: str = "sensor_id", # will be aliased to detector.sensor_id_column
include_main_array: bool = True,
include_lower_dc: bool = True,
include_upper_dc: bool = True,
) -> None:
super().__init__(
detector=IceCube86(),
pixel_feature_names=pixel_feature_names,
map_pixels_by=[string_label, dom_number_label],
)
if not any([include_main_array, include_lower_dc, include_upper_dc]):
raise ValueError("Include at least one array type.")
self._string_label = string_label
self._dom_number_label = dom_number_label
self._include_main_array = include_main_array
self._include_lower_dc = include_lower_dc
self._include_upper_dc = include_upper_dc
# channels = all features except the mapping keys
self._nb_channels = len(pixel_feature_names) - 2
# ---- GridDefinition interface ----
def _generate_mappings(
self,
geometry_table: pd.DataFrame,
map_pixels_by: List[str],
pixel_features: List[str],
) -> Union[List[pd.DataFrame], pd.DataFrame]:
"""
Build one mapping DataFrame per included grid using
detector.geometry_table.
"""
# Your logic goes here
# Ideally use the "sensor_id" which defines unique DOMs
# Or, if you prefer, we can add the non-unique "dom_number"
# to the geometry table
# Use global variable above as you wish
def _generate_shapes(
self,
geometry_table: pd.DataFrame,
pixel_features: List[str],
map_pixels_by: List[str],
) -> Union[Tuple[int], List[Tuple[int]]]:
""" Define the dimension(s) of the image(s) here"""
# Make sure as little as possible is hardcoded |
This is the big PR for the goal of adding CNN support to GraphNeT, enabling direct comparisons (see #771).
The CNN support consists of:
An ImageDefinition consists of 2 parts:
There are 2 CNN architectures implemented:
Timing of the ImageDefinition in Comparison to Other Datareps
At a low number of pulses, the bottleneck of the ImageDefinition is the initialisation of zero tensors
Timed Modules
5000-200000 Mock Pulses (log scale)
1-5000 Mock Pulses
1-500 Mock Pulses