-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathchapter_seven.py
More file actions
54 lines (44 loc) · 1.8 KB
/
chapter_seven.py
File metadata and controls
54 lines (44 loc) · 1.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
'''
- In previous chapter we saw how to create heatmaps.
- But heatmap is basically a 2D figure, so how can we plot a third value ?.
- We can create heatmap for all possible variables using plot_heatmaps().
- The plot would get saved as a html file.
- You can see from the example code below, that rsiWindow is a range, not a single number anymore.
'''
import pandas_ta as ta
import pandas as pd
import os
import matplotlib.pyplot as plt
from backtesting import Backtest
from backtesting import Strategy
from backtesting.lib import crossover, plot_heatmaps
from backtesting.test import GOOG
class RsiOscillator(Strategy):
upperBound = 70
lowerBound = 30
rsiWindow = 14
def init(self):
self.rsi = self.I(ta.rsi, pd.Series(self.data.Close), self.rsiWindow)
def next(self):
if crossover(self.rsi, self.upperBound):
self.position.close()
elif crossover(self.lowerBound, self.rsi):
self.buy()
bt = Backtest(GOOG, RsiOscillator, cash=10000)
stats, heatmap = bt.optimize(
upperBound=range(55, 85, 5),
lowerBound=range(10, 45, 5),
rsiWindow=range(10, 45, 5),
maximize='Sharpe Ratio',
# We are ensuring to only look the combination in which upperBound values are greater than lowerBound values.
# We can also make use of rsiWindow in it.
constraint=lambda param: param.upperBound > param.lowerBound,
# This parameter would be responsible for returning heatmaps.
return_heatmap=True
# max_tries = 100 # This option is very useful for avoiding overfitting, from all combinations I would randomly select 100 combinations (not all) & then give result according to that.
)
if not os.path.exists('plots'):
os.makedirs('plots')
fileName = f"plot.html"
# We can also see the plot getting saved in location plots.
plot_heatmaps(heatmap, agg='mean', filename=f"plots/{fileName}")