-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcomputational_graph_from_strach.py
More file actions
110 lines (78 loc) · 2.49 KB
/
computational_graph_from_strach.py
File metadata and controls
110 lines (78 loc) · 2.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import numpy as np
class Operation:
def __init__(self, input_nodes):
self.input_nodes = input_nodes
self.output_nodes = []
_default_graph.operations.append(self)
for node in input_nodes:
node.output_nodes.append(self)
def compute(self, *args):
pass
class Add(Operation):
def __init__(self, x, y):
super().__init__([x, y])
def compute(self, x_var, y_var):
self.inputs = [x_var, y_var]
return x_var + y_var
class Multiply(Operation):
def __init__(self, x, y):
super().__init__([x, y])
def compute(self, x_var, y_var):
self.inputs = [x_var, y_var]
return x_var * y_var
class MatrixMultiply(Operation):
def __init__(self, x, y):
super().__init__([x, y])
def compute(self, x_var, y_var):
self.inputs = [x_var, y_var]
return x_var.dot(y_var)
class PlaceHolder:
def __init__(self):
self.output_nodes = []
_default_graph.placeholders.append(self)
class Variable:
def __init__(self, initial_value):
self.value = initial_value
self.output_nodes = []
_default_graph.variables.append(self)
class Graph:
def __init__(self):
self.operations = []
self.placeholders = []
self.variables = []
def set_as_default(self):
global _default_graph
_default_graph = self
def traverse_postorder(operation):
nodes_postorder = []
def recurse(node):
if isinstance(node, Operation):
for input_node in node.input_nodes:
recurse(input_node)
nodes_postorder.append(node)
recurse(operation)
return nodes_postorder
class Session:
def run(self, operation, feed_dict):
nodes_postorder = traverse_postorder(operation)
for node in nodes_postorder:
if type(node) == PlaceHolder:
node.output = feed_dict[node]
elif type(node) == Variable:
node.output = node.value
else:
node.inputs = [input_node.output for input_node in node.input_nodes]
node.output = node.compute(*node.inputs)
if type(node.output) == list:
node.output = np.array(node.output)
return operation.output
g = Graph()
g.set_as_default()
A = Variable([[10, 20], [30, 40]])
b = Variable([1, 1])
x = PlaceHolder()
y = Multiply(A, x)
z = Add(y, b)
sess = Session()
result = sess.run(operation=z, feed_dict={x: 10})
print(result)