diff --git a/models/causal_bert_pytorch/CausalBert.py b/models/causal_bert_pytorch/CausalBert.py index d265ce1..71cea22 100644 --- a/models/causal_bert_pytorch/CausalBert.py +++ b/models/causal_bert_pytorch/CausalBert.py @@ -19,7 +19,8 @@ from torch.utils.data import Dataset, TensorDataset, DataLoader, RandomSampler, SequentialSampler from transformers import BertTokenizer -from transformers import BertModel, BertPreTrainedModel, AdamW, BertConfig +from transformers import BertModel, BertPreTrainedModel, BertConfig +from torch.optim import AdamW from transformers import get_linear_schedule_with_warmup from transformers import DistilBertTokenizer @@ -305,15 +306,15 @@ def collate_CandT(data): return dataloader -if __name__ == '__main__': - import pandas as pd +# if __name__ == '__main__': +# import pandas as pd - df = pd.read_csv('testdata.csv') - cb = CausalBertWrapper(batch_size=2, - g_weight=0.1, Q_weight=0.1, mlm_weight=1) - print(df.T) - cb.train(df['text'], df['C'], df['T'], df['Y'], epochs=1) - print(cb.ATE(df['C'], df.text, platt_scaling=True)) +# df = pd.read_csv('testdata.csv') +# cb = CausalBertWrapper(batch_size=2, +# g_weight=0.1, Q_weight=0.1, mlm_weight=1) +# print(df.T) +# cb.train(df['text'], df['C'], df['T'], df['Y'], epochs=1) +# print(cb.ATE(df['C'], df.text, platt_scaling=True))