Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ tf_class {
}
member_method {
name: "call"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "compile"
Expand Down
24 changes: 17 additions & 7 deletions tf_keras/models/sharpness_aware_minimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,17 +72,27 @@ def train_step(self, data):
if self.num_batch_splits is not None:
x_split = tf.split(x, self.num_batch_splits)
y_split = tf.split(y, self.num_batch_splits)
# Split the sample weight if it is provided.
if sample_weight is not None:
sample_weight_split = tf.split(
sample_weight, self.num_batch_splits
)
else:
sample_weight_split = [None] * self.num_batch_splits
else:
x_split = [x]
y_split = [y]
sample_weight_split = [sample_weight]

gradients_all_batches = []
pred_all_batches = []
for x_batch, y_batch in zip(x_split, y_split):
for x_batch, y_batch, sample_weight_batch in zip(
x_split, y_split, sample_weight_split
):
epsilon_w_cache = []
with tf.GradientTape() as tape:
pred = self.model(x_batch)
loss = self.compiled_loss(y_batch, pred)
pred = self(x_batch)
loss = self.compiled_loss(y_batch, pred, sample_weight_batch)
pred_all_batches.append(pred)
trainable_variables = self.model.trainable_variables
gradients = tape.gradient(loss, trainable_variables)
Expand All @@ -98,8 +108,8 @@ def train_step(self, data):
epsilon_w_cache.append(epsilon_w)

with tf.GradientTape() as tape:
pred = self(x_batch)
loss = self.compiled_loss(y_batch, pred)
pred = self(x_batch, training=True)
loss = self.compiled_loss(y_batch, pred, sample_weight_batch)
gradients = tape.gradient(loss, trainable_variables)
if len(gradients_all_batches) == 0:
for gradient in gradients:
Expand Down Expand Up @@ -127,7 +137,7 @@ def train_step(self, data):
self.compiled_metrics.update_state(y, pred, sample_weight)
return {m.name: m.result() for m in self.metrics}

def call(self, inputs):
def call(self, inputs, **kwargs):
"""Forward pass of SAM.

SAM delegates the forward pass call to the wrapped model.
Expand All @@ -138,7 +148,7 @@ def call(self, inputs):
Returns:
A Tensor, the outputs of the wrapped model for given `inputs`.
"""
return self.model(inputs)
return self.model(inputs, **kwargs)

def get_config(self):
config = super().get_config()
Expand Down
Loading