Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 159 additions & 16 deletions finplot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
from functools import partial, partialmethod
from finplot.live import Live
from math import ceil, floor, fmod
from wsgiref.headers import tspecials
import numpy as np
import os.path
import pandas as pd
import pyqtgraph as pg
from pyqtgraph import QtCore, QtGui
from pyqtgraph.dockarea.DockArea import DockArea



Expand Down Expand Up @@ -55,6 +57,10 @@
cross_hair_color = '#0007'
draw_line_color = '#000'
draw_done_color = '#555'
arrow_bull_color = '#20FF20'
arrow_bull_outline_color = '#222222'
arrow_bear_color = '#f7a9a7'
arrow_bear_outline_color = '#222222'
significant_decimals = 8
significant_eps = 1e-8
max_decimals = 10
Expand Down Expand Up @@ -780,6 +786,31 @@ def addScaleHandle(self, *args, **kwargs):
if self.resizable:
super().addScaleHandle(*args, **kwargs)

class FinArrow(pg.ArrowItem):
def __init__(self, ax, angle, brushColor='', penColor='', brushWidth=1, penWidth=1, *args, **kwargs):

kwargs['angle']=angle
kwargs['tipAngle']=30
kwargs['baseAngle']=20

kwargs['headLen']=8
kwargs['headWidth']=8
kwargs['tailLen']=6
kwargs['tailWidth']=5

if brushColor=='':
brushColor='#3030ff';
if penColor=='':
penColor='#000';

kwargs['pen']={'color': penColor, 'width': penWidth}
brush = pg.mkBrush(brushColor)
brush.width = brushWidth
kwargs['brush'] = brush

self.ax = ax

super().__init__(*args, **kwargs)

class FinViewBox(pg.ViewBox):
def __init__(self, win, init_steps=300, yscale=YScale('linear', 1), v_zoom_scale=1, *args, **kwargs):
Expand Down Expand Up @@ -1359,8 +1390,6 @@ def generate_picture(self, boundingRect):
p.setPen(pg.mkPen(poc_color))
p.drawLine(QtCore.QPointF(t, y), QtCore.QPointF(t+f*self.draw_poc, y))



class ScatterLabelItem(FinPlotItem):
def __init__(self, ax, datasrc, color, anchor):
self.color = color
Expand Down Expand Up @@ -1534,6 +1563,7 @@ def candlestick_ochl(datasrc, draw_body=True, draw_shadow=True, candle_width=0.6
_update_significants(ax, datasrc, force=True)
item.update_data = partial(_update_data, None, None, item)
item.update_gfx = partial(_update_gfx, item)
item.setZValue(40) # Skinok : candle should always be on top of any indicators
ax.addItem(item)
return item

Expand Down Expand Up @@ -1738,6 +1768,58 @@ def live(plots=1):
return [Live() for _ in range(plots)]


def add_order(datetime, price, isLong, ax=None):

# Open trade arrow
if isLong:
brushColor = arrow_bull_color
penColor = arrow_bull_outline_color
angle = 90

else:
brushColor = arrow_bear_color
penColor = arrow_bear_outline_color
angle = -90

add_arrow((datetime,price), angle, brushColor, penColor, ax=ax)

return

def add_trade(posOpen, posClose, isLong, isWinningTrade, ax=None):

# Open trade arrow
if isLong:
brushColor = arrow_bull_color
penColor = arrow_bull_outline_color
add_arrow(posOpen, 90, brushColor, penColor, ax=ax)

brushColor = arrow_bear_color
penColor = arrow_bear_outline_color
add_arrow(posClose, -90, brushColor, penColor, ax=ax)

else:
brushColor = arrow_bear_color
penColor = arrow_bear_outline_color
add_arrow(posOpen, -90, brushColor, penColor, ax=ax)

# Close trade arrow
brushColor = arrow_bull_color
penColor = arrow_bull_outline_color
add_arrow(posClose, 90, brushColor, penColor, ax=ax)

# Add dashed line
if isWinningTrade:
add_line(posOpen, posClose, "#30FF30", 2, style="--" )
else:
add_line(posOpen, posClose, "#FF3030", 2, style="..")

# Add label
mid = (posClose[0]-posOpen[0],posClose[1]-posOpen[1])
textPos = (posOpen[0] + mid[0], posOpen[1] + mid[1])
add_text(textPos,"+500")

return

def add_legend(text, ax=None):
ax = _create_plot(ax=ax, maximize=False)
_create_legend(ax)
Expand Down Expand Up @@ -1788,6 +1870,18 @@ def add_band(y0, y1, color=band_color, ax=None):
ax.addItem(lr)
return lr

def add_arrow(pos, angle, arrow_color='', arrow_outline_color='', arrow_brush_width=1, arrow_outline_width=1, interactive=False, ax=None):
ax = _create_plot(ax=ax, maximize=False)
arrow = FinArrow(ax, angle, arrow_color, arrow_outline_color, arrow_brush_width, arrow_outline_width)
x = pos[0]
if ax.vb.datasrc is not None:
x = _pdtime2index(ax, pd.Series([pos[0]]))[0]
y = ax.vb.yscale.invxform(pos[1])
arrow.setPos(x, y)
arrow.setZValue(50)
arrow.ax = ax
ax.addItem(arrow, ignoreBounds=True)
return arrow

def add_rect(p0, p1, color=band_color, interactive=False, ax=None):
ax = _create_plot(ax=ax, maximize=False)
Expand All @@ -1804,7 +1898,6 @@ def add_rect(p0, p1, color=band_color, interactive=False, ax=None):
ax.addItem(rect)
return rect


def add_line(p0, p1, color=draw_line_color, width=1, style=None, interactive=False, ax=None):
ax = _create_plot(ax=ax, maximize=False)
used_color = _get_color(ax, style, color)
Expand Down Expand Up @@ -1857,17 +1950,17 @@ def remove_primitive(primitive):
ax.vb.removeItem(txt)


def set_time_inspector(inspector, ax=None, when='click'):
def set_time_inspector(inspector, ax=None, when='click', data=None):

'''Callback when clicked like so: inspector(x, y).'''
ax = ax if ax else last_ax
master = ax.ax_widget if hasattr(ax, 'ax_widget') else ax.vb.win
if when == 'hover':
master.proxy_hover = pg.SignalProxy(master.scene().sigMouseMoved, rateLimit=15, slot=partial(_inspect_pos, ax, inspector))
master.proxy_hover = pg.SignalProxy(master.scene().sigMouseMoved, rateLimit=15, slot=partial(_inspect_pos, ax, data, inspector))
elif when in ('dclick', 'double-click'):
master.proxy_dclick = pg.SignalProxy(master.scene().sigMouseClicked, slot=partial(_inspect_clicked, ax, inspector, True))
master.proxy_dclick = pg.SignalProxy(master.scene().sigMouseClicked, slot=partial(_inspect_clicked, ax, data, inspector, True))
else:
master.proxy_click = pg.SignalProxy(master.scene().sigMouseClicked, slot=partial(_inspect_clicked, ax, inspector, False))

master.proxy_click = pg.SignalProxy(master.scene().sigMouseClicked, slot=partial(_inspect_clicked, ax, data, inspector, False))

def add_crosshair_info(infofunc, ax=None):
'''Callback when crosshair updated like so: info(ax,x,y,xtext,ytext); the info()
Expand Down Expand Up @@ -2605,15 +2698,13 @@ def _wheel_event_wrapper(self, orig_func, ev):
ev = QtGui.QWheelEvent(ev.position()+d, ev.globalPosition()+d, ev.pixelDelta(), ev.angleDelta(), ev.buttons(), ev.modifiers(), ev.phase(), False)
orig_func(self, ev)


def _inspect_clicked(ax, inspector, when_double_click, evs):
def _inspect_clicked(ax, data, inspector, when_double_click, evs):
if evs[-1].accepted or when_double_click != evs[-1].double():
return
pos = evs[-1].scenePos()
return _inspect_pos(ax, inspector, (pos,))
return _inspect_pos(ax, data, inspector, (pos,))


def _inspect_pos(ax, inspector, poss):
def _inspect_pos(ax, data, inspector, poss):
if not ax.vb.datasrc:
return
point = ax.vb.mapSceneToView(poss[-1])
Expand All @@ -2624,7 +2715,7 @@ def _inspect_pos(ax, inspector, poss):
if clamp_grid:
t = ax.vb.datasrc.x.iloc[-1 if t > 0 else 0]
try:
inspector(t, point.y())
inspector(t, point.y(), ax, data ) # or directly ax.vb.datasrc ?
except OSError as e:
pass
except Exception as e:
Expand Down Expand Up @@ -2655,6 +2746,8 @@ def _get_color(ax, style, wanted_color):
return colors[index%len(colors)]




def _pdtime2epoch(t):
if isinstance(t, pd.Series):
if isinstance(t.iloc[0], pd.Timestamp):
Expand All @@ -2669,7 +2762,59 @@ def _pdtime2epoch(t):
return t.astype('int64')
return t

#
# Skinok add
# Use case :
# In case of backtesting, this function allow the user to click on a particular trade (in a trade history panel)
# and the chart will automatically move & center on the position of this trade
# This function returns the x position in the dataset, given the entry date
#
def _dateStr2x(ax, dateStr, any_end=False, require_time=False):
ts = pd.Series(pd.to_datetime(dateStr))
if isinstance(ts.iloc[0], pd.Timestamp):
ts = ts.view('int64')
else:
h = np.nanmax(ts.values)
if h < 1e7:
if require_time:
assert False, 'not a time series'
return ts
if h < 1e10: # handle s epochs
ts = ts.astype('float64') * 1e9
elif h < 1e13: # handle ms epochs
ts = ts.astype('float64') * 1e6
elif h < 1e16: # handle us epochs
ts = ts.astype('float64') * 1e3

datasrc = _get_datasrc(ax)
xs = datasrc.x

# try exact match before approximate match
exact = datasrc.index[xs.isin(ts)].to_list()
if len(exact) == len(ts):
return exact

r = []
for i,t in enumerate(ts):
xss = xs.loc[xs>t]
if len(xss) == 0:
t0 = xs.iloc[-1]
if any_end or t0 == t:
r.append(len(xs)-1)
continue
if i > 0:
continue
assert t <= t0, 'must plot this primitive in prior time-range'
i1 = xss.index[0]
i0 = i1-1
if i0 < 0:
i0,i1 = 0,1
t0,t1 = xs.loc[i0], xs.loc[i1]
dt = (t-t0) / (t1-t0)
r.append(lerp(dt, i0, i1))
return r

# ts is "time series" here, not "timestamp"
def _pdtime2index(ax, ts, any_end=False, require_time=False):
if isinstance(ts.iloc[0], pd.Timestamp):
ts = ts.view('int64')
Expand Down Expand Up @@ -2915,11 +3060,9 @@ def _makepen(color, style=None, width=1):
dash[-1] += 2
return pg.mkPen(color=color, style=QtCore.Qt.PenStyle.CustomDashLine, dash=dash, width=width)


def _round(v):
return floor(v+0.5)


try:
qtver = '%d.%d' % (QtCore.QT_VERSION//256//256, QtCore.QT_VERSION//256%256)
if qtver not in ('5.9', '5.13') and [int(i) for i in pg.__version__.split('.')] <= [0,11,0]:
Expand Down