@@ -101,7 +101,8 @@ def __init__(self, root_ns: CircuitNodes, layer_sparsity_tol: float = 0.5,
101
101
force_gpu_compilation : bool = False ,
102
102
max_tied_ns_per_parflow_block : int = 8 ,
103
103
device : Optional [Union [int ,torch .device ]] = None ,
104
- verbose : bool = True ) -> None :
104
+ verbose : bool = True ,
105
+ normalize : bool = True ) -> None :
105
106
106
107
super (TensorCircuit , self ).__init__ ()
107
108
@@ -123,7 +124,8 @@ def __init__(self, root_ns: CircuitNodes, layer_sparsity_tol: float = 0.5,
123
124
force_gpu_compilation = force_gpu_compilation ,
124
125
max_tied_ns_per_parflow_block = max_tied_ns_per_parflow_block ,
125
126
device = device ,
126
- verbose = verbose
127
+ verbose = verbose ,
128
+ normalize = normalize
127
129
)
128
130
129
131
# Hyperparameters for backward pass
@@ -792,7 +794,7 @@ def _get_num_vars(self, ns: CircuitNodes):
792
794
793
795
def _init_layers (self , layer_sparsity_tol : Optional [float ] = None , max_num_partitions : Optional [int ] = None ,
794
796
disable_gpu_compilation : bool = False , force_gpu_compilation : bool = False ,
795
- max_tied_ns_per_parflow_block : int = 8 , verbose : bool = True , device : Optional [Union [str ,torch .device ]] = None ):
797
+ max_tied_ns_per_parflow_block : int = 8 , verbose : bool = True , device : Optional [Union [str ,torch .device ]] = None , normalize : bool = True ):
796
798
797
799
if hasattr (self , "input_layer_group" ) or hasattr (self , "inner_layer_groups" ):
798
800
raise ValueError ("Attempting to initialize a TensorCircuit for the second time. " + \
@@ -952,9 +954,9 @@ def _init_layers(self, layer_sparsity_tol: Optional[float] = None, max_num_parti
952
954
self ._root_node_range = (self .num_nodes - self .num_root_nodes , self .num_nodes )
953
955
954
956
# Initialize parameters
955
- self ._init_parameters ()
957
+ self ._init_parameters (normalize = normalize )
956
958
957
- def _init_parameters (self , perturbation : float = 4.0 , pseudocount : float = 0.0 ):
959
+ def _init_parameters (self , perturbation : float = 4.0 , pseudocount : float = 0.0 , normalize : bool = True ):
958
960
for ns in self .root_ns :
959
961
if not ns .is_tied () and (ns .is_sum () or ns .is_input ()) and not ns .has_params ():
960
962
ns .init_parameters (perturbation = perturbation , recursive = False )
@@ -967,7 +969,8 @@ def _init_parameters(self, perturbation: float = 4.0, pseudocount: float = 0.0):
967
969
if ns .is_sum () and not ns .is_tied () and ns .has_params ():
968
970
ns .gather_parameters (params )
969
971
970
- self ._normalize_parameters (params , pseudocount = pseudocount )
972
+ if normalize :
973
+ self ._normalize_parameters (params , pseudocount = pseudocount )
971
974
self .params = nn .Parameter (params )
972
975
973
976
# Due to the custom inplace backward pass implementation, we do not track
@@ -982,6 +985,24 @@ def _normalize_parameters(self, params, pseudocount: float = 0.0):
982
985
if params is not None :
983
986
normalize_parameters (params , self .par_update_kwargs , pseudocount )
984
987
988
+ def normalize (self , perturbation : float = 0.0 , pseudocount : float = 0.0 ):
989
+ params = torch .exp (torch .rand ([self .num_sum_params ]) * - perturbation )
990
+ params = params .to (self .device )
991
+ params [:self .num_dummy_params ] = 0.0
992
+
993
+ # Copy initial parameters if provided
994
+ for ns in self .root_ns :
995
+ if ns .is_sum () and not ns .is_tied () and ns .has_params ():
996
+ ns .gather_parameters (params )
997
+
998
+ self ._normalize_parameters (params , pseudocount = pseudocount )
999
+ self .params = nn .Parameter (params )
1000
+
1001
+ # Due to the custom inplace backward pass implementation, we do not track
1002
+ # gradient of PC parameters by PyTorch.
1003
+ self .params .requires_grad = False
1004
+
1005
+
985
1006
def _create_node_layers (self ):
986
1007
depth2nodes = dict ()
987
1008
nodes2depth = dict ()
0 commit comments