-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathencoding.py
More file actions
30 lines (24 loc) · 809 Bytes
/
encoding.py
File metadata and controls
30 lines (24 loc) · 809 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
# coding: utf-8
import numpy as np
def encoding( data, encode_type, time_length, vocab_size ):
if encode_type == '1hot':
return onehot_encoding(data, time_length, vocab_size)
elif encode_type == 'embedding':
return data
def onehot_encoding( data, time_length, vocab_size):
X = np.zeros((len(data), time_length, vocab_size), dtype=np.bool)
for i, sent in enumerate(data):
for j, k in enumerate(sent):
X[i, j, k] = 1
return X
def history_build(indata, pad_X):
H = list()
for i, pos in enumerate(indata.dataset['startid']):
if i == pos:
his = []
else:
his = list(pad_X[pos])
for j in range(pos + 1, i):
his += list(pad_X[j])
H.append(list(his))
return H