suhara commited on
Commit
4a31ff1
·
verified ·
1 Parent(s): 0d1698d

Upload modeling_nemotron_h.py

Browse files
Files changed (1) hide show
  1. modeling_nemotron_h.py +4 -3
modeling_nemotron_h.py CHANGED
@@ -623,8 +623,8 @@ class NemotronHMamba2Mixer(nn.Module):
623
  hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float()
624
  B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
625
  C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
626
- B = B.repeat(1, 1, self.num_heads // self.n_groups, 1)
627
- C = C.repeat(1, 1, self.num_heads // self.n_groups, 1)
628
  pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size
629
 
630
  D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)
@@ -852,7 +852,8 @@ class NemotronHMOE(nn.Module):
852
  final_hidden_states.index_add_(0, token_indices, weighted_output)
853
  else:
854
  # Local empty expert: no-op compute that still marks params as used.
855
- dummy_out = expert(torch.zeros_like(hidden_states[0]).unsqueeze(0).to(final_hidden_states.dtype))
 
856
  final_hidden_states = final_hidden_states + dummy_out
857
 
858
  # in original deepseek, the output of the experts are gathered once we leave this module
 
623
  hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float()
624
  B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
625
  C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
626
+ B = B.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads)
627
+ C = C.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads)
628
  pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size
629
 
630
  D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)
 
852
  final_hidden_states.index_add_(0, token_indices, weighted_output)
853
  else:
854
  # Local empty expert: no-op compute that still marks params as used.
855
+ expert_dtype = expert.down_proj.weight.dtype
856
+ dummy_out = expert(torch.zeros_like(hidden_states[0]).unsqueeze(0).to(expert_dtype))
857
  final_hidden_states = final_hidden_states + dummy_out
858
 
859
  # in original deepseek, the output of the experts are gathered once we leave this module