Skip to content

Commit 4be32bb

Browse files
committed
update direction arg in ohlc
- check that direction is ‘increasing’ or ‘decreasing’
1 parent 7679f5c commit 4be32bb

File tree

1 file changed

+125
-19
lines changed

1 file changed

+125
-19
lines changed

plotly/tools.py

Lines changed: 125 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1396,11 +1396,36 @@ def return_figure_from_figure_or_data(figure_or_data, validate_figure):
13961396
class TraceFactory(dict):
13971397

13981398
@staticmethod
1399+
def validate_equal_length(*args):
13991400
"""
1401+
Validates that data lists or ndarrays are the same length.
14001402
1403+
:raises: (PlotlyError) If any data lists are not the same length.
1404+
"""
1405+
length = len(args[0])
1406+
if any(len(lst) != length for lst in args):
1407+
raise exceptions.PlotlyError("Oops! Your data lists or ndarrays "
1408+
"should be the same length.")
14011409

14021410
@staticmethod
1411+
def validate_positive_scalars(**kwargs):
1412+
"""
1413+
Validates that all values given in key/val pairs are positive.
1414+
1415+
Accepts kwargs to improve Exception messages.
1416+
1417+
:raises: (PlotlyError) If any value is < 0 or raises.
1418+
"""
1419+
for key, val in kwargs.items():
1420+
try:
1421+
if val <= 0:
1422+
raise ValueError('{} must be > 0, got {}'.format(key, val))
1423+
except TypeError:
1424+
raise exceptions.PlotlyError('{} must be a number, got {}'
1425+
.format(key, val))
14031426

1427+
@staticmethod
1428+
def validate_streamline(x, y):
14041429
"""
14051430
streamline specific validations
14061431
@@ -1416,8 +1441,7 @@ class TraceFactory(dict):
14161441
if _numpy_imported is False:
14171442
raise ImportError("TraceFactory.create_streamline requires numpy.")
14181443
for index in range(len(x) - 1):
1419-
if ((x[index + 1] - x[index]) -
1420-
(x[1] - x[0])) > .0001:
1444+
if ((x[index + 1] - x[index]) - (x[1] - x[0])) > .0001:
14211445
raise exceptions.PlotlyError("x must be a 1 dimensional, "
14221446
"evenly spaced array")
14231447
for index in range(len(y) - 1):
@@ -1427,10 +1451,21 @@ class TraceFactory(dict):
14271451
"evenly spaced array")
14281452

14291453
@staticmethod
1430-
raise exceptions.PlotlyError("Oops! Your high, open, low, and "
1431-
"close lists should all be the same "
1432-
"length.")
1454+
def validate_ohlc(open, high, low, close, direction, **kwargs):
1455+
"""
1456+
ohlc specific validations
1457+
1458+
Specifically, this checks that the high value is greatest value and the
1459+
low value is the lowest value in each unit.
14331460
1461+
See TraceFactory.create_streamline() for params
1462+
1463+
:raises: (PlotlyError) If the high value is not the greatest value in
1464+
each unit.
1465+
:raises: (PlotlyError) If the low value is not the lowest value in each
1466+
unit.
1467+
:raises: (PlotlyError) If direction is not 'increasing' or 'decreasing'
1468+
"""
14341469
for lst in [open, low, close]:
14351470
for index in range(len(high)):
14361471
if high[index] < lst[index]:
@@ -1441,7 +1476,7 @@ class TraceFactory(dict):
14411476
"Double check that your data "
14421477
"is entered in O-H-L-C order")
14431478

1444-
for lst in [high, open, close]:
1479+
for lst in [open, high, close]:
14451480
for index in range(len(low)):
14461481
if low[index] > lst[index]:
14471482
raise exceptions.PlotlyError("Oops! Looks like some of "
@@ -1451,6 +1486,12 @@ class TraceFactory(dict):
14511486
"Double check that your data "
14521487
"is entered in O-H-L-C order")
14531488

1489+
if direction is 'increasing' or direction is 'decreasing':
1490+
pass
1491+
else:
1492+
raise exceptions.PlotlyError("direction must be defined as "
1493+
"'increasing' or 'decreasing'")
1494+
14541495
@staticmethod
14551496
def flatten(array):
14561497
"""
@@ -1546,6 +1587,10 @@ def create_quiver(x, y, u, v, scale=.1, arrow_scale=.3,
15461587
py.iplot(fig, filename='quiver')
15471588
```
15481589
"""
1590+
TraceFactory.validate_equal_length(x, y, u, v)
1591+
TraceFactory.validate_positive_scalars(arrow_scale=arrow_scale,
1592+
scale=scale)
1593+
15491594
barb_x, barb_y = _Quiver(x, y, u, v, scale,
15501595
arrow_scale, angle).get_barbs()
15511596
arrow_x, arrow_y = _Quiver(x, y, u, v, scale,
@@ -1577,6 +1622,7 @@ def create_streamline(x, y, u, v,
15771622
for more information on valid kwargs call
15781623
help(plotly.graph_objs.Scatter)
15791624
1625+
:rtype (trace): returns streamline data
15801626
15811627
Example 1: Plot simple streamline and increase arrow size
15821628
```
@@ -1635,7 +1681,12 @@ def create_streamline(x, y, u, v,
16351681
py.iplot(fig, filename='streamline')
16361682
```
16371683
"""
1684+
TraceFactory.validate_equal_length(x, y)
1685+
TraceFactory.validate_equal_length(u, v)
16381686
TraceFactory.validate_streamline(x, y)
1687+
TraceFactory.validate_positive_scalars(density=density,
1688+
arrow_scale=arrow_scale)
1689+
16391690
streamline_x, streamline_y = _Streamline(x, y, u, v,
16401691
density, angle,
16411692
arrow_scale).sum_streamlines()
@@ -1649,6 +1700,7 @@ def create_streamline(x, y, u, v,
16491700
return streamline
16501701

16511702
@staticmethod
1703+
def create_ohlc(open, high, low, close, direction, **kwargs):
16521704

16531705
"""
16541706
Returns data for an ohlc chart
@@ -1657,10 +1709,18 @@ def create_streamline(x, y, u, v,
16571709
:param (list) high: high values
16581710
:param (list) low: low values
16591711
:param (list) close: closing values
1712+
:param (string) direction: direction can be 'increasing' or
1713+
'decreasing'. When the direction is 'increasing', the returned data
1714+
consists of all units where the close value is greater than the
1715+
corresponding open value, and when the direction is 'decreasing',
1716+
the returned data consists of all units where the close value is
1717+
less than the corresponding open value.
16601718
:param (class) kwargs: kwargs passed through plotly.graph_objs.Scatter
16611719
for more information on valid kwargs call
16621720
help(plotly.graph_objs.Scatter)
16631721
1722+
:rtype (trace): returns data for ohlc increasing units or decreasing
1723+
units.
16641724
16651725
Example 1: Plot ohlc chart
16661726
```
@@ -1674,48 +1734,94 @@ def create_streamline(x, y, u, v,
16741734
low_data = [32.7, 32.7, 32.8, 32.6, 32.8]
16751735
close_data = [33.0, 32.9, 33.3, 33.1, 33.1]
16761736
1737+
# Create ohlc increasing units
1738+
ohlc_increase = tls.TraceFactory.create_ohlc(open_data, high_data,
1739+
low_data, close_data,
1740+
direction='increasing')
16771741
1742+
# Create ohlc decreasing units
1743+
ohlc_decrease = tls.TraceFactory.create_ohlc(open_data, high_data,
16781744
low_data, close_data,
1745+
direction='decreasing')
16791746
16801747
# Plot
16811748
fig = Figure()
16821749
fig['data'].append(ohlc_increase)
16831750
fig['data'].append(ohlc_decrease)
1751+
url = py.plot(fig, filename='ohlc')
16841752
```
16851753
1754+
Example 2: Plot ohlc chart with date labels
16861755
```
16871756
import plotly.plotly as py
16881757
import plotly.tools as tls
16891758
from plotly.graph_objs import *
16901759
1760+
from datetime import datetime
1761+
16911762
# Add data
16921763
open_data = [33.0, 33.3, 33.5, 33.0, 34.1]
16931764
high_data = [33.1, 33.3, 33.6, 33.2, 34.8]
16941765
low_data = [32.7, 32.7, 32.8, 32.6, 32.8]
16951766
close_data = [33.0, 32.9, 33.3, 33.1, 33.1]
1696-
dates = ['3/09', '6/09', '9/09', '12/09', '3/10']
1697-
1767+
dates = [datetime(year=2013, month=10, day=10),
1768+
datetime(year=2013, month=11, day=10),
1769+
datetime(year=2013, month=12, day=10),
1770+
datetime(year=2014, month=1, day=10),
1771+
datetime(year=2015, month=2, day=10)]
1772+
1773+
# Create ohlc trace of increasing units
1774+
ohlc_increase = tls.TraceFactory.create_ohlc(open_data, high_data,
1775+
low_data, close_data,
1776+
direction='increasing')
16981777
1778+
# Create ohlc trace of decreasing units
1779+
ohlc_decrease = tls.TraceFactory.create_ohlc(open_data, high_data,
1780+
low_data, close_data,
1781+
direction='decreasing')
16991782
17001783
# Create layout with dates as x-axis labels
1784+
fig = dict(data=[ohlc_increase, ohlc_decrease],
17011785
layout=dict(xaxis = dict(ticktext = dates,
17021786
tickvals = [1, 2, 3, 4, 5 ])))
17031787
17041788
# Plot
17051789
url = py.plot(fig, filename='ohlcs_dates', validate=False)
17061790
```
17071791
"""
1708-
1709-
1710-
1711-
1712-
1713-
1714-
1715-
1716-
1717-
1718-
1792+
TraceFactory.validate_equal_length(open, high, low, close)
1793+
TraceFactory.validate_ohlc(open, high, low, close, direction,
1794+
**kwargs)
1795+
1796+
if direction is 'increasing':
1797+
(flat_increase_x,
1798+
flat_increase_y,
1799+
text_increase) = _OHLC(open, high, low, close).get_increase()
1800+
1801+
kwargs.setdefault('name', 'Increasing')
1802+
kwargs.setdefault('line', {'color': 'rgb(44, 160, 44)'})
1803+
kwargs.setdefault('text', text_increase)
1804+
1805+
ohlc = Scatter(x=flat_increase_x,
1806+
y=flat_increase_y,
1807+
mode='lines',
1808+
**kwargs)
1809+
1810+
elif direction is 'decreasing':
1811+
(flat_decrease_x,
1812+
flat_decrease_y,
1813+
text_decrease) = _OHLC(open, high, low, close).get_decrease()
1814+
1815+
kwargs.setdefault('name', 'Decreasing')
1816+
kwargs.setdefault('line', {'color': 'rgb(214, 39, 40)'})
1817+
kwargs.setdefault('text', text_decrease)
1818+
1819+
ohlc = Scatter(x=flat_decrease_x,
1820+
y=flat_decrease_y,
1821+
mode='lines',
1822+
**kwargs)
1823+
1824+
return ohlc
17191825

17201826

17211827
class _Quiver(TraceFactory):

0 commit comments

Comments
 (0)