Source code for linear_relational.CausalEditor

from __future__ import annotations

from dataclasses import dataclass
from typing import Callable, Literal, Optional, Sequence, TypeVar, Union, cast

import torch
from tokenizers import Tokenizer
from torch import nn

from linear_relational.Concept import Concept
from linear_relational.lib.layer_matching import (
    LayerMatcher,
    collect_matching_layers,
    get_layer_name,
    guess_hidden_layer_matcher,
)
from linear_relational.lib.token_utils import (
    ensure_tokenizer_has_pad_token,
    find_final_word_token_index,
    make_inputs,
    predict_all_token_probs_from_input,
    predict_next_tokens_greedy,
)
from linear_relational.lib.torch_utils import get_device, untuple_tensor
from linear_relational.lib.TraceLayerDict import TraceLayerDict
from linear_relational.lib.util import batchify

EditorSubject = Union[str, int, Callable[[str, list[int]], int]]


[docs] @dataclass class ConceptSwapRequest: text: str subject: EditorSubject remove_concept: str add_concept: str
[docs] @dataclass class ConceptSwapAndPredictGreedyRequest(ConceptSwapRequest): predict_num_tokens: int = 1
T = TypeVar("T")
[docs] class CausalEditor: """Modify model activations during inference to swap concepts""" concepts: list[Concept] model: nn.Module tokenizer: Tokenizer layer_matcher: LayerMatcher layer_name_to_num: dict[str, int] def __init__( self, model: nn.Module, tokenizer: Tokenizer, concepts: list[Concept], layer_matcher: Optional[LayerMatcher] = None, ) -> None: self.concepts = concepts self.model = model self.tokenizer = tokenizer self.layer_matcher = layer_matcher or guess_hidden_layer_matcher(model) ensure_tokenizer_has_pad_token(tokenizer) num_layers = len(collect_matching_layers(self.model, self.layer_matcher)) self.layer_name_to_num = {} for layer_num in range(num_layers): self.layer_name_to_num[ get_layer_name(model, self.layer_matcher, layer_num) ] = layer_num @property def device(self) -> torch.device: return get_device(self.model)
[docs] def swap_subject_concepts_and_predict_greedy( self, text: str, subject: EditorSubject, remove_concept: str, add_concept: str, # if False, edit the subject token at every layer edit_single_layer: int | Literal[False] = False, predict_num_tokens: int = 1, magnitude_multiplier: float = 1.0, # if True, use the magnitude of the projection of the remove_concept against the subject's original activation # if False, use the magnitude of the subject's original activation use_remove_concept_projection_magnitude: bool = False, ) -> str: results = self.swap_subject_concepts_and_predict_greedy_bulk( [ ConceptSwapAndPredictGreedyRequest( text, subject, remove_concept, add_concept, predict_num_tokens ) ], magnitude_multiplier=magnitude_multiplier, edit_single_layer=edit_single_layer, use_remove_concept_projection_magnitude=use_remove_concept_projection_magnitude, ) return results[0]
[docs] def swap_subject_concepts_and_predict_greedy_bulk( self, requests: Sequence[ConceptSwapAndPredictGreedyRequest], # if False, edit the subject token at every layer edit_single_layer: int | Literal[False] = False, magnitude_multiplier: float = 1.0, # if True, use the magnitude of the projection of the remove_concept against the subject's original activation # if False, use the magnitude of the subject's original activation use_remove_concept_projection_magnitude: bool = False, ) -> list[str]: next_tokens = self.swap_subject_concepts_and_predict_tokens_greedy_bulk( requests, edit_single_layer=edit_single_layer, magnitude_multiplier=magnitude_multiplier, use_remove_concept_projection_magnitude=use_remove_concept_projection_magnitude, ) return [self.tokenizer.decode(tokens) for tokens in next_tokens]
[docs] def swap_subject_concepts_and_predict_tokens_greedy_bulk( self, requests: Sequence[ConceptSwapAndPredictGreedyRequest], # if False, edit the subject token at every layer edit_single_layer: int | Literal[False], magnitude_multiplier: float = 1.0, # if True, use the magnitude of the projection of the remove_concept against the subject's original activation # if False, use the magnitude of the subject's original activation use_remove_concept_projection_magnitude: bool = False, batch_size: int = 12, show_progress: bool = False, ) -> list[list[int]]: results: list[list[int]] = [] for batch in batchify(requests, batch_size, show_progress=show_progress): def run_batch_fn() -> list[list[int]]: max_num_tokens = max(req.predict_num_tokens for req in batch) next_tokens = predict_next_tokens_greedy( self.model, self.tokenizer, [req.text for req in batch], num_tokens=max_num_tokens, device=self.device, ) return [ tokens[: req.predict_num_tokens] for tokens, req in zip(next_tokens, batch) ] results.extend( self._swap_subject_concepts_and_run_batch( batch, run_fn=run_batch_fn, edit_single_layer=edit_single_layer, magnitude_multiplier=magnitude_multiplier, use_remove_concept_projection_magnitude=use_remove_concept_projection_magnitude, ) ) return results
[docs] def swap_subject_concepts_and_predict_all_token_probs_bulk( self, requests: Sequence[ConceptSwapRequest], magnitude_multiplier: float = 1.0, # if False, edit the subject token at every layer edit_single_layer: int | Literal[False] = False, # if True, use the magnitude of the projection of the remove_concept against the subject's original activation # if False, use the magnitude of the subject's original activation use_remove_concept_projection_magnitude: bool = False, batch_size: int = 12, show_progress: bool = False, ) -> list[torch.Tensor]: results: list[torch.Tensor] = [] for batch in batchify(requests, batch_size, show_progress=show_progress): def run_batch_fn() -> list[torch.Tensor]: inputs = make_inputs( self.tokenizer, [req.text for req in batch], device=self.device ) return predict_all_token_probs_from_input(self.model, inputs) results.extend( self._swap_subject_concepts_and_run_batch( batch, run_fn=run_batch_fn, magnitude_multiplier=magnitude_multiplier, edit_single_layer=edit_single_layer, use_remove_concept_projection_magnitude=use_remove_concept_projection_magnitude, ) ) return results
def _swap_subject_concepts_and_run_batch( self, requests: Sequence[ConceptSwapRequest], run_fn: Callable[[], T], edit_single_layer: int | Literal[False], magnitude_multiplier: float = 1.0, # if True, use the magnitude of the projection of the remove_concept against the subject's original activation # if False, use the magnitude of the subject's original activation use_remove_concept_projection_magnitude: bool = False, ) -> T: """ Helper to run the given run_fn while swapping the subject concept for each request. The run_fn should run the model with the same batch of inputs as specified in the requests """ subj_tokens = [self._find_subject_token(req) for req in requests] with torch.no_grad(): remove_concept_vectors = [ ( self._find_concept(req.remove_concept) .vector.detach() .clone() .type(cast(torch.dtype, self.model.dtype)) .to(self.device) ) for req in requests ] add_concept_vectors = [ ( self._find_concept(req.add_concept) .vector.detach() .clone() .type(cast(torch.dtype, self.model.dtype)) .to(self.device) ) for req in requests ] def edit_model_output(output: torch.Tensor, layer_name: str) -> torch.Tensor: if ( edit_single_layer is not False and self.layer_name_to_num[layer_name] != edit_single_layer ): return output fixed_output = untuple_tensor(output) for i, subj_token in enumerate(subj_tokens): remove_concept_vector = remove_concept_vectors[i] add_concept_vector = add_concept_vectors[i] original_subj_act = fixed_output[i][subj_token] if use_remove_concept_projection_magnitude: base_magnitude = original_subj_act.dot(remove_concept_vector) else: base_magnitude = original_subj_act.norm() magnitude = base_magnitude * magnitude_multiplier fixed_output[i][subj_token] = original_subj_act + magnitude * ( add_concept_vector - remove_concept_vector ) return output with torch.no_grad(), TraceLayerDict( self.model, layers=self.layer_name_to_num.keys(), edit_output=edit_model_output, ): return run_fn() def _find_subject_token(self, query: ConceptSwapRequest) -> int: text = query.text subject = query.subject if isinstance(subject, int): return subject if isinstance(subject, str): return find_final_word_token_index(self.tokenizer, text, subject) if callable(subject): return subject(text, self.tokenizer.encode(text)) raise ValueError(f"Unknown subject type: {type(subject)}") def _find_concept(self, concept_name: str) -> Concept: for concept in self.concepts: if concept.name == concept_name: return concept raise ValueError(f"Unknown concept: {concept_name}")