Skip to content

Commit 3d6dbb9

Browse files
committed
Added Dendrogram class and tests
1 parent 0d7d1dd commit 3d6dbb9

File tree

2 files changed

+216
-1
lines changed

2 files changed

+216
-1
lines changed

plotly/tests/test_core/test_tools/test_figure_factory.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import datetime
55
from nose.tools import raises
66

7+
import numpy as np
8+
79
import plotly.tools as tls
810
from plotly.exceptions import PlotlyError
911
from plotly.graph_objs import graph_objs
@@ -802,3 +804,49 @@ def test_datetime_candlestick(self):
802804

803805
self.assertEqual(candle, exp_candle)
804806

807+
class TestDendrogram(TestCase):
808+
809+
def test_default_dendrogram(self):
810+
dendro = tls.TraceFactory.create_dendrogram(X=[[1, 2, 3, 4],
811+
[1, 1, 3, 4],
812+
[1, 2, 1, 4],
813+
[1, 2, 3, 1]])
814+
expected_dendro_data = [{'marker': {'color': 'rgb(255,133,27)'},
815+
'mode': 'lines', 'xaxis': 'xs',
816+
'yaxis': 'ys',
817+
'y': np.array([0., 1., 1., 0.]),
818+
'x': np.array([25., 25., 35., 35.]),
819+
'type': u'scatter'},
820+
{'marker': {'color': 'rgb(255,133,27)'},
821+
'mode': 'lines',
822+
'xaxis': 'xs',
823+
'yaxis': 'ys',
824+
'y': np.array([0., 2.23606798, 2.23606798, 1.]),
825+
'x': np.array([15., 15., 30., 30.]),
826+
'type': u'scatter'},
827+
{'marker': {'color': 'blue'},
828+
'mode': 'lines',
829+
'xaxis': 'xs',
830+
'yaxis': 'ys',
831+
'y': np.array([0., 3.60555128, 3.60555128, 2.23606798]),
832+
'x': np.array([5., 5., 22.5, 22.5]), 'type': u'scatter'}]
833+
834+
self.assertEqual(len(dendro.data), len(expected_dendro_data))
835+
self.assertTrue(np.array_equal(dendro.labels, np.array(['3', '2', '0', '1'])))
836+
837+
for i in range(1,len(dendro.data)):
838+
self.assertTrue(np.allclose(dendro.data[i]['x'], expected_dendro_data[i]['x']))
839+
self.assertTrue(np.allclose(dendro.data[i]['y'], expected_dendro_data[i]['y']))
840+
841+
def test_dendrogram_random_matrix(self):
842+
# create a random uncorrelated matrix
843+
X = np.random.rand(5,5)
844+
# variable 2 is correlated with all the other variables
845+
X[2,:] = sum(X,0)
846+
847+
dendro = tls.TraceFactory.create_dendrogram(X)
848+
849+
# Check that 2 is in a separate cluster
850+
self.assertEqual(dendro.labels[0], '2')
851+
852+

plotly/tools.py

Lines changed: 168 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,16 @@
1515
import six
1616

1717
import math
18+
import scipy
1819

20+
import scipy.cluster.hierarchy as sch
1921

2022
from plotly import utils
2123
from plotly import exceptions
2224
from plotly import session
2325

2426
from plotly.graph_objs import graph_objs
25-
from plotly.graph_objs import Scatter, Marker
27+
from plotly.graph_objs import Scatter, Marker, Line, Data
2628

2729

2830
# Warning format
@@ -1760,6 +1762,7 @@ def create_streamline(x, y, u, v,
17601762
y=streamline_y + arrow_y,
17611763
mode='lines', **kwargs)
17621764

1765+
<<<<<<< HEAD
17631766
data = [streamline]
17641767
layout = graph_objs.Layout(hovermode='closest')
17651768

@@ -2275,6 +2278,23 @@ def create_candlestick(open, high, low, close,
22752278
return dict(data=data, layout=layout)
22762279

22772280

2281+
@staticmethod
2282+
def create_dendrogram(X, orientation="bottom", labels=None,
2283+
colorscale=None, **kwargs):
2284+
"""
2285+
Returns a dendrogram Plotly figure object.
2286+
2287+
X: Heatmap matrix as array of arrays
2288+
orientation: 'top', 'right', 'bottom', or 'left'
2289+
labels: List of axis category labels
2290+
colorscale: Optional colorscale for dendgrogram tree clusters
2291+
"""
2292+
2293+
#TODO: add validations of input
2294+
2295+
dendrogram = _Dendrogram(X, orientation, labels, colorscale)
2296+
return dendrogram
2297+
22782298
class _Quiver(FigureFactory):
22792299
"""
22802300
Refer to FigureFactory.create_quiver() for docstring
@@ -2690,6 +2710,7 @@ def sum_streamlines(self):
26902710
streamline_y = sum(self.st_y, [])
26912711
return streamline_x, streamline_y
26922712

2713+
<<<<<<< HEAD
26932714

26942715
class _OHLC(FigureFactory):
26952716
"""
@@ -2871,3 +2892,149 @@ def get_candle_decrease(self):
28712892
return (decrease_x, decrease_close, decrease_dif,
28722893
stick_decrease_y, stick_decrease_x)
28732894

2895+
class _Dendrogram(TraceFactory):
2896+
''' Returns a Dendrogram figure object
2897+
Example usage:
2898+
D = Dendrogram( Z )
2899+
fig = { 'data':D.data, 'layout':D.layout }
2900+
py.iplot( fig, filename='Dendro', validate=False )'''
2901+
2902+
def __init__(self, X, orientation='bottom', labels=None, colorscale=None, \
2903+
width=700, height=700, xaxis='xaxis', yaxis='yaxis' ):
2904+
''' Draw a 2d dendrogram tree
2905+
X: Heatmap matrix as array of arrays
2906+
orientation: 'top', 'right', 'bottom', or 'left'
2907+
labels: List of axis category labels
2908+
colorscale: Optional colorscale for dendgrogram tree clusters
2909+
Returns a dendrogram Plotly figure object '''
2910+
2911+
self.orientation = orientation
2912+
self.labels = labels
2913+
self.xaxis = xaxis
2914+
self.yaxis = yaxis
2915+
self.data = []
2916+
self.leaves = []
2917+
self.sign = { self.xaxis:1, self.yaxis:1 }
2918+
self.layout = { self.xaxis:{}, self.yaxis:{} }
2919+
2920+
self.sign[self.xaxis] = 1 if self.orientation in ['left','bottom'] else -1
2921+
self.sign[self.yaxis] = 1 if self.orientation in ['right','bottom'] else -1
2922+
2923+
dd_traces, xvals, yvals, ordered_labels, leaves = self.get_dendrogram_traces( X, colorscale )
2924+
2925+
self.labels = ordered_labels
2926+
self.leaves = leaves
2927+
yvals_flat = yvals.flatten()
2928+
xvals_flat = xvals.flatten()
2929+
2930+
self.zero_vals = []
2931+
2932+
for i in range(len(yvals_flat)):
2933+
if yvals_flat[i] == 0.0 and xvals_flat[i] not in self.zero_vals:
2934+
self.zero_vals.append( xvals_flat[i] )
2935+
2936+
self.zero_vals.sort()
2937+
2938+
self.layout = self.set_figure_layout( width, height )
2939+
self.data = Data( dd_traces )
2940+
2941+
def get_color_dict( self, colorscale ):
2942+
''' Return colorscale used for dendrogram tree clusters '''
2943+
2944+
# These are the color codes returned for dendrograms
2945+
# We're replacing them with nicer colors
2946+
default_colors = {'r':'red','g':'green','b':'blue','c':'cyan',\
2947+
'm':'magenta','y':'yellow','k':'black','w':'white'}
2948+
2949+
if colorscale is None:
2950+
colorscale = [
2951+
"rgb(0,116,217)",
2952+
"rgb(255,65,54)",
2953+
"rgb(133,20,75)",
2954+
"rgb(255,133,27)",
2955+
"rgb(255,220,0)",
2956+
"rgb(61,153,112)"]
2957+
for i in range(len(default_colors.keys())):
2958+
k = default_colors.keys()[i]
2959+
if i < len( colorscale ):
2960+
default_colors[k] = colorscale[i]
2961+
2962+
return default_colors
2963+
2964+
def set_axis_layout( self, axis_key ):
2965+
''' Sets and returns default axis object for dendrogram figure
2966+
axis_key: "xaxis", "xaxis1", "yaxis", yaxis1", etc '''
2967+
2968+
axis_defaults = {
2969+
'type': 'linear',
2970+
'ticks': 'inside',
2971+
'mirror': 'allticks',
2972+
'rangemode': 'tozero',
2973+
'showticklabels': True,
2974+
'zeroline': False,
2975+
'showgrid': False,
2976+
'showline': True,
2977+
}
2978+
2979+
if self.labels != None:
2980+
axis_key_labels = self.xaxis
2981+
if self.orientation in ['left','right']:
2982+
axis_key_labels = self.yaxis
2983+
if axis_key_labels not in self.layout:
2984+
self.layout[axis_key_labels] = {}
2985+
self.layout[axis_key_labels]['tickvals'] = [ea*self.sign[axis_key] for ea in self.zero_vals]
2986+
self.layout[axis_key_labels]['ticktext'] = self.labels
2987+
self.layout[axis_key_labels]['tickmode'] = 'array'
2988+
2989+
self.layout[axis_key].update( axis_defaults )
2990+
2991+
return self.layout[axis_key]
2992+
2993+
def set_figure_layout( self, width, height ):
2994+
''' Sets and returns default layout object for dendrogram figure '''
2995+
2996+
self.layout.update({
2997+
'showlegend':False,
2998+
'autoscale':False,
2999+
'hovermode':'closest',
3000+
'width':width,
3001+
'width':height
3002+
})
3003+
3004+
self.set_axis_layout(self.xaxis)
3005+
self.set_axis_layout(self.yaxis)
3006+
3007+
return self.layout
3008+
3009+
def get_dendrogram_traces( self, X, colorscale ):
3010+
''' Returns a tuple with:
3011+
(a) List of Plotly trace objects for the dendrogram tree
3012+
(b) icoord: All X points of the dendogram tree as array of arrays with lenght 4
3013+
(c) dcoord: All Y points of the dendogram tree as array of arrays with lenght 4 '''
3014+
3015+
d = sch.distance.pdist(X)
3016+
Z = sch.linkage(d, method='complete')
3017+
P = sch.dendrogram(Z,orientation=self.orientation,labels=self.labels, no_plot=True)
3018+
3019+
icoord = scipy.array( P['icoord'] )
3020+
dcoord = scipy.array( P['dcoord'] )
3021+
ordered_labels = scipy.array( P['ivl'] )
3022+
color_list = scipy.array( P['color_list'] )
3023+
colors = self.get_color_dict( colorscale )
3024+
3025+
trace_list = []
3026+
3027+
for i in range(len(icoord)):
3028+
# xs and ys are arrays of 4 points that make up the '∩' shapes of the dendrogram tree
3029+
xs = icoord[i] if self.orientation in ['top','bottom'] else dcoord[i]
3030+
ys = dcoord[i] if self.orientation in ['top','bottom'] else icoord[i]
3031+
color_key = color_list[i]
3032+
trace = Scatter(x=np.multiply(self.sign[self.xaxis],xs), \
3033+
y=np.multiply(self.sign[self.yaxis],ys), \
3034+
mode='lines', marker=Marker(color=colors[color_key]) )
3035+
trace['xaxis'] = 'x'+self.xaxis[-1]
3036+
trace['yaxis'] = 'y'+self.yaxis[-1]
3037+
trace_list.append( trace )
3038+
3039+
return trace_list, icoord, dcoord, ordered_labels, P['leaves']
3040+

0 commit comments

Comments
 (0)