-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmwe.py
More file actions
21 lines (21 loc) · 796 Bytes
/
mwe.py
File metadata and controls
21 lines (21 loc) · 796 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import numpy as np
from scipy.special import softmax
import csv
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
import json
import glob
import os
import pickle
import shutil
import time
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
task='offensive'
MODEL = f"cardiffnlp/twitter-roberta-base-{task}"
tokenizer = AutoTokenizer.from_pretrained(MODEL)
model = TFAutoModelForSequenceClassification.from_pretrained(MODEL)
tokenized = tokenizer(["Hello there", "Howdy hhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhhh"],padding=True, return_tensors='tf')
res = model.predict(tokenized['input_ids'], batch_size=100, use_multiprocessing=True)
print(res)