Source code for linear_relational.training.Trainer

from collections import defaultdict
from time import time
from typing import Callable, Literal, Optional

import torch
from tokenizers import Tokenizer
from torch import nn

from linear_relational.Concept import Concept
from linear_relational.lib.balance_grouped_items import balance_grouped_items
from linear_relational.lib.extract_token_activations import extract_token_activations
from linear_relational.lib.layer_matching import (
    LayerMatcher,
    get_layer_name,
    guess_hidden_layer_matcher,
)
from linear_relational.lib.logger import log_or_print, logger
from linear_relational.lib.token_utils import PromptAnswerData, find_prompt_answer_data
from linear_relational.lib.torch_utils import get_device
from linear_relational.lib.util import group_items
from linear_relational.Lre import InvertedLre, Lre
from linear_relational.Prompt import Prompt
from linear_relational.PromptValidator import PromptValidator
from linear_relational.training.train_lre import ObjectAggregation, train_lre

VectorAggregation = Literal["pre_mean", "post_mean"]


[docs] class Trainer: """Train LREs and concepts from prompts""" model: nn.Module tokenizer: Tokenizer layer_matcher: LayerMatcher prompt_validator: PromptValidator def __init__( self, model: nn.Module, tokenizer: Tokenizer, layer_matcher: Optional[LayerMatcher] = None, prompt_validator: Optional[PromptValidator] = None, ) -> None: self.model = model self.tokenizer = tokenizer self.layer_matcher = layer_matcher or guess_hidden_layer_matcher(model) self.prompt_validator = prompt_validator or PromptValidator(model, tokenizer)
[docs] def train_lre( self, relation: str, subject_layer: int, object_layer: int, prompts: list[Prompt], max_lre_training_samples: int | None = None, object_aggregation: ObjectAggregation = "mean", validate_prompts: bool = True, validate_prompts_batch_size: int = 4, move_to_cpu: bool = False, verbose: bool = True, seed: int | str | float = 42, ) -> Lre: processed_prompts = self._process_relation_prompts( relation=relation, prompts=prompts, validate_prompts=validate_prompts, validate_prompts_batch_size=validate_prompts_batch_size, verbose=verbose, ) prompts_by_object = group_items(processed_prompts, lambda p: p.object_name) lre_train_prompts = balance_grouped_items( items_by_group=prompts_by_object, max_total=max_lre_training_samples, seed=seed, ) return train_lre( model=self.model, tokenizer=self.tokenizer, layer_matcher=self.layer_matcher, relation=relation, subject_layer=subject_layer, object_layer=object_layer, prompts=lre_train_prompts, object_aggregation=object_aggregation, move_to_cpu=move_to_cpu, )
[docs] def train_relation_concepts( self, relation: str, subject_layer: int, object_layer: int, prompts: list[Prompt], max_lre_training_samples: int | None = 20, object_aggregation: ObjectAggregation = "mean", vector_aggregation: VectorAggregation = "post_mean", inv_lre_rank: int = 200, validate_prompts_batch_size: int = 4, validate_prompts: bool = True, verbose: bool = True, name_concept_fn: Optional[Callable[[str, str], str]] = None, seed: int | str | float = 42, ) -> list[Concept]: processed_prompts = self._process_relation_prompts( relation=relation, prompts=prompts, validate_prompts=validate_prompts, validate_prompts_batch_size=validate_prompts_batch_size, verbose=verbose, ) prompts_by_object = group_items(processed_prompts, lambda p: p.object_name) if len(prompts_by_object) == 1: logger.warning( f"Only one valid object found for {relation}. Results may be poor." ) lre_train_prompts = balance_grouped_items( items_by_group=prompts_by_object, max_total=max_lre_training_samples, seed=seed, ) inv_lre = train_lre( model=self.model, tokenizer=self.tokenizer, layer_matcher=self.layer_matcher, relation=relation, subject_layer=subject_layer, object_layer=object_layer, prompts=lre_train_prompts, object_aggregation=object_aggregation, ).invert(inv_lre_rank) return self.train_relation_concepts_from_inv_lre( relation=relation, inv_lre=inv_lre, prompts=processed_prompts, vector_aggregation=vector_aggregation, object_aggregation=object_aggregation, object_layer=object_layer, validate_prompts_batch_size=validate_prompts_batch_size, validate_prompts=False, # we already validated the prompts above name_concept_fn=name_concept_fn, verbose=verbose, )
[docs] def train_relation_concepts_from_inv_lre( self, inv_lre: InvertedLre | Callable[[str], InvertedLre], prompts: list[Prompt], vector_aggregation: VectorAggregation = "post_mean", object_aggregation: ObjectAggregation | None = None, relation: str | None = None, object_layer: int | None = None, validate_prompts_batch_size: int = 4, extract_objects_batch_size: int = 4, validate_prompts: bool = True, name_concept_fn: Optional[Callable[[str, str], str]] = None, verbose: bool = True, ) -> list[Concept]: if isinstance(inv_lre, InvertedLre): if object_aggregation is None: object_aggregation = inv_lre.object_aggregation if object_layer is None: object_layer = inv_lre.object_layer if relation is None: relation = inv_lre.relation if object_aggregation is None: raise ValueError( "object_aggregation must be specified if inv_lre is a function" ) if object_layer is None: raise ValueError("object_layer must be specified if inv_lre is a function") if relation is None: raise ValueError("relation must be specified if inv_lre is a function") processed_prompts = self._process_relation_prompts( relation=relation, prompts=prompts, validate_prompts=validate_prompts, validate_prompts_batch_size=validate_prompts_batch_size, verbose=verbose, ) start_time = time() object_activations = self._extract_target_object_activations_for_inv_lre( prompts=processed_prompts, batch_size=extract_objects_batch_size, object_aggregation=object_aggregation, object_layer=object_layer, show_progress=verbose, move_to_cpu=True, ) logger.info( f"Extracted {len(object_activations)} object activations in {time() - start_time:.2f}s" ) concepts: list[Concept] = [] with torch.no_grad(): for ( object_name, activations, ) in object_activations.items(): resolved_inv_lre = ( inv_lre if isinstance(inv_lre, InvertedLre) else inv_lre(object_name) ) name = None if name_concept_fn is not None: name = name_concept_fn(relation, object_name) concept = self._build_concept( relation_name=relation, layer=resolved_inv_lre.subject_layer, inv_lre=resolved_inv_lre, object_name=object_name, activations=activations, vector_aggregation=vector_aggregation, name=name, ) concepts.append(concept) return concepts
def _process_relation_prompts( self, relation: str, prompts: list[Prompt], validate_prompts: bool, validate_prompts_batch_size: int, verbose: bool, ) -> list[Prompt]: valid_prompts = prompts if validate_prompts: log_or_print(f"validating {len(prompts)} prompts", verbose=verbose) valid_prompts = self.prompt_validator.filter_prompts( prompts, validate_prompts_batch_size, verbose ) if len(valid_prompts) == 0: raise ValueError(f"No valid prompts found for {relation}.") return valid_prompts def _build_concept( self, layer: int, relation_name: str, object_name: str, activations: list[torch.Tensor], inv_lre: InvertedLre, vector_aggregation: VectorAggregation, name: str | None, ) -> Concept: device = inv_lre.bias.device dtype = inv_lre.bias.dtype if vector_aggregation == "pre_mean": acts = [torch.stack(activations).to(device=device, dtype=dtype).mean(dim=0)] elif vector_aggregation == "post_mean": acts = [act.to(device=device, dtype=dtype) for act in activations] else: raise ValueError(f"Unknown vector aggregation method {vector_aggregation}") vecs = [ inv_lre.calculate_subject_activation(act, normalize=False) for act in acts ] vec = torch.stack(vecs).mean(dim=0) vec = vec / vec.norm() return Concept( name=name, object=object_name, relation=relation_name, layer=layer, vector=vec.detach().clone().cpu(), ) @torch.no_grad() def _extract_target_object_activations_for_inv_lre( self, object_layer: int, object_aggregation: Literal["mean", "first_token"], prompts: list[Prompt], batch_size: int, show_progress: bool = False, move_to_cpu: bool = True, ) -> dict[str, list[torch.Tensor]]: activations_by_object: dict[str, list[torch.Tensor]] = defaultdict(list) prompt_answer_data: list[PromptAnswerData] = [] for prompt in prompts: prompt_answer_data.append( find_prompt_answer_data(self.tokenizer, prompt.text, prompt.answer) ) layer_name = get_layer_name(self.model, self.layer_matcher, object_layer) raw_activations = extract_token_activations( self.model, self.tokenizer, layers=[layer_name], texts=[prompt_answer.full_prompt for prompt_answer in prompt_answer_data], token_indices=[ prompt_answer.output_answer_token_indices for prompt_answer in prompt_answer_data ], device=get_device(self.model), batch_size=batch_size, show_progress=show_progress, move_results_to_cpu=move_to_cpu, ) for prompt, raw_activation in zip(prompts, raw_activations): if object_aggregation == "mean": activation = torch.stack(raw_activation[layer_name]).mean(dim=0) elif object_aggregation == "first_token": activation = raw_activation[layer_name][0] else: raise ValueError( f"Unknown inv_lre.object_aggregation: {object_aggregation}" ) activations_by_object[prompt.object_name].append(activation) return activations_by_object