import sys from typing import Optional, List, Union, Tuple import torch import copy from torch import nn import torch.nn.functional as F from torch.nn import CrossEntropyLoss from transformers import LlamaForCausalLM, LlamaPreTrainedModel, LlamaConfig, AutoModel from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPast from transformers.models.idefics.modeling_idefics import LLAMA_INPUTS_DOCSTRING, _CONFIG_FOR_DOC from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm, LlamaModel from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings, logging from transformers.cache_utils import Cache, DynamicCache, StaticCache from transformers.modeling_attn_mask_utils import AttentionMaskConverter import torch.distributed as dist logger = logging.get_logger(__name__) class NewLlamaModel(LlamaModel): add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if self.config._attn_implementation == "flash_attention_2": raise ValueError( "You can not use flash attention to pretrain" ) if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" ) if self.gradient_checkpointing and self.training and use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." ) use_cache = False if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) return_legacy_cache = False if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) return_legacy_cache = True past_key_values = DynamicCache.from_legacy_cache(past_key_values) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) summarize_suffix_ids = [9162, 19138, 675, 278, 2038, 13382, 2629, 9475, 3838, 29901, 29871, 32008, 32011, 32004, 32013, 32007, 32005, 32002, 32014] predict_suffix_ids = [9162, 8500, 278, 1494, 13382, 2629, 9475, 3838, 29901, 29871, 32000, 32009, 32012, 32001, 32010, 32003, 32006, 32015] if position_ids is None: position_ids = cache_position.unsqueeze(0) for i in range(len(position_ids)): position_ids[i][-len(predict_suffix_ids):] = copy.deepcopy(position_ids[i][ -len(summarize_suffix_ids) - len(predict_suffix_ids): -len( summarize_suffix_ids)]) causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) causal_mask[:, :, -len(predict_suffix_ids) :, -len(predict_suffix_ids) - len(summarize_suffix_ids): -len(predict_suffix_ids), ] = torch.finfo(inputs_embeds.dtype).min # embed positions hidden_states = inputs_embeds # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None if return_legacy_cache: next_cache = next_cache.to_legacy_cache() if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, ) def _update_causal_mask( self, attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, output_attentions: bool, ): # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward # if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: # if AttentionMaskConverter._ignore_causal_mask_sdpa( # attention_mask, # inputs_embeds=input_tensor, # past_key_values_length=past_seen_tokens, # is_training=self.training, # ): # return None dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] if using_static_cache: target_length = past_key_values.get_max_length() else: target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) if attention_mask is not None and attention_mask.dim() == 4: # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing if attention_mask.max() != 0: raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") causal_mask = attention_mask else: causal_mask = torch.full( (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device ) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit mask_length = attention_mask.shape[-1] padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype ) if ( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type == "cuda" and not output_attentions ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask class PreLlamaModel(LlamaForCausalLM): def __init__(self, config): super().__init__(config) self.model = NewLlamaModel(config) self.log_softmax = torch.nn.LogSoftmax(dim=1) """ prompt type1: "{}", summarize the above passage within eight words: token ids: [376, ..., 9162, 19138, 675, 278, 2038, 13382, 2629, 9475, 3838, 29901, 29871, 32008, 32011, 32004, 32013, 32007, 32005, 32002, 32014] prompt type2: "{}", predict the following passage within eight words: token ids: [376, ..., 8500, 278, 1494, 13382, 2629, 9475, 3838, 29901, 29871, 32000, 32009, 32012, 32001, 32010, 32003, 32006, 32015] [9162, 19138, 675, 278, 2038, 13382, 2629, 9475, 3838, 29901, 29871, 32008, 32011, 32004, 32013, 32007, 32005, 32002, 32014, 8500, 278, 1494, 13382, 2629, 9475, 3838, 29901, 29871, 32000, 32009, 32012, 32001, 32010, 32003, 32006, 32015] Maybe only one of them will appear, or both may appear. We consider all possibilities here. """ self.summarize_prompt_ids = [9162, 19138, 675, 278, 2038, 13382, 2629, 9475, 3838, 29901, 29871, 32008, 32011, 32004, 32013, 32007, 32005, 32002, 32014] self.predict_prompt_ids = [9162, 8500, 278, 1494, 13382, 2629, 9475, 3838, 29901, 29871, 32000, 32009, 32012, 32001, 32010, 32003, 32006, 32015] @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, output_summarize_ids: Optional[torch.LongTensor] = None, output_predict_ids: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Returns: Example: ```python >>> from transformers import AutoTokenizer, LlamaForCausalLM >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, ) hidden_states = outputs[0] if self.config.pretraining_tp > 1: lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] logits = torch.cat(logits, dim=-1) else: logits = self.lm_head(hidden_states) logits = logits.float() ar_loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model parallelism shift_labels = shift_labels.to(shift_logits.device) ar_loss = loss_fct(shift_logits, shift_labels) """ prompt: ", summarize the above passage within eight words: ", predict the following passage within eight words: token ids: [9162, 19138, 675, 278, 2038, 13382, 2629, 9475, 3838, 29901, 29871, 32008, 32011, 32004, 32013, 32007, 32005, 32002, 32014, 9162, 8500, 278, 1494, 13382, 2629, 9475, 3838, 29901, 29871, 32000, 32009, 32012, 32001, 32010, 32003, 32006, 32015] token ids[-26:-18] —— token ids[-8:] —— """ bow_summarize_loss = 0 if output_summarize_ids is not None: special_logits = logits[:, -len(self.predict_prompt_ids) - 8:-len(self.predict_prompt_ids), :] special_logits, _ = torch.max(special_logits, dim=1) bow_summarize_loss = 0 possibility = self.log_softmax(special_logits) batch_num = 0 for p, temp_output_ids in zip(possibility, output_summarize_ids): unique_useful_ids = torch.unique(temp_output_ids[temp_output_ids > 2]) if len(unique_useful_ids) > 0: bow_summarize_loss -= torch.mean(p[unique_useful_ids]) batch_num += 1 if batch_num > 0: bow_summarize_loss /= batch_num bow_summarize_loss /= 10 bow_predict_loss = 0 if output_predict_ids is not None: special_logits = logits[:, -8:, :] special_logits, _ = torch.max(special_logits, dim=1) bow_predict_loss = 0 possibility = self.log_softmax(special_logits) batch_num = 0 for p, temp_output_ids in zip(possibility, output_predict_ids): unique_useful_ids = torch.unique(temp_output_ids[temp_output_ids > 2]) if len(unique_useful_ids) > 0: bow_predict_loss -= torch.mean(p[unique_useful_ids]) batch_num += 1 if batch_num > 0: bow_predict_loss /= batch_num bow_predict_loss /= 10 if bow_summarize_loss > 0 and bow_predict_loss > 0: bow_loss = (bow_summarize_loss + bow_predict_loss) / 2 elif bow_summarize_loss > 0: bow_loss = bow_summarize_loss elif bow_predict_loss > 0: bow_loss = bow_predict_loss else: bow_loss = None if ar_loss is not None and bow_loss is not None: loss = ar_loss + bow_loss elif ar_loss is None: loss = bow_loss else: loss = ar_loss if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) class PreModel(nn.Module): def __init__(self, model: AutoModel = None, ): super().__init__() self.model = model def gradient_checkpointing_enable(self, **kwargs): self.model.gradient_checkpointing_enable(**kwargs) def enable_input_require_grads(self, **kwargs): self.model.enable_input_require_grads(**kwargs) def forward(self, *args, **kwargs): return self.model(*args, **kwargs) def save(self, output_dir: str): state_dict = self.model.state_dict() state_dict = type(state_dict)( {k: v.clone().cpu() for k, v in state_dict.items()}) self.model.save_pretrained(output_dir, state_dict=state_dict)