@@ -1120,24 +1120,61 @@ def partial_rowwise_adam() -> Dict[str, Any]:
1120
1120
"""
1121
1121
)
1122
1122
split_precomputation += """
1123
+
1124
+ // Define the optimizer state (for use with optimizer offloading)
1125
+ struct OptimizerState {
1126
+ // momentum2 is a single value so it will be accessed directly as a struct field
1127
+ momentum2_ph_t momentum2;
1128
+
1129
+ // momentum1 is an array of D values, so a method to return a pointer given the offset is defined instead
1130
+ DEVICE_INLINE momentum1_ph_t* momentum1_ptr(const int32_t d) const {
1131
+ // Re-cast the address to momentum1_ph_t* and increment by d to reach the destination address
1132
+ return reinterpret_cast<momentum1_ph_t *>(
1133
+ // Cast the address this to momentum2_t* and increment by 1 to skip over the momentum2 value
1134
+ reinterpret_cast<momentum2_ph_t *>(
1135
+ // Remove the const qualifier from this if needed
1136
+ const_cast<OptimizerState *>(this)
1137
+ ) + 1
1138
+ ) + d;
1139
+ }
1140
+ };
1141
+
1142
+ // Fetch the pointer to the optimizer state along the cache row
1143
+ [[maybe_unused]] auto* optimizer = weight_row_template.template optimizer_state_ptr<OptimizerState>();
1144
+
1145
+ // Fetch the pointer to the momentum1 value
1146
+ // Define the fetch here instead of in split_weight_update to avoid conditionals inside a loop
1147
+ auto* momentum1_ptr0 = enable_optimizer_offloading ?
1148
+ (optimizer->momentum1_ptr(0)) :
1149
+ (&momentum1[idx * D]);
1150
+
1123
1151
const at::acc_type<cache_t, true> g_avg_square =
1124
1152
GROUP_REDUCE_ALL_SUM(g_local_sum_square, at::acc_type<cache_t, true>) / D;
1125
1153
1126
1154
at::acc_type<cache_t, true> v_hat_t;
1127
1155
v_hat_t = 0.0;
1128
1156
if (threadIdx.x == 0) {
1129
- at::acc_type<cache_t, true> v_t = momentum2[idx] * beta2 + g_avg_square * (1.0 - beta2);
1130
- momentum2[idx] = v_t;
1157
+ auto v_t = g_avg_square * (1.0 - beta2);
1158
+
1159
+ if (enable_optimizer_offloading) {
1160
+ v_t += optimizer->momentum2 * beta2;
1161
+ optimizer->momentum2 = v_t;
1162
+ } else {
1163
+ v_t += momentum2[idx] * beta2;
1164
+ momentum2[idx] = v_t;
1165
+ }
1166
+
1131
1167
v_hat_t = v_t / (1.0 - powf(beta2, iter));
1132
1168
}
1133
1169
v_hat_t = SHFL_SYNC(v_hat_t, 0);
1134
1170
"""
1135
1171
1136
1172
split_weight_update = """
1137
- Vec4T<momentum1_ph_t> m_t(&momentum1[idx * D + d]);
1173
+ auto* momentum1_ptr = momentum1_ptr0 + d;
1174
+ Vec4T<momentum1_ph_t> m_t(momentum1_ptr);
1138
1175
m_t.mul_(beta1);
1139
1176
m_t.fma_(grad, 1.0 - beta1);
1140
- m_t.store(&momentum1[idx * D + d] );
1177
+ m_t.store(momentum1_ptr );
1141
1178
1142
1179
weight_new.acc.x -= learning_rate * (m_t.acc.x / (1.0 - powf(beta1, iter)) / (sqrtf(v_hat_t) + eps) + weight_decay * weight_new.acc.x);
1143
1180
weight_new.acc.y -= learning_rate * (m_t.acc.y / (1.0 - powf(beta1, iter)) / (sqrtf(v_hat_t) + eps) + weight_decay * weight_new.acc.y);
@@ -1179,7 +1216,7 @@ def partial_rowwise_adam() -> Dict[str, Any]:
1179
1216
"has_gpu_support" : True ,
1180
1217
"has_vbe_support" : False ,
1181
1218
"has_global_weight_decay_support" : False ,
1182
- "has_ssd_support" : False ,
1219
+ "has_ssd_support" : True ,
1183
1220
}
1184
1221
1185
1222
0 commit comments