Skip to content

Commit 1ad06c0

Browse files
committed
normalizing a pc
1 parent de87697 commit 1ad06c0

File tree

2 files changed

+33
-6
lines changed

2 files changed

+33
-6
lines changed

src/pyjuice/model/tensorcircuit.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ def __init__(self, root_ns: CircuitNodes, layer_sparsity_tol: float = 0.5,
101101
force_gpu_compilation: bool = False,
102102
max_tied_ns_per_parflow_block: int = 8,
103103
device: Optional[Union[int,torch.device]] = None,
104-
verbose: bool = True) -> None:
104+
verbose: bool = True,
105+
normalize: bool = True) -> None:
105106

106107
super(TensorCircuit, self).__init__()
107108

@@ -123,7 +124,8 @@ def __init__(self, root_ns: CircuitNodes, layer_sparsity_tol: float = 0.5,
123124
force_gpu_compilation = force_gpu_compilation,
124125
max_tied_ns_per_parflow_block = max_tied_ns_per_parflow_block,
125126
device = device,
126-
verbose = verbose
127+
verbose = verbose,
128+
normalize = normalize
127129
)
128130

129131
# Hyperparameters for backward pass
@@ -792,7 +794,7 @@ def _get_num_vars(self, ns: CircuitNodes):
792794

793795
def _init_layers(self, layer_sparsity_tol: Optional[float] = None, max_num_partitions: Optional[int] = None,
794796
disable_gpu_compilation: bool = False, force_gpu_compilation: bool = False,
795-
max_tied_ns_per_parflow_block: int = 8, verbose: bool = True, device: Optional[Union[str,torch.device]] = None):
797+
max_tied_ns_per_parflow_block: int = 8, verbose: bool = True, device: Optional[Union[str,torch.device]] = None, normalize: bool = True):
796798

797799
if hasattr(self, "input_layer_group") or hasattr(self, "inner_layer_groups"):
798800
raise ValueError("Attempting to initialize a TensorCircuit for the second time. " + \
@@ -952,9 +954,9 @@ def _init_layers(self, layer_sparsity_tol: Optional[float] = None, max_num_parti
952954
self._root_node_range = (self.num_nodes - self.num_root_nodes, self.num_nodes)
953955

954956
# Initialize parameters
955-
self._init_parameters()
957+
self._init_parameters(normalize = normalize)
956958

957-
def _init_parameters(self, perturbation: float = 4.0, pseudocount: float = 0.0):
959+
def _init_parameters(self, perturbation: float = 4.0, pseudocount: float = 0.0, normalize: bool = True):
958960
for ns in self.root_ns:
959961
if not ns.is_tied() and (ns.is_sum() or ns.is_input()) and not ns.has_params():
960962
ns.init_parameters(perturbation = perturbation, recursive = False)
@@ -967,7 +969,8 @@ def _init_parameters(self, perturbation: float = 4.0, pseudocount: float = 0.0):
967969
if ns.is_sum() and not ns.is_tied() and ns.has_params():
968970
ns.gather_parameters(params)
969971

970-
self._normalize_parameters(params, pseudocount = pseudocount)
972+
if normalize:
973+
self._normalize_parameters(params, pseudocount = pseudocount)
971974
self.params = nn.Parameter(params)
972975

973976
# Due to the custom inplace backward pass implementation, we do not track
@@ -982,6 +985,24 @@ def _normalize_parameters(self, params, pseudocount: float = 0.0):
982985
if params is not None:
983986
normalize_parameters(params, self.par_update_kwargs, pseudocount)
984987

988+
def normalize(self, perturbation: float = 0.0, pseudocount: float = 0.0):
989+
params = torch.exp(torch.rand([self.num_sum_params]) * -perturbation)
990+
params = params.to(self.device)
991+
params[:self.num_dummy_params] = 0.0
992+
993+
# Copy initial parameters if provided
994+
for ns in self.root_ns:
995+
if ns.is_sum() and not ns.is_tied() and ns.has_params():
996+
ns.gather_parameters(params)
997+
998+
self._normalize_parameters(params, pseudocount = pseudocount)
999+
self.params = nn.Parameter(params)
1000+
1001+
# Due to the custom inplace backward pass implementation, we do not track
1002+
# gradient of PC parameters by PyTorch.
1003+
self.params.requires_grad = False
1004+
1005+
9851006
def _create_node_layers(self):
9861007
depth2nodes = dict()
9871008
nodes2depth = dict()

src/pyjuice/nodes/distributions/literal.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,5 +74,11 @@ def em_fn(local_offsets, params_ptr, param_flows_ptr, s_pids, s_pfids, metadata_
7474
step_size, pseudocount, BLOCK_SIZE):
7575
pass
7676

77+
@staticmethod
78+
def partition_fn(local_offsets, params_ptr, s_pids, metadata_ptr, s_mids_ptr, mask, BLOCK_SIZE, TILE_SIZE_K):
79+
mars = tl.zeros([BLOCK_SIZE], dtype = tl.float32)
80+
return mars
81+
82+
7783
def _get_constructor(self):
7884
return Literal, {"lit": self.lit}

0 commit comments

Comments
 (0)