Upload modeling_nemotron_h.py
Browse files- modeling_nemotron_h.py +4 -3
modeling_nemotron_h.py
CHANGED
|
@@ -623,8 +623,8 @@ class NemotronHMamba2Mixer(nn.Module):
|
|
| 623 |
hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float()
|
| 624 |
B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
|
| 625 |
C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
|
| 626 |
-
B = B.
|
| 627 |
-
C = C.
|
| 628 |
pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size
|
| 629 |
|
| 630 |
D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)
|
|
@@ -852,7 +852,8 @@ class NemotronHMOE(nn.Module):
|
|
| 852 |
final_hidden_states.index_add_(0, token_indices, weighted_output)
|
| 853 |
else:
|
| 854 |
# Local empty expert: no-op compute that still marks params as used.
|
| 855 |
-
|
|
|
|
| 856 |
final_hidden_states = final_hidden_states + dummy_out
|
| 857 |
|
| 858 |
# in original deepseek, the output of the experts are gathered once we leave this module
|
|
|
|
| 623 |
hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float()
|
| 624 |
B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
|
| 625 |
C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
|
| 626 |
+
B = B.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads)
|
| 627 |
+
C = C.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads)
|
| 628 |
pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size
|
| 629 |
|
| 630 |
D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)
|
|
|
|
| 852 |
final_hidden_states.index_add_(0, token_indices, weighted_output)
|
| 853 |
else:
|
| 854 |
# Local empty expert: no-op compute that still marks params as used.
|
| 855 |
+
expert_dtype = expert.down_proj.weight.dtype
|
| 856 |
+
dummy_out = expert(torch.zeros_like(hidden_states[0]).unsqueeze(0).to(expert_dtype))
|
| 857 |
final_hidden_states = final_hidden_states + dummy_out
|
| 858 |
|
| 859 |
# in original deepseek, the output of the experts are gathered once we leave this module
|