Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions models/causal_bert_pytorch/CausalBert.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from torch.utils.data import Dataset, TensorDataset, DataLoader, RandomSampler, SequentialSampler

from transformers import BertTokenizer
from transformers import BertModel, BertPreTrainedModel, AdamW, BertConfig
from transformers import BertModel, BertPreTrainedModel, BertConfig
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup

from transformers import DistilBertTokenizer
Expand Down Expand Up @@ -305,15 +306,15 @@ def collate_CandT(data):
return dataloader


if __name__ == '__main__':
import pandas as pd
# if __name__ == '__main__':
# import pandas as pd

df = pd.read_csv('testdata.csv')
cb = CausalBertWrapper(batch_size=2,
g_weight=0.1, Q_weight=0.1, mlm_weight=1)
print(df.T)
cb.train(df['text'], df['C'], df['T'], df['Y'], epochs=1)
print(cb.ATE(df['C'], df.text, platt_scaling=True))
# df = pd.read_csv('testdata.csv')
# cb = CausalBertWrapper(batch_size=2,
# g_weight=0.1, Q_weight=0.1, mlm_weight=1)
# print(df.T)
# cb.train(df['text'], df['C'], df['T'], df['Y'], epochs=1)
# print(cb.ATE(df['C'], df.text, platt_scaling=True))



Expand Down