Skip to content

Commit 66382e2

Browse files
committed
added colors dictionary option
1 parent 00f8162 commit 66382e2

File tree

2 files changed

+157
-13
lines changed

2 files changed

+157
-13
lines changed

plotly/tests/test_core/test_tools/test_figure_factory.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1130,6 +1130,8 @@ class TestGantt(TestCase):
11301130

11311131
def test_validate_gantt(self):
11321132

1133+
# validate the basic gantt inputs
1134+
11331135
df = [dict(Task='Job A',
11341136
Start='2009-02-01',
11351137
Finish='2009-08-30',
@@ -1169,6 +1171,8 @@ def test_validate_gantt(self):
11691171

11701172
def test_gantt_index(self):
11711173

1174+
# validate the index used for gantt
1175+
11721176
df = [dict(Task='Job A',
11731177
Start='2009-02-01',
11741178
Finish='2009-08-30',
@@ -1197,6 +1201,8 @@ def test_gantt_index(self):
11971201

11981202
def test_gantt_validate_colors(self):
11991203

1204+
# validate the gantt colors variable
1205+
12001206
df = [dict(Task='Job A', Start='2009-02-01',
12011207
Finish='2009-08-30', Complete=75),
12021208
dict(Task='Job B', Start='2009-02-01',
@@ -1228,6 +1234,17 @@ def test_gantt_validate_colors(self):
12281234
tls.FigureFactory.create_gantt, df,
12291235
index_col='Complete', colors=5)
12301236

1237+
# verify that if colors is a dictionary, its keys span all the
1238+
# values in the index column
1239+
colors_dict = {75: 'rgb(1, 2, 3)'}
1240+
1241+
pattern4 = ("If you are using colors as a dictionary, all of its "
1242+
"keys must be all the values in the index column.")
1243+
1244+
self.assertRaisesRegexp(PlotlyError, pattern4,
1245+
tls.FigureFactory.create_gantt, df,
1246+
index_col='Complete', colors=colors_dict)
1247+
12311248
def test_gantt_all_args(self):
12321249

12331250
# check if gantt chart matches with expected output

plotly/tools.py

Lines changed: 140 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1765,6 +1765,123 @@ def _gantt_colorscale(chart, colors, title, index_col, show_colorbar,
17651765
fig = dict(data=data, layout=layout)
17661766
return fig
17671767

1768+
@staticmethod
1769+
def _gantt_dict(chart, colors, title, index_col, show_colorbar, bar_width,
1770+
showgrid_x, showgrid_y, height, width, tasks=None,
1771+
task_names=None, data=None):
1772+
"""
1773+
Refer to FigureFactory.create_gantt() for docstring
1774+
"""
1775+
if tasks is None:
1776+
tasks = []
1777+
if task_names is None:
1778+
task_names = []
1779+
if data is None:
1780+
data = []
1781+
1782+
for index in range(len(chart)):
1783+
task = dict(x0=chart[index]['Start'],
1784+
x1=chart[index]['Finish'],
1785+
name=chart[index]['Task'])
1786+
tasks.append(task)
1787+
1788+
shape_template = {
1789+
'type': 'rect',
1790+
'xref': 'x',
1791+
'yref': 'y',
1792+
'opacity': 1,
1793+
'line': {
1794+
'width': 0,
1795+
},
1796+
'yref': 'y',
1797+
}
1798+
1799+
index_vals = []
1800+
for row in range(len(tasks)):
1801+
if chart[row][index_col] not in index_vals:
1802+
index_vals.append(chart[row][index_col])
1803+
1804+
index_vals.sort()
1805+
1806+
# verify each value in index column appears in colors dictionary
1807+
for key in index_vals:
1808+
if key not in colors:
1809+
raise exceptions.PlotlyError("If you are using colors as a "
1810+
"dictionary, all of its keys "
1811+
"must be all the values in the "
1812+
"index column.")
1813+
1814+
for index in range(len(tasks)):
1815+
tn = tasks[index]['name']
1816+
task_names.append(tn)
1817+
del tasks[index]['name']
1818+
tasks[index].update(shape_template)
1819+
tasks[index]['y0'] = index - bar_width
1820+
tasks[index]['y1'] = index + bar_width
1821+
1822+
tasks[index]['fillcolor'] = colors[chart[index][index_col]]
1823+
1824+
# add a line for hover text and autorange
1825+
data.append(
1826+
dict(
1827+
x=[tasks[index]['x0'], tasks[index]['x1']],
1828+
y=[index, index],
1829+
name='',
1830+
marker={'color': 'white'}
1831+
)
1832+
)
1833+
1834+
layout = dict(
1835+
title=title,
1836+
showlegend=False,
1837+
height=height,
1838+
width=width,
1839+
shapes=[],
1840+
hovermode='closest',
1841+
yaxis=dict(
1842+
showgrid=showgrid_y,
1843+
ticktext=task_names,
1844+
tickvals=list(range(len(tasks))),
1845+
range=[-1, len(tasks) + 1],
1846+
autorange=False,
1847+
zeroline=False,
1848+
),
1849+
xaxis=dict(
1850+
showgrid=showgrid_x,
1851+
zeroline=False,
1852+
rangeselector=dict(
1853+
buttons=list([
1854+
dict(count=7,
1855+
label='1w',
1856+
step='day',
1857+
stepmode='backward'),
1858+
dict(count=1,
1859+
label='1m',
1860+
step='month',
1861+
stepmode='backward'),
1862+
dict(count=6,
1863+
label='6m',
1864+
step='month',
1865+
stepmode='backward'),
1866+
dict(count=1,
1867+
label='YTD',
1868+
step='year',
1869+
stepmode='todate'),
1870+
dict(count=1,
1871+
label='1y',
1872+
step='year',
1873+
stepmode='backward'),
1874+
dict(step='all')
1875+
])
1876+
),
1877+
type='date'
1878+
)
1879+
)
1880+
layout['shapes'] = tasks
1881+
1882+
fig = dict(data=data, layout=layout)
1883+
return fig
1884+
17681885
@staticmethod
17691886
def create_gantt(df, colors=None, index_col=None, show_colorbar=False,
17701887
reverse_colors=False, title='Gantt Chart',
@@ -1776,11 +1893,11 @@ def create_gantt(df, colors=None, index_col=None, show_colorbar=False,
17761893
17771894
:param (array|list) df: input data for gantt chart. Must be either a
17781895
a dataframe or a list. If dataframe, the columns must include
1779-
'Task', 'Start' and 'Finish'; 'Complete' is optional and is used
1780-
to colorscale the bars. If a list, it must contain dictionaries
1781-
with the same required column headers, with 'Complete' optional
1782-
in the same way as the dataframe
1783-
:param (list) colors: a list of 'rgb(a, b, c)' colors where a, b and c
1896+
'Task', 'Start' and 'Finish'. Other columns can be included and
1897+
used for indexing. If a list, its elements must be dictionaries
1898+
with the same required column headers: 'Task', 'Start' and
1899+
'Finish'.
1900+
:param (str|list|dict) colors: a list of 'rgb(a, b, c)' colors where a, b and c
17841901
are between 0 and 255. Can also be a Plotly colorscale but this is
17851902
will result in only a 2-color cycle. If number of colors is less
17861903
than the total number of tasks, colors will cycle
@@ -2005,14 +2122,24 @@ def create_gantt(df, colors=None, index_col=None, show_colorbar=False,
20052122
return fig
20062123

20072124
else:
2008-
fig = FigureFactory._gantt_colorscale(chart, colors,
2009-
title, index_col,
2010-
show_colorbar,
2011-
bar_width, showgrid_x,
2012-
showgrid_y, height,
2013-
width, tasks=None,
2014-
task_names=None, data=None)
2015-
return fig
2125+
if not isinstance(colors, dict):
2126+
fig = FigureFactory._gantt_colorscale(chart, colors,
2127+
title, index_col,
2128+
show_colorbar,
2129+
bar_width, showgrid_x,
2130+
showgrid_y, height,
2131+
width, tasks=None,
2132+
task_names=None,
2133+
data=None)
2134+
return fig
2135+
else:
2136+
fig = FigureFactory._gantt_dict(chart, colors, title,
2137+
index_col, show_colorbar,
2138+
bar_width, showgrid_x,
2139+
showgrid_y, height, width,
2140+
tasks=None, task_names=None,
2141+
data=None)
2142+
return fig
20162143

20172144
@staticmethod
20182145
def _find_intermediate_color(lowcolor, highcolor, intermed):

0 commit comments

Comments
 (0)