Source code for linear_relational.Lre

from typing import Any, Literal

import torch
from torch import nn


[docs] class InvertedLre(nn.Module): """Low-rank inverted LRE, used for calculating subject activations from object activations""" relation: str subject_layer: int object_layer: int # store u, v, s, and bias separately to avoid storing the full weight matrix u: nn.Parameter s: nn.Parameter v: nn.Parameter bias: nn.Parameter object_aggregation: Literal["mean", "first_token"] metadata: dict[str, Any] | None = None def __init__( self, relation: str, subject_layer: int, object_layer: int, object_aggregation: Literal["mean", "first_token"], u: torch.Tensor, s: torch.Tensor, v: torch.Tensor, bias: torch.Tensor, metadata: dict[str, Any] | None = None, ) -> None: super().__init__() self.relation = relation self.subject_layer = subject_layer self.object_layer = object_layer self.object_aggregation = object_aggregation self.u = nn.Parameter(u, requires_grad=False) self.s = nn.Parameter(s, requires_grad=False) self.v = nn.Parameter(v, requires_grad=False) self.bias = nn.Parameter(bias, requires_grad=False) self.metadata = metadata @property def rank(self) -> int: return self.s.shape[0]
[docs] def w_inv_times_vec(self, vec: torch.Tensor) -> torch.Tensor: # group u.T @ vec to avoid calculating larger matrices than needed return self.v @ torch.diag(1 / self.s) @ (self.u.T @ vec)
[docs] def forward( self, object_activations: torch.Tensor, # a tensor of shape (num_activations, hidden_activation_size) normalize: bool = False, ) -> torch.Tensor: return self.calculate_subject_activation( object_activations=object_activations, normalize=normalize, )
[docs] def calculate_subject_activation( self, object_activations: torch.Tensor, # a tensor of shape (num_activations, hidden_activation_size) normalize: bool = False, ) -> torch.Tensor: # match precision of weight_inverse and bias unbiased_acts = object_activations - self.bias.unsqueeze(0) vec = self.w_inv_times_vec(unbiased_acts.T).mean(dim=1) if normalize: vec = vec / vec.norm() return vec
def __repr__(self) -> str: return f"InvertedLre({self.relation}, rank {self.rank}, layers {self.subject_layer} -> {self.object_layer}, {self.object_aggregation})"
[docs] class LowRankLre(nn.Module): """Low-rank approximation of a LRE""" relation: str subject_layer: int object_layer: int # store u, v, s, and bias separately to avoid storing the full weight matrix u: nn.Parameter s: nn.Parameter v: nn.Parameter bias: nn.Parameter object_aggregation: Literal["mean", "first_token"] metadata: dict[str, Any] | None = None def __init__( self, relation: str, subject_layer: int, object_layer: int, object_aggregation: Literal["mean", "first_token"], u: torch.Tensor, s: torch.Tensor, v: torch.Tensor, bias: torch.Tensor, metadata: dict[str, Any] | None = None, ) -> None: super().__init__() self.relation = relation self.subject_layer = subject_layer self.object_layer = object_layer self.object_aggregation = object_aggregation self.u = nn.Parameter(u, requires_grad=False) self.s = nn.Parameter(s, requires_grad=False) self.v = nn.Parameter(v, requires_grad=False) self.bias = nn.Parameter(bias, requires_grad=False) self.metadata = metadata @property def rank(self) -> int: return self.s.shape[0]
[docs] def w_times_vec(self, vec: torch.Tensor) -> torch.Tensor: # group v.T @ vec to avoid calculating larger matrices than needed return self.u @ torch.diag(self.s) @ (self.v.T @ vec)
[docs] def forward( self, subject_activations: torch.Tensor, # a tensor of shape (num_activations, hidden_activation_size) normalize: bool = False, ) -> torch.Tensor: return self.calculate_object_activation( subject_activations=subject_activations, normalize=normalize, )
[docs] def calculate_object_activation( self, subject_activations: torch.Tensor, # a tensor of shape (num_activations, hidden_activation_size) normalize: bool = False, ) -> torch.Tensor: # match precision of weight_inverse and bias ws = self.w_times_vec(subject_activations.T) vec = (ws + self.bias.unsqueeze(-1)).mean(dim=1) if normalize: vec = vec / vec.norm() return vec
def __repr__(self) -> str: return f"LowRankLre({self.relation}, rank {self.rank}, layers {self.subject_layer} -> {self.object_layer}, {self.object_aggregation})"
[docs] class Lre(nn.Module): """Linear Relational Embedding""" relation: str subject_layer: int object_layer: int weight: nn.Parameter bias: nn.Parameter object_aggregation: Literal["mean", "first_token"] metadata: dict[str, Any] | None = None def __init__( self, relation: str, subject_layer: int, object_layer: int, object_aggregation: Literal["mean", "first_token"], weight: torch.Tensor, bias: torch.Tensor, metadata: dict[str, Any] | None = None, ) -> None: super().__init__() self.relation = relation self.subject_layer = subject_layer self.object_layer = object_layer self.object_aggregation = object_aggregation self.weight = nn.Parameter(weight, requires_grad=False) self.bias = nn.Parameter(bias, requires_grad=False) self.metadata = metadata
[docs] def invert(self, rank: int) -> InvertedLre: """Invert this LRE using a low-rank approximation""" u, s, v = self._low_rank_svd(rank) return InvertedLre( relation=self.relation, subject_layer=self.subject_layer, object_layer=self.object_layer, object_aggregation=self.object_aggregation, u=u.detach().clone(), s=s.detach().clone(), v=v.detach().clone(), bias=self.bias.detach().clone(), metadata=self.metadata, )
[docs] def forward( self, subject_activations: torch.Tensor, normalize: bool = False ) -> torch.Tensor: return self.calculate_object_activation( subject_activations=subject_activations, normalize=normalize )
[docs] def calculate_object_activation( self, subject_activations: torch.Tensor, # a tensor of shape (num_activations, hidden_activation_size) normalize: bool = False, ) -> torch.Tensor: # match precision of weight_inverse and bias vec = subject_activations @ self.weight.T + self.bias if len(vec.shape) == 2: vec = vec.mean(dim=0) if normalize: vec = vec / vec.norm() return vec
[docs] def to_low_rank(self, rank: int) -> LowRankLre: """Create a low-rank approximation of this LRE""" u, s, v = self._low_rank_svd(rank) return LowRankLre( relation=self.relation, subject_layer=self.subject_layer, object_layer=self.object_layer, object_aggregation=self.object_aggregation, u=u.detach().clone(), s=s.detach().clone(), v=v.detach().clone(), bias=self.bias.detach().clone(), metadata=self.metadata, )
@torch.no_grad() def _low_rank_svd( self, rank: int ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # use a float for the svd, then convert back to the original dtype u, s, v = torch.svd(self.weight.float()) low_rank_u: torch.Tensor = u[:, :rank].to(self.weight.dtype) low_rank_v: torch.Tensor = v[:, :rank].to(self.weight.dtype) low_rank_s: torch.Tensor = s[:rank].to(self.weight.dtype) return low_rank_u, low_rank_s, low_rank_v def __repr__(self) -> str: return f"Lre({self.relation}, layers {self.subject_layer} -> {self.object_layer}, {self.object_aggregation})"