233 lines
7.7 KiB
Python
233 lines
7.7 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# top-level folder for each specific model found within the models/ directory at
|
|
# the top-level of this source tree.
|
|
|
|
import math
|
|
from typing import Optional, Tuple, Union
|
|
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
|
import torch
|
|
|
|
|
|
def apply_scaling(freqs: torch.Tensor):
|
|
# Values obtained from grid search
|
|
scale_factor = 8
|
|
low_freq_factor = 1
|
|
high_freq_factor = 4
|
|
old_context_len = 8192 # original llama3 length
|
|
|
|
low_freq_wavelen = old_context_len / low_freq_factor
|
|
high_freq_wavelen = old_context_len / high_freq_factor
|
|
new_freqs = []
|
|
for freq in freqs:
|
|
wavelen = 2 * math.pi / freq
|
|
if wavelen < high_freq_wavelen:
|
|
new_freqs.append(freq)
|
|
elif wavelen > low_freq_wavelen:
|
|
new_freqs.append(freq / scale_factor)
|
|
else:
|
|
assert low_freq_wavelen != high_freq_wavelen
|
|
smooth = (old_context_len / wavelen - low_freq_factor) / (
|
|
high_freq_factor - low_freq_factor
|
|
)
|
|
new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq)
|
|
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)
|
|
|
|
|
|
def precompute_freqs_cis(
|
|
dim: int,
|
|
end: int,
|
|
theta: float = 10000.0,
|
|
use_scaled: bool = False,
|
|
device: str = "cuda:0",
|
|
):
|
|
freqs = 1.0 / (
|
|
theta ** (torch.arange(0, dim, 2, device=device)[: (dim // 2)].float() / dim)
|
|
)
|
|
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
|
|
if use_scaled:
|
|
freqs = apply_scaling(freqs)
|
|
freqs = torch.outer(t, freqs)
|
|
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
|
return freqs_cis
|
|
|
|
|
|
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
|
ndim = x.ndim
|
|
assert 0 <= 1 < ndim
|
|
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
|
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
|
return freqs_cis.view(*shape)
|
|
|
|
|
|
def apply_rotary_emb(
|
|
xq: torch.Tensor,
|
|
xk: torch.Tensor,
|
|
freqs_cis: torch.Tensor,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
|
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
|
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
|
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
|
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
|
return xq_out.type_as(xq), xk_out.type_as(xk)
|
|
|
|
|
|
def apply_rotary_pos_emb(q, k, cos, sin):
|
|
cos = cos.unsqueeze(1)
|
|
sin = sin.unsqueeze(1)
|
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
return q_embed.to(q.dtype), k_embed.to(k.dtype)
|
|
|
|
|
|
def rotate_half(x):
|
|
x1 = x[..., : x.shape[-1] // 2]
|
|
x2 = x[..., x.shape[-1] // 2 :]
|
|
return torch.cat((-x2, x1), dim=-1)
|
|
|
|
|
|
def generate_cos_sin_f32_cache(
|
|
max_seq_len, head_dim, theta=1e4, use_scaled: bool = False, device: str = "cuda:0"
|
|
):
|
|
position = torch.arange(max_seq_len, device=device, dtype=torch.float32).unsqueeze(
|
|
1
|
|
)
|
|
freqs = 1.0 / (
|
|
theta
|
|
** (torch.arange(0, head_dim, 2, device=device, dtype=torch.float32) / head_dim)
|
|
)
|
|
freqs = torch.cat([freqs, freqs], dim=-1).contiguous()
|
|
if use_scaled:
|
|
freqs = apply_scaling(freqs)
|
|
args = position * freqs
|
|
sin_cache = torch.sin(args)
|
|
cos_cache = torch.cos(args)
|
|
return cos_cache, sin_cache
|
|
|
|
|
|
# The following code is from the vLLM's implementation of RoPE.
|
|
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding.py
|
|
|
|
|
|
class RotaryEmbedding(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
head_size: int,
|
|
rotary_dim: int,
|
|
max_position_embeddings: int,
|
|
base: int,
|
|
is_neox_style: bool,
|
|
dtype: torch.dtype,
|
|
device: str = "cuda:0",
|
|
) -> None:
|
|
super().__init__()
|
|
self.head_size = head_size
|
|
self.rotary_dim = rotary_dim
|
|
self.max_position_embeddings = max_position_embeddings
|
|
self.base = base
|
|
self.is_neox_style = is_neox_style
|
|
self.dtype = dtype
|
|
self.device = device
|
|
cache = self._compute_cos_sin_cache()
|
|
self.cos_sin_cache: torch.Tensor
|
|
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
|
|
|
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
|
inv_freq = 1.0 / (
|
|
base
|
|
** (
|
|
torch.arange(
|
|
0, self.rotary_dim, 2, dtype=torch.float, device=self.device
|
|
)
|
|
/ self.rotary_dim
|
|
)
|
|
)
|
|
return inv_freq
|
|
|
|
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
|
"""Compute the cos and sin cache."""
|
|
inv_freq = self._compute_inv_freq(self.base)
|
|
t = torch.arange(
|
|
self.max_position_embeddings, dtype=torch.float, device=self.device
|
|
)
|
|
|
|
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
|
cos = freqs.cos()
|
|
sin = freqs.sin()
|
|
cache = torch.cat((cos, sin), dim=-1)
|
|
return cache
|
|
|
|
def _apply_rotary_emb(
|
|
self,
|
|
x: torch.Tensor,
|
|
cos: torch.Tensor,
|
|
sin: torch.Tensor,
|
|
is_neox_style: bool,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
x: [num_tokens, num_heads, head_size]
|
|
cos: [num_tokens, head_size // 2]
|
|
sin: [num_tokens, head_size // 2]
|
|
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
|
|
positional embeddings.
|
|
"""
|
|
cos = cos.unsqueeze(-2).to(x.dtype)
|
|
sin = sin.unsqueeze(-2).to(x.dtype)
|
|
if is_neox_style:
|
|
x1, x2 = torch.chunk(x, 2, dim=-1)
|
|
else:
|
|
x1 = x[..., ::2]
|
|
x2 = x[..., 1::2]
|
|
o1 = x1 * cos - x2 * sin
|
|
o2 = x2 * cos + x1 * sin
|
|
if is_neox_style:
|
|
return torch.cat((o1, o2), dim=-1)
|
|
else:
|
|
return torch.stack((o1, o2), dim=-1).flatten(-2)
|
|
|
|
def forward_native(
|
|
self,
|
|
positions: torch.Tensor,
|
|
query: torch.Tensor,
|
|
key: torch.Tensor,
|
|
offsets: Optional[torch.Tensor] = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""A PyTorch-native implementation of forward()."""
|
|
if offsets is not None:
|
|
positions = positions + offsets
|
|
|
|
positions = positions.flatten()
|
|
num_tokens = positions.shape[0]
|
|
cos_sin = self.cos_sin_cache.index_select(0, positions)
|
|
|
|
# Note: the is different from the vLLM's implementation,
|
|
# We added float32 conversion because float32 is required for the rotary embedding to work correctly for long contexts
|
|
query = query.to(torch.float32)
|
|
key = key.to(torch.float32)
|
|
|
|
cos, sin = cos_sin.chunk(2, dim=-1)
|
|
|
|
query_shape = query.shape
|
|
query = query.view(num_tokens, -1, self.head_size)
|
|
query_rot = query[..., : self.rotary_dim]
|
|
query_pass = query[..., self.rotary_dim :]
|
|
query_rot = self._apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
|
|
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
|
|
|
key_shape = key.shape
|
|
key = key.view(num_tokens, -1, self.head_size)
|
|
key_rot = key[..., : self.rotary_dim]
|
|
key_pass = key[..., self.rotary_dim :]
|
|
key_rot = self._apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
|
|
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
|
|
|
query = query.to(self.dtype)
|
|
key = key.to(self.dtype)
|
|
return query, key
|