diff --git a/tf_keras/api/golden/v2/tensorflow.keras.models.experimental.-sharpness-aware-minimization.pbtxt b/tf_keras/api/golden/v2/tensorflow.keras.models.experimental.-sharpness-aware-minimization.pbtxt index 9c9473c9f..3a7c3ef6d 100644 --- a/tf_keras/api/golden/v2/tensorflow.keras.models.experimental.-sharpness-aware-minimization.pbtxt +++ b/tf_keras/api/golden/v2/tensorflow.keras.models.experimental.-sharpness-aware-minimization.pbtxt @@ -199,7 +199,7 @@ tf_class { } member_method { name: "call" - argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None" } member_method { name: "compile" diff --git a/tf_keras/models/sharpness_aware_minimization.py b/tf_keras/models/sharpness_aware_minimization.py index 70c4c2583..589d62cdf 100644 --- a/tf_keras/models/sharpness_aware_minimization.py +++ b/tf_keras/models/sharpness_aware_minimization.py @@ -72,17 +72,27 @@ def train_step(self, data): if self.num_batch_splits is not None: x_split = tf.split(x, self.num_batch_splits) y_split = tf.split(y, self.num_batch_splits) + # Split the sample weight if it is provided. + if sample_weight is not None: + sample_weight_split = tf.split( + sample_weight, self.num_batch_splits + ) + else: + sample_weight_split = [None] * self.num_batch_splits else: x_split = [x] y_split = [y] + sample_weight_split = [sample_weight] gradients_all_batches = [] pred_all_batches = [] - for x_batch, y_batch in zip(x_split, y_split): + for x_batch, y_batch, sample_weight_batch in zip( + x_split, y_split, sample_weight_split + ): epsilon_w_cache = [] with tf.GradientTape() as tape: - pred = self.model(x_batch) - loss = self.compiled_loss(y_batch, pred) + pred = self(x_batch) + loss = self.compiled_loss(y_batch, pred, sample_weight_batch) pred_all_batches.append(pred) trainable_variables = self.model.trainable_variables gradients = tape.gradient(loss, trainable_variables) @@ -98,8 +108,8 @@ def train_step(self, data): epsilon_w_cache.append(epsilon_w) with tf.GradientTape() as tape: - pred = self(x_batch) - loss = self.compiled_loss(y_batch, pred) + pred = self(x_batch, training=True) + loss = self.compiled_loss(y_batch, pred, sample_weight_batch) gradients = tape.gradient(loss, trainable_variables) if len(gradients_all_batches) == 0: for gradient in gradients: @@ -127,7 +137,7 @@ def train_step(self, data): self.compiled_metrics.update_state(y, pred, sample_weight) return {m.name: m.result() for m in self.metrics} - def call(self, inputs): + def call(self, inputs, **kwargs): """Forward pass of SAM. SAM delegates the forward pass call to the wrapped model. @@ -138,7 +148,7 @@ def call(self, inputs): Returns: A Tensor, the outputs of the wrapped model for given `inputs`. """ - return self.model(inputs) + return self.model(inputs, **kwargs) def get_config(self): config = super().get_config()