from dataclasses import dataclass
from typing import Callable, Optional, Sequence, Union
import torch
from tokenizers import Tokenizer
from torch import nn
from linear_relational.Concept import Concept
from linear_relational.lib.extract_token_activations import (
TokenLayerActivationsList,
extract_token_activations,
)
from linear_relational.lib.layer_matching import (
LayerMatcher,
collect_matching_layers,
get_layer_name,
guess_hidden_layer_matcher,
)
from linear_relational.lib.token_utils import (
ensure_tokenizer_has_pad_token,
find_final_word_token_index,
)
from linear_relational.lib.torch_utils import get_device
from linear_relational.lib.util import batchify
QuerySubject = Union[str, int, Callable[[str, list[int]], int]]
[docs]
@dataclass
class ConceptMatchQuery:
text: str
subject: QuerySubject
[docs]
@dataclass
class ConceptMatchResult:
concept: str
score: float
[docs]
@dataclass
class QueryResult:
concept_results: dict[str, ConceptMatchResult]
@property
def best_match(self) -> ConceptMatchResult:
return max(self.concept_results.values(), key=lambda x: x.score)
[docs]
class ConceptMatcher:
"""Match concepts against subject activations in a model"""
concepts: list[Concept]
model: nn.Module
tokenizer: Tokenizer
layer_matcher: LayerMatcher
layer_name_to_num: dict[str, int]
map_activations_fn: (
Callable[[TokenLayerActivationsList], TokenLayerActivationsList] | None
)
def __init__(
self,
model: nn.Module,
tokenizer: Tokenizer,
concepts: list[Concept],
layer_matcher: Optional[LayerMatcher] = None,
map_activations_fn: (
Callable[[TokenLayerActivationsList], TokenLayerActivationsList] | None
) = None,
) -> None:
self.concepts = concepts
self.model = model
self.tokenizer = tokenizer
self.layer_matcher = layer_matcher or guess_hidden_layer_matcher(model)
self.map_activations_fn = map_activations_fn
ensure_tokenizer_has_pad_token(tokenizer)
num_layers = len(collect_matching_layers(self.model, self.layer_matcher))
self.layer_name_to_num = {}
for layer_num in range(num_layers):
self.layer_name_to_num[
get_layer_name(model, self.layer_matcher, layer_num)
] = layer_num
[docs]
def query(self, query: str, subject: QuerySubject) -> QueryResult:
return self.query_bulk([ConceptMatchQuery(query, subject)])[0]
[docs]
def query_bulk(
self,
queries: Sequence[ConceptMatchQuery],
batch_size: int = 4,
verbose: bool = False,
) -> list[QueryResult]:
results: list[QueryResult] = []
for batch in batchify(queries, batch_size, show_progress=verbose):
results.extend(self._query_batch(batch))
return results
def _query_batch(self, queries: Sequence[ConceptMatchQuery]) -> list[QueryResult]:
subj_tokens = [self._find_subject_token(query) for query in queries]
with torch.no_grad():
batch_subj_token_activations = extract_token_activations(
self.model,
self.tokenizer,
layers=self.layer_name_to_num.keys(),
texts=[q.text for q in queries],
token_indices=subj_tokens,
device=get_device(self.model),
# batching is handled already, so no need to batch here too
batch_size=len(queries),
show_progress=False,
)
if self.map_activations_fn is not None:
batch_subj_token_activations = self.map_activations_fn(
batch_subj_token_activations
)
results: list[QueryResult] = []
for raw_subj_token_activations in batch_subj_token_activations:
concept_results: dict[str, ConceptMatchResult] = {}
# need to replace the layer name with the layer number
subj_token_activations = {
self.layer_name_to_num[layer_name]: layer_activations[0]
for layer_name, layer_activations in raw_subj_token_activations.items()
}
for concept in self.concepts:
concept_results[concept.name] = _apply_concept_to_activations(
concept, subj_token_activations
)
results.append(QueryResult(concept_results))
return results
def _find_subject_token(self, query: ConceptMatchQuery) -> int:
text = query.text
subject = query.subject
if isinstance(subject, int):
return subject
if isinstance(subject, str):
return find_final_word_token_index(self.tokenizer, text, subject)
if callable(subject):
return subject(text, self.tokenizer.encode(text))
raise ValueError(f"Unknown subject type: {type(subject)}")
@torch.no_grad()
def _apply_concept_to_activations(
concept: Concept, activations: dict[int, torch.Tensor]
) -> ConceptMatchResult:
score = concept.forward(activations[concept.layer]).item()
return ConceptMatchResult(
concept=concept.name,
score=score,
)