Trainer

class linear_relational.Trainer(model, tokenizer, layer_matcher=None, prompt_validator=None)[source]

Train LREs and concepts from prompts

layer_matcher
model
prompt_validator
tokenizer
train_lre(relation, subject_layer, object_layer, prompts, max_lre_training_samples=None, object_aggregation='mean', validate_prompts=True, validate_prompts_batch_size=4, move_to_cpu=False, verbose=True, seed=42)[source]
train_relation_concepts(relation, subject_layer, object_layer, prompts, max_lre_training_samples=20, object_aggregation='mean', vector_aggregation='post_mean', inv_lre_rank=200, validate_prompts_batch_size=4, validate_prompts=True, verbose=True, name_concept_fn=None, seed=42)[source]
train_relation_concepts_from_inv_lre(inv_lre, prompts, vector_aggregation='post_mean', object_aggregation=None, relation=None, object_layer=None, validate_prompts_batch_size=4, extract_objects_batch_size=4, validate_prompts=True, name_concept_fn=None, verbose=True)[source]
class linear_relational.Prompt(text, answer, subject, subject_name='', object_name='')[source]

A prompt for training LREs and LRCs

answer
object_name
subject
subject_name
text