Skip to content

Commit 3a94af6

Browse files
committed
allow fine customization of condition functions
1 parent 00ab963 commit 3a94af6

File tree

2 files changed

+19
-8
lines changed

2 files changed

+19
-8
lines changed

classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,10 @@ def __init__(
400400
self.conditioners = nn.ModuleList([])
401401

402402
self.hidden_dims = hidden_dims
403-
self.hiddens_channel_first = hiddens_channel_first # whether hiddens to be conditioned is channel first or last
403+
self.num_condition_fns = len(hidden_dims)
404+
self.hiddens_channel_first = cast_tuple(hiddens_channel_first, self.num_condition_fns) # whether hiddens to be conditioned is channel first or last
405+
406+
assert len(self.hiddens_channel_first) == self.num_condition_fns
404407

405408
self.cond_drop_prob = cond_drop_prob
406409

@@ -447,13 +450,11 @@ def forward(
447450
elif exists(text_embeds):
448451
batch = text_embeds.shape[0]
449452

450-
device = self.device
451-
452453
if not exists(text_embeds):
453454
text_embeds = self.embed_texts(texts)
454455

455456
if cond_drop_prob > 0.:
456-
prob_keep_mask = prob_mask_like((batch, 1), 1. - cond_drop_prob, device = device)
457+
prob_keep_mask = prob_mask_like((batch, 1), 1. - cond_drop_prob, device = self.device)
457458
null_text_embeds = rearrange(self.null_text_embed, 'd -> 1 d')
458459

459460
text_embeds = torch.where(
@@ -462,8 +463,18 @@ def forward(
462463
null_text_embeds
463464
)
464465

465-
text_embeds = repeat(text_embeds, 'b ... -> (b r) ...', r = repeat_batch)
466+
# prepare the conditioning functions
467+
468+
repeat_batch = cast_tuple(repeat_batch, self.num_condition_fns)
469+
470+
cond_fns = []
471+
472+
for cond, cond_hiddens_channel_first, cond_repeat_batch in zip(self.conditioners, self.hiddens_channel_first, repeat_batch):
473+
cond_text_embeds = repeat(text_embeds, 'b ... -> (b r) ...', r = cond_repeat_batch)
474+
cond_fn = partial(cond, cond_text_embeds)
475+
476+
wrapper_fn = rearrange_channel_first if cond_hiddens_channel_first else rearrange_channel_last
466477

467-
wrapper_fn = rearrange_channel_first if self.hiddens_channel_first else rearrange_channel_last
478+
cond_fns.append(wrapper_fn(cond_fn))
468479

469-
return tuple(wrapper_fn(partial(cond, text_embeds)) for cond in self.conditioners)
480+
return tuple(cond_fns)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
name = 'classifier-free-guidance-pytorch',
55
packages = find_packages(exclude=[]),
66
include_package_data = True,
7-
version = '0.0.22',
7+
version = '0.0.23',
88
license='MIT',
99
description = 'Classifier Free Guidance - Pytorch',
1010
author = 'Phil Wang',

0 commit comments

Comments
 (0)