Source code for linear_relational.Concept
from __future__ import annotations
from typing import Any, Optional
import torch
from torch import nn
[docs]
class Concept(nn.Module):
"""Linear Relation Concept (LRC)"""
layer: int
vector: torch.Tensor
object: str
relation: str
name: str
metadata: dict[str, Any]
def __init__(
self,
layer: int,
vector: torch.Tensor,
object: str,
relation: str,
metadata: Optional[dict[str, Any]] = None,
name: Optional[str] = None,
) -> None:
super().__init__()
self.layer = layer
self.vector = vector
self.object = object
self.relation = relation
self.metadata = metadata or {}
self.name = name or f"{self.relation}: {self.object}"
[docs]
def forward(self, activations: torch.Tensor) -> torch.Tensor:
vector = self.vector.to(activations.device, dtype=activations.dtype)
if len(activations.shape) == 1:
return vector @ activations
return vector @ activations.T