import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments
from datasets import load_dataset
# Load the tokenizer and model
model_name = "gpt2-large"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
# Add padding token to the tokenizer
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model = GPT2LMHeadModel.from_pretrained(model_name)
model.resize_token_embeddings(len(tokenizer)) # Adjust the model's embedding size
# Check if MPS (Metal Performance Shaders) is available
device = torch.device("mps" if torch.has_mps else "cpu")
model.to(device)
# Load the haiku dataset
dataset = load_dataset("davanstrien/haiku_kto")
# Inspect dataset keys
print(dataset['train'][0])
# Split the training data into train and validation sets (90% train, 10% validation)
train_val_split = dataset['train'].train_test_split(test_size=0.1)
train_data = train_val_split['train']
val_data = train_val_split['test']
# Extract haiku texts and tokenize them
def extract_and_tokenize_function(batch):
haikus = []
for example in batch['messages']:
haiku = next((message['content'] for message in example if message['role'] == 'assistant'), None)
if haiku:
haikus.append(haiku)
tokenized = tokenizer(haikus, truncation=True, padding='max_length', max_length=50)
input_ids = torch.tensor(tokenized['input_ids'])
attention_mask = torch.tensor(tokenized['attention_mask'])
labels = input_ids.clone()
labels[labels == tokenizer.pad_token_id] = -100
return {'input_ids': input_ids, 'attention_mask': attention_mask, 'labels': labels}
# Apply tokenization function with batching
train_data = train_data.map(extract_and_tokenize_function, batched=True, remove_columns=train_data.column_names)
val_data = val_data.map(extract_and_tokenize_function, batched=True, remove_columns=val_data.column_names)
train_data.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
val_data.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
# Training arguments
training_args = TrainingArguments(
output_dir="./results",
overwrite_output_dir=True,
num_train_epochs=3,
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
warmup_steps=500,
weight_decay=0.01,
logging_dir='./logs',
logging_steps=10,
save_strategy="epoch", # Save strategy set to "epoch"
evaluation_strategy="epoch",
save_total_limit=2,
load_best_model_at_end=True,
)
# Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_data,
eval_dataset=val_data,
)
# Fine-tune the model
trainer.train()
# Save the fine-tuned model
model.save_pretrained("./fine-tuned-haiku-model")
tokenizer.save_pretrained("./fine-tuned-haiku-model")