# Demo Notebook: *On the representations of entities in Auto-Regressive Large Language Models*

This notebook demonstrates the experiments descibed in our papier *On the representations of entities in Auto-Regressive Large Language Models*



## Imports and Utils

In [6]:
%cd ~/code/entityrepresentations

/home/morand/code/entityrepresentations


In [1]:
%load_ext autoreload
%autoreload 2

In [7]:
import torch, os, gc, sys
from pathlib import Path
from tqdm import tqdm
import numpy as np
from processResults import loadResults, get_taskVec

# read weights files

repo_path = Path(os.getcwd())
# load results
res_path = repo_path / "results"
if not res_path.exists():
    raise ValueError(
        "No results found, please make sure your current directory is the repository root"
    )

# get jobs
print(f"Loading results from {res_path}")
results = loadResults(res_path)

import transformer_lens as tl
from transformer_lens import HookedTransformer, patching

# our own code
import utils, plotting
from LabelExtractor import eval_model, infer_entities
import circuitsvis as cv
import plotly.io as pio

Loading results from /home/morand/code/entityrepresentations/results


100%|██████████| 66/66 [00:00<00:00, 387.72it/s]
Using the latest cached version of the module from /data/morand/.cache/hf_home/modules/evaluate_modules/metrics/evaluate-metric--chrf/d244bab9383988714085a8dacc4871986d9f025398581c33d6b2ee22836b4069 (last modified on Thu Jun 27 14:17:21 2024) since it couldn't be found locally at evaluate-metric--chrf, or remotely on the Hugging Face Hub.


In [3]:
### Testing the library for plotting

# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")
# Testing that the library works
cv.examples.hello("Fellow AI researcher")

Using renderer: notebook_connected


## Provided Results
We included the results for some of the conducted experiments in this repository. They are stored in the `results` folder. In each experiment we provide: 
- the Task Vector checkpoint `*.pth`
- the parameters used in the experiment `params.json` 
- the TV training history
- The output Evaluation of the TV on the test set

In this notebook, we have loaded them in a dataframe that enables us to quickly find and load task vectors for inference. 

In [4]:
results.head()

Unnamed: 0,path,model_name,dataset_name,layer,with_context,extraction_method,max_ent_length,max_length,epochs,logs_per_epoch,lr,batch_size,run,version,history,Eval,date,inference
0,/home/morand/code/entityrepresentations/result...,phi-2,CoNLL2003,4,False,in_context,20,200,10,4,0.04,30,4.0,1.0,"[{'epoch': 0.2490118577075099, 'loss': 1.36271...","{'Partial Match': 0.6901094822787159, 'Exact M...",1739180000.0,/home/morand/code/entityrepresentations/result...
1,/home/morand/code/entityrepresentations/result...,phi-2,CoNLL2003,28,False,in_context,20,200,10,4,0.04,30,1.0,1.0,"[{'epoch': 0.2490118577075099, 'loss': 1.35745...","{'Partial Match': 0.6607904991649657, 'Exact M...",1739180000.0,/home/morand/code/entityrepresentations/result...
2,/home/morand/code/entityrepresentations/result...,phi-2,CoNLL2003,6,False,in_context,20,200,10,4,0.04,30,0.0,1.0,"[{'epoch': 0.2490118577075099, 'loss': 1.41423...","{'Partial Match': 0.6838003340137316, 'Exact M...",1739180000.0,/home/morand/code/entityrepresentations/result...
3,/home/morand/code/entityrepresentations/result...,phi-2,CoNLL2003,23,False,in_context,20,200,10,4,0.04,30,1.0,1.0,"[{'epoch': 0.2490118577075099, 'loss': 1.17705...","{'Partial Match': 0.6850992763035814, 'Exact M...",1739180000.0,/home/morand/code/entityrepresentations/result...
4,/home/morand/code/entityrepresentations/result...,phi-2,CoNLL2003,26,False,in_context,20,200,10,4,0.04,30,0.0,1.0,"[{'epoch': 0.2490118577075099, 'loss': 1.22254...","{'Partial Match': 0.677120059380219, 'Exact Ma...",1739180000.0,/home/morand/code/entityrepresentations/result...


In [5]:
# print unique model names
print(f" Results available for Models : {results['model_name'].unique()}")

 Results available for Models : ['phi-2']


## Params 
Thanks to the [`transformer_lens`](https://github.com/TransformerLensOrg/TransformerLens/tree/main) library, we can load and use many different LLMs seemlessly.

In [6]:
# Load a model (eg GPT-2 Small)
model_name = "meta-llama/Meta-Llama-3-8B"  # ?
model_name = "gpt2-small"  # 117M ok
model_name = "pythia-2.8b"  # ok !
model_name = "gpt2-xl"  # 1.5B ok
model_name = "mistralai/Mistral-7B-v0.1"  # ok JZ a100 8cpus |
model_name = "gpt2-large"  # 774M ok
model_name = "gpt2-medium"  # 302M ok
model_name = "phi-1_5"  # 1.5B ok
model_name = "phi-2"  # 2,5B ok 12cpus nope, gpu 24cpus ok

with_context = True
with_context = False

## Load Model
By default, we try to load the model from the local hugginface cache, it this fails it will try to download it from HF. It may then need an additional acess token to download.

In [7]:
# your huggingface token if you have one, needed for some models e;g Llama
hf_token = None

# check if model variable exists
if not "model" in locals():
    model = utils.load_llm(
        model_name,
        token=hf_token,
           dtype=torch.bfloat16
    )
    dim = model.QK.shape[-1]

model.eval()
model = model.cuda()
print(model)
print(f"Model {model_name} loaded as {model.W_U.dtype}")

Model phi-2 is in cache


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model phi-2 into HookedTransformer
Moving model to device:  cuda
HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (blocks): ModuleList(
    (0-31): 32 x TransformerBlock(
      (ln1): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
        (hook_rot_k): HookPoint()
        (hook_rot_q): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): Ho

## Dataset

In [8]:
dataset_name = "CoNLL2003"
train_dataset, test_dataset, val_dataset = utils.load_datasets(
    dataset_name, max_ent_length=200
)

In [9]:
print("train length:", len(train_dataset))
print("dev length:", len(val_dataset))
print("test length:", len(test_dataset))
print("ex sample:")
item = train_dataset[np.random.randint(len(val_dataset))]
for key in item.keys():
    print(f" - {key}: {item[key]}")

train length: 22749
dev length: 5695
test length: 5389
ex sample:
 - entity: Kashmir
 - text: Gupta said there might be an increase in the number of people infiltrating the Kashmir valley to create disturbance in the region.
 - id: 2860


## Test Task Vectors

In [10]:
layer = 20  # Layer at which we want to extract the trained task vector

fileName = get_taskVec(
    results,
    model_name,
    layer=layer,
    dataset_name=dataset_name,
    with_context=with_context,
)

TaskVec = torch.load(fileName, weights_only=True)
print("TaskVec loaded from ", fileName)

found 1 jobs for layer 20 of phi-2 without context on CoNLL2003 with method in_context.
found ['TaskVec_phi-2_l20_e10.0.pth'] 
TaskVec loaded from  /home/morand/code/entityrepresentations/results/3320fdb2ffd15372a5a0b5a011871e86a696604b8335bfe874f2e363e1e91987/TaskVec_phi-2_l20_e10.0.pth


In [11]:
ind = np.random.randint(len(test_dataset))
# ind = 7616
prompt = test_dataset[ind]["text"]
prompt = test_dataset[np.random.randint(len(test_dataset))]["text"]
prompt = "Somunkonwncitea is the capital of France, the capital of France is"
prompt = "When Albert Einstein and Mandelbrot went to the store, the General Relativity father gave the bottle to"
prompt = "Netherlands's capital is the city of"
prompt = "Gaston Julia and Mandelbrot meet, the latter tells"
prompt = "The City of Lights iconic landmark"
print("prompt:", prompt)


words = model.to_str_tokens(prompt)
print(len(words))
cv.tokens.colored_tokens(words, words)

prompt: The City of Lights iconic landmark
7


In [12]:
# compute whole cache
print("computing cache ...")
# get whole hidden states
_, cache = model.run_with_cache(prompt)
repr = cache[tl.utils.get_act_name("resid_post", layer, "")][
    0, :, :
]  # 1 x n_tokens x dim
repr = repr.detach().cuda()
TaskVec = TaskVec.detach().to(repr.dtype).cuda()
print(repr.shape, repr.dtype)
data = [
    {
        "id": i,
        "text": prompt,
        "representation": repr[i],
        "tok": words[i],
    }
    for i in range(len(words))
]
infer_entities(model, TaskVec, data, with_context=with_context)

computing cache ...
torch.Size([7, 2560]) torch.bfloat16


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  1.70it/s]


### Inference at given layer

In [13]:
print(f"Prompt : '{data[0]['text']}' \n")
print("Token".center(20), f"| Inferred at layer {layer}")
print("-" * 60)
for it in data:
    print(f"{it['tok'].center(20)} | {it['inferred']}")

# smoother display in notebook
html = cv.tokens.colored_tokens(words, [it["inferred"] for it in data])
html.cdn_src = html.cdn_src.replace("margin: 15px", "margin: 50px")
html

Prompt : 'The City of Lights iconic landmark' 

       Token         | Inferred at layer 20
------------------------------------------------------------
   <|endoftext|>     | C.
        The          | B.
        City         | City
         of          | City of
       Lights        | City of Lights
       iconic        | iconic
      landmark       | landmark


# Entity Lens
Now that we have a LLM and a trained task vector $\theta_\ell$ loaded, we can infer entities from any representations.
We showcase here the *Entity Lens*, that generates a mention for each token considered at specifed layers, allowing to visualize to what *entity* the model is *thinking* in its internal representations.

In [16]:
#define function

def EntityLens(
    model: HookedTransformer,
    layers: list,
    prompt: str = "The City of Lights iconic landmark",
    with_context: bool = False,
    use_TV_layer: bool = None,
    use_filter: bool = False,
    verbose: bool = False,
    compute_logit_lens: bool = False,
    output: str = "fancy",
    plot_size: tuple = (600, 5000),
):
    """
    Compute the Entity Lens for a given model and prompt.
    Args:
        model (HookedTransformer): The model to use.
        layers (list): The layers to use.
        prompt (str): The prompt to use.
        with_context (bool): Whether to use context or not.
        use_TV_layer (int): what Task Vector to use.
            - default is None: use the TV from each layer,
            - if int, use the TV from that layer for all layers.
        use_filter (bool): Whether to use the linear filter or not.
        verbose (bool): Whether to print verbose output or not.
        compute_logit_lens (bool): Whether to compute the logit lens or not.
        output (str): The output format. Can be
            - "fancy": fancy table using plotly
            - "markdown": markdown table
            - "csv": csv file
    """
    # use one TV for each layer
    # use one specific TV for all layers

    if verbose:
        print(
            f"Computing {'contextual' if with_context else 'uncontextual'} Entity Lens at layers {layers} of {model.cfg.model_name}"
        )

    str_tokens = model.to_str_tokens(prompt)
    tok_ids = range(1, len(str_tokens))  # skip the first token
    str_tokens = [str_tokens[t] for t in tok_ids]
    dtype = model.W_E.dtype


    # get whole cache
    with torch.no_grad():
        _, cache = model.run_with_cache(prompt)
        reprs = []
        for layer in layers:
            if layer == -1:
                hook = tl.utils.get_act_name("embed")
            else:
                hook = tl.utils.get_act_name("resid_post", layer, "")
            reprs.append(cache[hook].detach().cpu())  # 1 x n_tokens x dim
        del cache
        gc.collect()
        torch.cuda.empty_cache()

    def get_TV(layer, dtype=torch.float32):
        fileName = get_taskVec(
            results,
            model_name,
            layer=layer,
            dataset_name=dataset_name,
            with_context=with_context,
            verbose=False,
        )
        if verbose:
            print(f" layer {layer}, TaskVec loaded from {fileName}")
        TaskVec = torch.load(fileName, weights_only=True).to(dtype)
        return TaskVec

    if use_TV_layer is not None:
        TaskVec = get_TV(use_TV_layer, dtype=dtype)
        if verbose:
            print(f"Using TaskVec from layer {use_TV_layer}")
    else:
        if verbose:
            print("will load TaskVec from every layer")

    if verbose:
        print(
            f"infering entities from all tokens in prompt with{'' if with_context else 'out'} context..."
        )
    
    
    data = []
    for i, l in enumerate(layers):
        # load the task vector
        if use_TV_layer is None:
            TaskVec = get_TV(l, dtype=dtype)

        data_l = [
            {"representation": reprs[i][0, t, :], "layer": l, "text": prompt}
            for t in tok_ids
        ]

        # add the id
        for i, d in enumerate(data_l):
            d["id"] = i

        # infer entities
        infer_entities(
            model,
            TaskVec,
            data_l,
            with_context=with_context,
            max_tokens=8,
            b_size=5,
            verbose=False,
            # return_ppx = True,
            # return_likelihood = True,
        )

        if compute_logit_lens:
            for row in data_l:
                row["logitLens"] = utils.project_on_vocab(model, row["representation"], k = 1)
        
        data += data_l

    # generate the table
    strings = []

    #add leogit Lens prediction
    for l in layers:
        strings.append(
            [d["inferred"] + (f" ({d['logitLens'].replace(' ','_')})" if compute_logit_lens else "") for d in data if d["layer"] == l]
            )
        # ppx.append([d["perplexity"] for d in data if d["layer"] == l])
        # log_likelihoods.append([d["likelihood"] for d in data if d["layer"] == l])

    row_headers = [f"{l+1}" for l in layers]
    if layers[0] == -1:
        row_headers[0] = "Emb"
    # if verbose: Print as markdown table
    if output == "markdown":
        print()
        print(f"Entity Lens for {model_name} on {dataset_name}")
        width = 30
        sep = ";"
        print()
        print("-" * (width + 1) * (len(str_tokens) + 1))
        print(sep.join([tok.center(width) for tok in ["token"] + str_tokens]))
        print("-" * (width + 1) * (len(str_tokens) + 1))
        for i, row in enumerate(strings):
            print(sep.join([it.center(width) for it in row_headers[i : i + 1] + row]))
        print("-" * (width + 1) * (len(str_tokens) + 1))

    elif output == "fancy":
        header = ["l"] + str_tokens  # empty for the row headers column
        cells = [[row] + data for row, data in zip(row_headers, strings)]
        columns = list(map(list, zip(*cells)))

        size = 50

        # Create minimal figure
        fig = go.Figure(
            data=[
                go.Table(
                    columnwidth=[10] + [50]*len(str_tokens),
                    header=dict(
                        # fill_color="white",
                        # line_color="black",
                        values=header,
                        align="center",
                        font=dict(size=size+5,family="Times New Roman Black"),
                        height=1.6*size,
                        
                    ),
                    cells=dict(
                        values=columns,
                        fill_color="white",  # no color
                        line_color="gray",
                        align="center",
                        font=dict(size=size,family="Times New Roman"),
                        height= 1.5*size,
                    ),
                )
            ]
        )
        
        h ,w = plot_size
        # Remove padding around figure
        fig.update_layout(
            margin=dict(t=9, b=9, l=9, r=9),
            height=h,  # compact height
            width=w,   # compact width
        )
        fig.show()
    
    return data



### Inference

In [20]:
import plotly.graph_objects as go


every_N = 5
layers = [-1] + list(range(0, model.cfg.n_layers, every_N))
layers = [-1, 5, 10, 20, 25, 31]

data = EntityLens(
    model,
    layers=layers,
    prompt=prompt,
    with_context=with_context,
    # use_TV_layer=20,  # 20
    # use_TV_layer=None,  # 20
    use_filter=False,
    verbose=True,
    compute_logit_lens=True,
    output="markdown",
    plot_size=(740, 3500),  # (600, 5000) for the notebook
)

Computing uncontextual Entity Lens at layers [-1, 5, 10, 20, 25, 31] of phi-2
will load TaskVec from every layer
infering entities from all tokens in prompt without context...
 layer -1, TaskVec loaded from /home/morand/code/entityrepresentations/results/51ba5755b0a50e37cf68a3fcc60185fda38f2196b0fe73304e1510e21bdf4bb1/TaskVec_phi-2_l-1_e9.993406593406593.pth
 layer 5, TaskVec loaded from /home/morand/code/entityrepresentations/results/2222d55bab9b2962526dbe211c46f05b7fdb1e454b76571ccfc2cb8ed63e3bd9/TaskVec_phi-2_l5_e10.0.pth
 layer 10, TaskVec loaded from /home/morand/code/entityrepresentations/results/67284b69fb65b648fa447911c12daa0885c1b7a0f79d1088b0544bf401f65e43/TaskVec_phi-2_l10_e10.0.pth
 layer 20, TaskVec loaded from /home/morand/code/entityrepresentations/results/3320fdb2ffd15372a5a0b5a011871e86a696604b8335bfe874f2e363e1e91987/TaskVec_phi-2_l20_e10.0.pth
 layer 25, TaskVec loaded from /home/morand/code/entityrepresentations/results/8883d6a4d505993f5fcb478a7aca448a97824b6883a7f6

# Interactive visualization

In [18]:
from IPython.display import display
import ipywidgets as widgets

def interactive_plot():
    # nonlocal with_context, prompt, words, buttons, model, TaskVec
    # Callback function to update selected word
    def on_button_click(b):
        layer = dropdown.value
        repr = utils.get_representation(
            model,
            layer=layer,
            tokens=model.to_tokens(prompt),
            token_inds=torch.tensor([b.id]),
            verbose=True,
        )
        # change button color
        b.style.button_color = "lightgreen"
        # change all others to default
        for button in buttons:
            if button.id != b.id:
                button.style.button_color = "white"
        data = [{"id": 0, "representation": repr, "text": prompt}]
        infer_entities(model, TaskVec, data, with_context=with_context)
        selected_word_label.value = (
            f"Token: '{b.description}' \n Inferred: '{data[0]['inferred']}'"
        )
        # print("generation:", data[0]["inferred"])


    # Callback function to update prompt
    def on_text_submit(change):
        global prompt, words, buttons, button_box
        prompt = change["new"]
        words = model.to_str_tokens(prompt)
        print(words)
        print(len(words))

        # Update buttons
        buttons = []
        for ind, token in enumerate(words):
            buttons.append(
                clickableToken(
                    id=ind,
                    description=token,
                    layout=widgets.Layout(width="auto", margin="2px", padding="0 5px"),
                )
            )
        button_box.children = buttons


    # Textbox widget for entering prompt
    textbox = widgets.Text(
        value=prompt,
        description="Enter prompt:",
        disabled=False,
        layout=widgets.Layout(width="100%"),
    )
    textbox.observe(on_text_submit, names="value")

    # Buttons for each word
    buttons = []


    class clickableToken(widgets.Button):
        def __init__(self, id: int, description: str, **kwargs):
            super().__init__(description=description, **kwargs)
            self.description = description
            self.id = id
            self.on_click(on_button_click)


    for ind, token in enumerate(words):
        buttons.append(
            clickableToken(
                id=ind,
                description=token,
                layout=widgets.Layout(width="auto", margin="2px", padding="0 5px"),
            )
        )

    # Box widget with flexible wrapping
    button_box = widgets.Box(
        children=buttons,
        layout=widgets.Layout(display="flex", flex_flow="row wrap", align_items="center"),
    )
    # Dropdown widget for selecting an integer
    dropdown = widgets.Dropdown(
        options=[(f"layer {i}", i) for i in range(len(model.blocks) - 1, -1, -1)],
        description="Select a layer:",
        disabled=False,
    )
    dropdown.value = layer


    def on_checkbox_change(change):
        global with_context
        with_context = change["new"]


    # Create the checkbox widget
    with_context_checkbox = widgets.Checkbox(
        value=with_context, description="Generate with context", disabled=False
    )
    # Link the function to the checkbox change event
    with_context_checkbox.observe(on_checkbox_change, names="value")

    # Label to display the selected word
    selected_word_label = widgets.Label()

    # Create a title using HTML widget
    title = widgets.HTML(value="<h3>Select a Token to generate from:</h3>")

    # Display widgets
    display(textbox)
    display(title)
    display(button_box)
    # display checkbox and dropdown side by side
    display(widgets.HBox([with_context_checkbox, dropdown]))
    display(selected_word_label)


In [19]:
interactive_plot()

Text(value='The City of Lights iconic landmark', description='Enter prompt:', layout=Layout(width='100%'))

HTML(value='<h3>Select a Token to generate from:</h3>')

Box(children=(clickableToken(description='<|endoftext|>', layout=Layout(margin='2px', padding='0 5px', width='…

HBox(children=(Checkbox(value=False, description='Generate with context'), Dropdown(description='Select a laye…

Label(value='')