Skip to content

Instantly share code, notes, and snippets.

@davidmezzetti
Created November 7, 2025 14:05
Show Gist options
  • Select an option

  • Save davidmezzetti/2724bbbd254efa13ddf157a3896a7c02 to your computer and use it in GitHub Desktop.

Select an option

Save davidmezzetti/2724bbbd254efa13ddf157a3896a7c02 to your computer and use it in GitHub Desktop.
from datasets import load_dataset
from sklearn.metrics import accuracy_score
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from txtai.pipeline import HFTrainer
def metrics(pred):
labels, preds = pred.label_ids, pred.predictions.argmax(-1)
# Calculate accuracy
return {"accuracy": accuracy_score(labels, preds)}
train = load_dataset("stanfordnlp/imdb", split="train")
test = load_dataset("stanfordnlp/imdb", split="test")
trainer = HFTrainer()
path = "neuml/bert-hash-nano"
model = AutoModelForSequenceClassification.from_pretrained(path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
trainer((model, tokenizer), train, test, metrics=metrics, output_dir="imdb")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment