Skip to content

Commit 637da60

Browse files
q10facebook-github-bot
authored andcommitted
Support optimizer state offloading for partial rowwise adam optimizer
Summary: - Support optimizer state offloading for partial rowwise adam optimizer Differential Revision: D76491848
1 parent e9ce63d commit 637da60

File tree

1 file changed

+42
-5
lines changed

1 file changed

+42
-5
lines changed

fbgemm_gpu/codegen/genscript/optimizers.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,24 +1120,61 @@ def partial_rowwise_adam() -> Dict[str, Any]:
11201120
"""
11211121
)
11221122
split_precomputation += """
1123+
1124+
// Define the optimizer state (for use with optimizer offloading)
1125+
struct OptimizerState {
1126+
// momentum2 is a single value so it will be accessed directly as a struct field
1127+
momentum2_ph_t momentum2;
1128+
1129+
// momentum1 is an array of D values, so a method to return a pointer given the offset is defined instead
1130+
DEVICE_INLINE momentum1_ph_t* momentum1_ptr(const int32_t d) const {
1131+
// Re-cast the address to momentum1_ph_t* and increment by d to reach the destination address
1132+
return reinterpret_cast<momentum1_ph_t *>(
1133+
// Cast the address this to momentum2_t* and increment by 1 to skip over the momentum2 value
1134+
reinterpret_cast<momentum2_ph_t *>(
1135+
// Remove the const qualifier from this if needed
1136+
const_cast<OptimizerState *>(this)
1137+
) + 1
1138+
) + d;
1139+
}
1140+
};
1141+
1142+
// Fetch the pointer to the optimizer state along the cache row
1143+
[[maybe_unused]] auto* optimizer = weight_row_template.template optimizer_state_ptr<OptimizerState>();
1144+
1145+
// Fetch the pointer to the momentum1 value
1146+
// Define the fetch here instead of in split_weight_update to avoid conditionals inside a loop
1147+
auto* momentum1_ptr0 = enable_optimizer_offloading ?
1148+
(optimizer->momentum1_ptr(0)) :
1149+
(&momentum1[idx * D]);
1150+
11231151
const at::acc_type<cache_t, true> g_avg_square =
11241152
GROUP_REDUCE_ALL_SUM(g_local_sum_square, at::acc_type<cache_t, true>) / D;
11251153
11261154
at::acc_type<cache_t, true> v_hat_t;
11271155
v_hat_t = 0.0;
11281156
if (threadIdx.x == 0) {
1129-
at::acc_type<cache_t, true> v_t = momentum2[idx] * beta2 + g_avg_square * (1.0 - beta2);
1130-
momentum2[idx] = v_t;
1157+
auto v_t = g_avg_square * (1.0 - beta2);
1158+
1159+
if (enable_optimizer_offloading) {
1160+
v_t += optimizer->momentum2 * beta2;
1161+
optimizer->momentum2 = v_t;
1162+
} else {
1163+
v_t += momentum2[idx] * beta2;
1164+
momentum2[idx] = v_t;
1165+
}
1166+
11311167
v_hat_t = v_t / (1.0 - powf(beta2, iter));
11321168
}
11331169
v_hat_t = SHFL_SYNC(v_hat_t, 0);
11341170
"""
11351171

11361172
split_weight_update = """
1137-
Vec4T<momentum1_ph_t> m_t(&momentum1[idx * D + d]);
1173+
auto* momentum1_ptr = momentum1_ptr0 + d;
1174+
Vec4T<momentum1_ph_t> m_t(momentum1_ptr);
11381175
m_t.mul_(beta1);
11391176
m_t.fma_(grad, 1.0 - beta1);
1140-
m_t.store(&momentum1[idx * D + d]);
1177+
m_t.store(momentum1_ptr);
11411178
11421179
weight_new.acc.x -= learning_rate * (m_t.acc.x / (1.0 - powf(beta1, iter)) / (sqrtf(v_hat_t) + eps) + weight_decay * weight_new.acc.x);
11431180
weight_new.acc.y -= learning_rate * (m_t.acc.y / (1.0 - powf(beta1, iter)) / (sqrtf(v_hat_t) + eps) + weight_decay * weight_new.acc.y);
@@ -1179,7 +1216,7 @@ def partial_rowwise_adam() -> Dict[str, Any]:
11791216
"has_gpu_support": True,
11801217
"has_vbe_support": False,
11811218
"has_global_weight_decay_support": False,
1182-
"has_ssd_support": False,
1219+
"has_ssd_support": True,
11831220
}
11841221

11851222

0 commit comments

Comments
 (0)