# Install additional packages if haven't already done so
# !pip install matplotlib nltk spacy textacy
# !python -m spacy download en_core_web_sm
import itertools
import random
import matplotlib.pyplot as plt
import nltk
import numpy as np
import spacy
import textacy
import torch
from matplotlib.gridspec import GridSpec
from nltk import word_tokenize
from nltk.corpus import framenet as fn
from nltk.tokenize import word_tokenize
from spacy.symbols import nsubj, VERB
from tqdm import tqdm
from transformers import (
GPT2Config,
GPT2LMHeadModel,
GPT2Tokenizer
)
nlp = spacy.load("en_core_web_sm")
%matplotlib inline
plt.style.use('seaborn')
def visualize_single(att_map, tokens, n_layer, n_head):
"""
Attention map for a given layer and head
"""
plt.figure(figsize=(16, 12))
crop_len = len(tokens)
plt.imshow(att_map[n_layer, n_head, :crop_len, :crop_len], cmap='Reds')
plt.xticks(range(crop_len), tokens, rotation=60, fontsize=12)
plt.yticks(range(crop_len), tokens, fontsize=12)
plt.grid(False)
def visualize_all(attn, crop_len, n_layers=12, n_heads=12, title=""):
"""
Full grid of attention maps [12x12]
"""
fig, axes = plt.subplots(n_layers, n_heads, figsize=(15, 12), sharex=True, sharey=True)
for i in range(n_layers):
for j in range(n_heads):
im = axes[i, j].imshow(attn[i, j, :crop_len, :crop_len], cmap='Oranges')
axes[i, j].axis('off')
fig.colorbar(im, ax=axes.ravel().tolist())
fig.suptitle(title, fontsize=20)
def visualize_before_and_after(before, after, title='', cmap="Greens"):
"""
Visualize the difference between base BERT and fine-tuned BERT
"""
fig, axes = plt.subplots(1, 2, figsize=(20, 10))
ax1, ax2 = axes[0], axes[1]
vmax = max(np.max(before), np.max(after))
im = ax1.imshow(before, cmap=cmap, vmax=vmax)
ax1.set_title('Base model')
ax1.grid(False)
im = ax2.imshow(after, cmap=cmap, vmax=vmax)
ax2.set_title('Fine-tuned model')
ax2.grid(False)
fig.colorbar(im, ax=axes.ravel().tolist())
fig.suptitle(title, fontsize=20)
# See spacy docs for tag-pos relation
def detect_all_pos(sentence, pos='PRON'):
"""
Detect all tokens with a given POS tag
"""
if pos not in ['PRON', 'VERB', 'NOUN']:
raise ValueError("POS not recognized")
pos2tag = {'PRON': ['PRP', 'PRP$'],
'NOUN': ['NN', 'NNP', 'NNPS', 'NNS'],
'VERB': ['VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ']}
doc = nlp(sentence, disable=['ner', 'parser'])
targets = []
for token in doc:
if token.tag_ in pos2tag[pos]:
targets.append(token.text)
return set(targets)
def detect_all_negations(sentence):
"""
Check if there is a negation phrase in a sentence out of a list of manually curated negations
"""
negation_words = ['neither', 'nor', 'not', 'never', 'none', "don't", "won't", "didn't",
"hadn't", "haven't", "can't", "isn't", "wasn't", "shouldn't", "couldn't", "nothing", "nowhere"]
targets = [word for word in negation_words if word in sentence]
return set(targets)
def detect_all_dep(sentence, label):
"""
Get subject-object dependencies
"""
doc = nlp(sentence, disable=['ner', 'pos'])
label2dep = {'SUBJ': ['nsubj', "nsubjpass", "csubj", "csubjpass", "agent", "expl"],
"OBJ": ['dobj', 'iobj', "dative", "attr", "oprd"]}
targets = []
for token in doc:
if token.dep_ in label2dep[label]:
targets.append(token.text)
return set(targets)
def get_max_target_weight(attn, target_indices):
"""
Get the maximum attn weight out of target tokens (given by their indices)
"""
if not target_indices:
return 0
avg_attn = np.mean(attn, axis=0)
target_weights = avg_attn[target_indices]
max_target_weight = np.max(target_weights)
return max_target_weight
def encode_input_text(text, tokenizer):
tokenized_text = tokenizer.tokenize(text)
ids = torch.LongTensor(tokenizer.convert_tokens_to_ids(tokenized_text))
return tokenizer.build_inputs_with_special_tokens(ids)
def analyze_target_attention(sentence, max_len, model, feature='NOUN', n_layers=12, n_heads=12):
"""
Analyze the attention weights for a sentence and for a given syntactic feature
"""
weights = np.zeros((n_layers, n_heads))
tokens = tokenizer.tokenize(sentence)
if feature in ["NOUN", "PRON", "VERB"]:
target_feat = detect_all_pos(sentence, feature)
elif feature == "NEG":
target_feat = detect_all_negations(sentence)
elif feature in ["SUBJ", "OBJ"]:
target_feat = detect_all_dep(sentence, feature)
tokens_feat = list(itertools.chain.from_iterable([tokenizer.tokenize(feat) for feat in target_feat]))
feat_indices = [i for i, token in enumerate(tokens) if token in tokens_feat]
input_ids = encode_input_text(sentence, tokenizer)
_, _, output = model(input_ids)
output = torch.stack(output).detach().numpy()
for l in range(n_layers):
for h in range(n_heads):
weights[l, h] = get_max_target_weight(output[l, h, :, :], feat_indices)
return weights
def extract_subj_verb(sentence):
"""
Get subject-verb dependencies
"""
doc = nlp(sentence)
subj_verb = []
for possible_subject in doc:
if possible_subject.dep == nsubj and possible_subject.head.pos == VERB:
subj_verb.append((possible_subject.text, possible_subject.head.text))
return subj_verb
def read_dataset(file_path, tokenizer, block_size):
"""
Read text file and convert to token ids with tokenizer, truncate into blocks of block_size
"""
data = []
with open(file_path, encoding="utf-8") as f:
text = f.read()
tokenized_text = tokenizer.tokenize(text)
# pdb.set_trace()
ids = torch.LongTensor(tokenizer.convert_tokens_to_ids(tokenized_text))
for i in range(0, len(ids) - block_size + 1, block_size): # Truncate in block of block_size
data.append(tokenizer.build_inputs_with_special_tokens(ids[i : i + block_size]))
return data
def read_dataset_text(file_path, block_size):
"""
Read text file and tokenize into words, truncate into blocks of block_size, join together into strings
"""
data = []
with open(file_path, encoding="utf-8") as f:
text = f.read()
tokens = word_tokenize(text)
# tokenizer.tokenize(text)
for i in range(0, len(tokens) - block_size + 1, block_size): # Truncate in block of block_size
data.append(' '.join(tokens[i : i + block_size]))
return data
max_len = 60
path_to_data = './examples/seinfeld/all_scripts.txt'
task = 'Seinfeld'
# Model property: https://huggingface.co/transformers/v2.2.0/pretrained_models.html
n_layers = 12
n_heads = 12
# path_to_model = '../pretrained_models/v2/{}/fine-tuned/'.format(task.lower())
finetuned_model_path = './examples/output/'
device = 'cpu'
baseBERT
and fine-tunedBERT
]Note the output_attentions
param of the config object is set to True
for model to output attention_probs
# Code from examples/run_lm_finetuning.py, how to use pretrained model and tokenizer(this paper used pretrain tokenizer throughout)
config_class = GPT2Config
model_class = GPT2LMHeadModel
tokenizer_class = GPT2Tokenizer
config = config_class.from_pretrained('gpt2')
# Change default config to make model output attention probs as well
config.output_attentions = True
tokenizer = tokenizer_class.from_pretrained('gpt2')
model_base = model_class.from_pretrained(
'gpt2',
config=config
)
model_base.eval()
model_base.to(device)
# Load up the fine-tuned model trained on seinfeld dataset
finetuned_model_path = './examples/output'
model_finetuned = model_class.from_pretrained(
finetuned_model_path,
config=config
)
model_finetuned.eval()
model_finetuned.to(device)
import pdb
data = read_dataset(path_to_data, tokenizer, max_len)
data[1]
## for big datasets
if len(data) > 1000:
data = random.sample(data, 1000)
test_sent = 'JERRY: Went out to dinner the other night. Check came at the end of the meal, as it always does. Never liked the check at the end of the meal system'
rand_example = encode_input_text(test_sent, tokenizer)
_, _, rand_attn_finetuned = model_finetuned(rand_example)
_, _, rand_attn_base = model_base(rand_example)
rand_attn_finetuned = torch.stack(rand_attn_finetuned).detach().numpy()
rand_attn_base = torch.stack(rand_attn_base).detach().numpy()
crop_len = len(rand_example)
visualize_all(rand_attn_finetuned, crop_len, title="Random {} example attention map: fine-tuned model".format(task))
visualize_all(rand_attn_base, crop_len, title="Random {} example attention map: pre-trained model".format(task))
rand_example
np.sum(np.isclose(rand_attn_base, rand_attn_finetuned))
# About half of fine-tuned weights 'close' to base weights
#rand_example = random.choice(data)
tokens = tokenizer.convert_ids_to_tokens(rand_example, skip_special_tokens=False)
print(tokens)
_, _, rand_attn_finetuned = model_finetuned(rand_example)
rand_attn_finetuned = torch.stack(rand_attn_finetuned).detach().numpy()
visualize_single(rand_attn_finetuned, tokens, 11, 8)
all_similarities = []
for example in tqdm(data):
_, _, rand_attn_finetuned = model_finetuned(example)
_, _, rand_attn_base = model_base(example)
rand_attn_finetuned = torch.stack(rand_attn_finetuned)
rand_attn_base = torch.stack(rand_attn_base)
#print(rand_attn_finetuned.shape)
finetuned_vec = rand_attn_finetuned.squeeze(0).view(n_layers, n_heads, -1)
base_vec = rand_attn_base.squeeze(0).view(n_layers, n_heads, -1)
sim_map = torch.nn.functional.cosine_similarity(base_vec, finetuned_vec, dim=-1).detach().numpy()
all_similarities.append(sim_map)
all_similarities = np.stack(all_similarities, axis=-1)
avg_sim = np.mean(all_similarities, axis=-1)
plt.figure(figsize=(12, 9))
plt.imshow(avg_sim, cmap='Blues_r', vmin=0, vmax=1)
plt.colorbar()
plt.grid(False)
plt.title('Cosine similarity map for {}. Averaged over examples'.format(task))
plt.xlabel('Head id')
plt.ylabel('Layer id')
text_data = read_dataset_text(path_to_data, max_len)
if len(text_data) > 1000:
text_data = random.sample(text_data, 1000)
pos = 'NOUN'
all_weights_finetuned = []
all_weights_base = []
for example in tqdm(text_data):
weights_finetuned = analyze_target_attention(example, max_len, model_finetuned, feature=pos)
weights_base = analyze_target_attention(example, max_len, model_base, feature=pos)
all_weights_finetuned.append(weights_finetuned)
all_weights_base.append(weights_base)
all_weights_finetuned = np.stack(all_weights_finetuned, axis=-1)
all_weights_base = np.stack(all_weights_base, axis=-1)
avg_weights_finetuned = np.mean(all_weights_finetuned, axis=-1)
avg_weights_base = np.mean(all_weights_base, axis=-1)
visualize_before_and_after(avg_weights_base, avg_weights_finetuned, 'Attention to NOUN tags')
pos = 'VERB'
all_weights_finetuned = []
all_weights_base = []
for example in tqdm(text_data):
weights_finetuned = analyze_target_attention(example, max_len, model_finetuned, feature=pos)
weights_base = analyze_target_attention(example, max_len, model_base, feature=pos)
all_weights_finetuned.append(weights_finetuned)
all_weights_base.append(weights_base)
all_weights_finetuned = np.stack(all_weights_finetuned, axis=-1)
all_weights_base = np.stack(all_weights_base, axis=-1)
avg_weights_finetuned = np.mean(all_weights_finetuned, axis=-1)
avg_weights_base = np.mean(all_weights_base, axis=-1)
visualize_before_and_after(avg_weights_base, avg_weights_finetuned, 'Attention to VERB tags')
pos = 'PRON'
all_weights_finetuned = []
all_weights_base = []
for example in tqdm(text_data):
weights_finetuned = analyze_target_attention(example, max_len, model_finetuned, feature=pos)
weights_base = analyze_target_attention(example, max_len, model_base, feature=pos)
all_weights_finetuned.append(weights_finetuned)
all_weights_base.append(weights_base)
all_weights_finetuned = np.stack(all_weights_finetuned, axis=-1)
all_weights_base = np.stack(all_weights_base, axis=-1)
avg_weights_finetuned = np.mean(all_weights_finetuned, axis=-1)
avg_weights_base = np.mean(all_weights_base, axis=-1)
visualize_before_and_after(avg_weights_base, avg_weights_finetuned, 'Attention to PRON tags')
feature = 'NEG'
all_weights_finetuned = []
all_weights_base = []
for example in tqdm(text_data):
weights_finetuned = analyze_target_attention(example, max_len, model_finetuned, feature=pos)
weights_base = analyze_target_attention(example, max_len, model_base, feature=pos)
all_weights_finetuned.append(weights_finetuned)
all_weights_base.append(weights_base)
all_weights_finetuned = np.stack(all_weights_finetuned, axis=-1)
all_weights_base = np.stack(all_weights_base, axis=-1)
avg_weights_finetuned = np.mean(all_weights_finetuned, axis=-1)
avg_weights_base = np.mean(all_weights_base, axis=-1)
visualize_before_and_after(avg_weights_base, avg_weights_finetuned, 'Attention to NEG tokens')
feature = 'SUBJ'
all_weights_finetuned = []
all_weights_base = []
for example in tqdm(text_data):
weights_finetuned = analyze_target_attention(example, max_len, model_finetuned, feature=pos)
weights_base = analyze_target_attention(example, max_len, model_base, feature=pos)
all_weights_finetuned.append(weights_finetuned)
all_weights_base.append(weights_base)
all_weights_finetuned = np.stack(all_weights_finetuned, axis=-1)
all_weights_base = np.stack(all_weights_base, axis=-1)
avg_weights_finetuned = np.mean(all_weights_finetuned, axis=-1)
avg_weights_base = np.mean(all_weights_base, axis=-1)
visualize_before_and_after(avg_weights_base, avg_weights_finetuned, 'Attention to SUBJ tokens')
feature = 'OBJ'
all_weights_finetuned = []
all_weights_base = []
for example in tqdm(text_data):
weights_finetuned = analyze_target_attention(example, max_len, model_finetuned, feature=pos)
weights_base = analyze_target_attention(example, max_len, model_base, feature=pos)
all_weights_finetuned.append(weights_finetuned)
all_weights_base.append(weights_base)
all_weights_finetuned = np.stack(all_weights_finetuned, axis=-1)
all_weights_base = np.stack(all_weights_base, axis=-1)
avg_weights_finetuned = np.mean(all_weights_finetuned, axis=-1)
avg_weights_base = np.mean(all_weights_base, axis=-1)
visualize_before_and_after(avg_weights_base, avg_weights_finetuned, 'Attention to OBJ tokens')