Skip to content

Commit 25d334f

Browse files
committed
Refact all our _*_imported code \o/
1 parent 7cbd355 commit 25d334f

File tree

4 files changed

+68
-131
lines changed

4 files changed

+68
-131
lines changed

plotly/graph_objs/figure_factory.py

Lines changed: 26 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,8 @@
66
import math
77
from collections import OrderedDict
88

9-
from plotly import exceptions
9+
from plotly import exceptions, optional_imports
1010
from plotly.graph_objs.graph_objs import GraphObjectFactory
11-
from plotly.tools import (_numpy_imported, _scipy_imported,
12-
_scipy__spatial_imported,
13-
_scipy__cluster__hierarchy_imported)
14-
15-
if _scipy_imported:
16-
import scipy
17-
import scipy as scp
18-
if _numpy_imported:
19-
import numpy as np
20-
if _scipy__spatial_imported:
21-
import scipy.spatial as scs
22-
if _scipy__cluster__hierarchy_imported:
23-
import scipy.cluster.hierarchy as sch
2411

2512

2613
# Default colours for finance charts
@@ -106,16 +93,13 @@ def _validate_distplot(hist_data, curve_type):
10693
:raises: (PlotlyError) If curve_type is not valid (i.e. not 'kde' or
10794
'normal').
10895
"""
109-
try:
110-
import pandas as pd
111-
_pandas_imported = True
112-
except ImportError:
113-
_pandas_imported = False
96+
pd = optional_imports.get_module('pandas')
97+
np = optional_imports.get_module('numpy')
11498

11599
hist_data_types = (list,)
116-
if _numpy_imported:
100+
if np:
117101
hist_data_types += (np.ndarray,)
118-
if _pandas_imported:
102+
if pd:
119103
hist_data_types += (pd.core.series.Series,)
120104

121105
if not isinstance(hist_data[0], hist_data_types):
@@ -131,8 +115,8 @@ def _validate_distplot(hist_data, curve_type):
131115
raise exceptions.PlotlyError("curve_type must be defined as "
132116
"'kde' or 'normal'")
133117

134-
if _scipy_imported is False:
135-
raise ImportError("FigureFactory.create_distplot requires scipy")
118+
msg = 'FigureFactory.create_distplot requires scipy'
119+
optional_imports.get_module('scipy', raise_exc=True, msg=msg)
136120

137121
@staticmethod
138122
def _validate_positive_scalars(**kwargs):
@@ -165,8 +149,8 @@ def _validate_streamline(x, y):
165149
:raises: (PlotlyError) If x is not evenly spaced.
166150
:raises: (PlotlyError) If y is not evenly spaced.
167151
"""
168-
if _numpy_imported is False:
169-
raise ImportError("FigureFactory.create_streamline requires numpy")
152+
msg = 'FigureFactory.create_streamline requires numpy'
153+
optional_imports.get_module('numpy', raise_exc=True, msg=msg)
170154
for index in range(len(x) - 1):
171155
if ((x[index + 1] - x[index]) - (x[1] - x[0])) > .0001:
172156
raise exceptions.PlotlyError("x must be a 1 dimensional, "
@@ -1085,13 +1069,10 @@ def create_dendrogram(X, orientation="bottom", labels=None,
10851069
```
10861070
10871071
"""
1088-
dependencies = (_scipy_imported and _scipy__spatial_imported and
1089-
_scipy__cluster__hierarchy_imported)
1090-
1091-
if dependencies is False:
1092-
raise ImportError("FigureFactory.create_dendrogram requires "
1093-
"scipy, scipy.spatial and scipy.hierarchy")
1094-
1072+
msg = ('FigureFactory.create_dendrogram requires scipy, scipy.spatial '
1073+
'and scipy.hierarchy')
1074+
for module in ['scipy', 'scipy.spatial', 'scipy.cluster.hierarchy']:
1075+
optional_imports.get_module(module, raise_exc=True, msg=msg)
10951076
s = X.shape
10961077
if len(s) != 2:
10971078
exceptions.PlotlyError("X should be 2-dimensional array.")
@@ -1253,6 +1234,7 @@ class _Streamline(FigureFactory):
12531234
def __init__(self, x, y, u, v,
12541235
density, angle,
12551236
arrow_scale, **kwargs):
1237+
np = optional_imports.get_module('numpy', raise_exc=True)
12561238
self.x = np.array(x)
12571239
self.y = np.array(y)
12581240
self.u = np.array(u)
@@ -1296,6 +1278,7 @@ def value_at(self, a, xi, yi):
12961278
"""
12971279
Set up for RK4 function, based on Bokeh's streamline code
12981280
"""
1281+
np = optional_imports.get_module('numpy', raise_exc=True)
12991282
if isinstance(xi, np.ndarray):
13001283
self.x = xi.astype(np.int)
13011284
self.y = yi.astype(np.int)
@@ -1412,6 +1395,7 @@ def get_streamlines(self):
14121395
"""
14131396
Get streamlines by building trajectory set.
14141397
"""
1398+
np = optional_imports.get_module('numpy', raise_exc=True)
14151399
for indent in range(self.density // 2):
14161400
for xi in range(self.density - 2 * indent):
14171401
self.traj(xi + indent, indent)
@@ -1442,6 +1426,7 @@ def get_streamline_arrows(self):
14421426
:rtype (list, list) arrows_x: x-values to create arrowhead and
14431427
arrows_y: y-values to create arrowhead
14441428
"""
1429+
np = optional_imports.get_module('numpy', raise_exc=True)
14451430
arrow_end_x = np.empty((len(self.st_x)))
14461431
arrow_end_y = np.empty((len(self.st_y)))
14471432
arrow_start_x = np.empty((len(self.st_x)))
@@ -1743,12 +1728,14 @@ def make_kde(self):
17431728
17441729
:rtype (list) curve: list of kde representations
17451730
"""
1731+
scipy__stats = optional_imports.get_module('scipy.stats',
1732+
raise_exc=True)
17461733
curve = [None] * self.trace_number
17471734
for index in range(self.trace_number):
17481735
self.curve_x[index] = [self.start[index] +
17491736
x * (self.end[index] - self.start[index]) /
17501737
500 for x in range(500)]
1751-
self.curve_y[index] = (scipy.stats.gaussian_kde
1738+
self.curve_y[index] = (scipy__stats.gaussian_kde
17521739
(self.hist_data[index])
17531740
(self.curve_x[index]))
17541741
self.curve_y[index] *= self.bin_size
@@ -1774,6 +1761,7 @@ def make_normal(self):
17741761
17751762
:rtype (list) curve: list of normal curve representations
17761763
"""
1764+
scipy = optional_imports.get_module('scipy', raise_exc=True)
17771765
curve = [None] * self.trace_number
17781766
mean = [None] * self.trace_number
17791767
sd = [None] * self.trace_number
@@ -1978,6 +1966,11 @@ def get_dendrogram_traces(self, X, colorscale):
19781966
(e) P['leaves']: left-to-right traversal of the leaves
19791967
19801968
"""
1969+
np = optional_imports.get_module('numpy', raise_exc=True)
1970+
scp = optional_imports.get_module('scipy', raise_exc=True)
1971+
scs = optional_imports.get_module('scipy.spatial', raise_exc=True)
1972+
sch = optional_imports.get_module('scipy.cluster.hierarchy',
1973+
raise_exc=True)
19811974
d = scs.distance.pdist(X)
19821975
Z = sch.linkage(d, method='complete')
19831976
P = sch.dendrogram(Z, orientation=self.orientation,

plotly/offline/offline.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import requests
1313

14-
from plotly import session, tools, utils
14+
from plotly import session, tools, utils, optional_imports
1515
from plotly.exceptions import PlotlyError
1616

1717
PLOTLY_OFFLINE_DIRECTORY = plotlyjs_path = os.path.expanduser(
@@ -51,9 +51,9 @@ def init_notebook_mode():
5151
to load the necessary javascript files for creating
5252
Plotly graphs with plotly.offline.iplot.
5353
"""
54-
if not tools._ipython_imported:
55-
raise ImportError('`iplot` can only run inside an IPython Notebook.')
56-
from IPython.display import HTML, display
54+
msg = '`iplot` can only run inside an IPython Notebook.'
55+
IPython__display = optional_imports.get_module('IPython.display',
56+
raise_exc=True, msg=msg)
5757

5858
if not os.path.exists(PLOTLY_OFFLINE_BUNDLE):
5959
raise PlotlyError('Plotly Offline source file at {source_path} '
@@ -68,8 +68,10 @@ def init_notebook_mode():
6868

6969
global __PLOTLY_OFFLINE_INITIALIZED
7070
__PLOTLY_OFFLINE_INITIALIZED = True
71-
display(HTML('<script type="text/javascript">' +
72-
open(PLOTLY_OFFLINE_BUNDLE).read() + '</script>'))
71+
IPython__display.display(IPython__display.HTML(
72+
'<script type="text/javascript">' +
73+
open(PLOTLY_OFFLINE_BUNDLE).read() + '</script>'
74+
))
7375

7476

7577
def iplot(figure_or_data, show_link=True, link_text='Export to plot.ly'):
@@ -108,10 +110,10 @@ def iplot(figure_or_data, show_link=True, link_text='Export to plot.ly'):
108110
'plotly.offline.init_notebook_mode() '
109111
'# run at the start of every ipython notebook',
110112
]))
111-
if not tools._ipython_imported:
112-
raise ImportError('`iplot` can only run inside an IPython Notebook.')
113+
msg = '`iplot` can only run inside an IPython Notebook.'
114+
IPython__display = optional_imports.get_module('IPython.display',
115+
raise_exc=True, msg=msg)
113116

114-
from IPython.display import HTML, display
115117
if isinstance(figure_or_data, dict):
116118
data = figure_or_data['data']
117119
layout = figure_or_data.get('layout', {})
@@ -152,7 +154,7 @@ def iplot(figure_or_data, show_link=True, link_text='Export to plot.ly'):
152154
.replace('http://', '')
153155
link_text = link_text.replace('plot.ly', link_domain)
154156

155-
display(HTML(
157+
IPython__display.display(IPython__display.HTML(
156158
'<script type="text/javascript">'
157159
'window.PLOTLYENV=window.PLOTLYENV || {};'
158160
'window.PLOTLYENV.BASE_URL="' + plotly_platform_url + '";'
@@ -169,17 +171,17 @@ def iplot(figure_or_data, show_link=True, link_text='Export to plot.ly'):
169171
layout=jlayout,
170172
link_text=link_text)
171173

172-
display(HTML(''
173-
'<div class="{id} loading" style="color: rgb(50,50,50);">'
174-
'Drawing...</div>'
175-
'<div id="{id}" style="height: {height}; width: {width};" '
176-
'class="plotly-graph-div">'
177-
'</div>'
178-
'<script type="text/javascript">'
179-
'{script}'
180-
'</script>'
181-
''.format(id=plotdivid, script=script,
182-
height=height, width=width)))
174+
IPython__display.display(IPython__display.HTML(
175+
'<div class="{id} loading" style="color: rgb(50,50,50);">'
176+
'Drawing...</div>'
177+
'<div id="{id}" style="height: {height}; width: {width};" '
178+
'class="plotly-graph-div">'
179+
'</div>'
180+
'<script type="text/javascript">'
181+
'{script}'
182+
'</script>'
183+
''.format(id=plotdivid, script=script, height=height, width=width)
184+
))
183185

184186

185187
def plot():

plotly/tools.py

Lines changed: 11 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,8 @@
1212
import warnings
1313
import six
1414

15-
from plotly import utils
16-
from plotly import exceptions
17-
from plotly import graph_reference
18-
from plotly import session
15+
from plotly import (exceptions, graph_reference, optional_imports, session,
16+
utils)
1917
from plotly.files import (CONFIG_FILE, CREDENTIALS_FILE, FILE_CONTENT,
2018
GRAPH_REFERENCE_FILE, check_file_permissions)
2119

@@ -29,50 +27,6 @@ def warning_on_one_line(message, category, filename, lineno,
2927
message)
3028
warnings.formatwarning = warning_on_one_line
3129

32-
try:
33-
from . import matplotlylib
34-
_matplotlylib_imported = True
35-
except ImportError:
36-
_matplotlylib_imported = False
37-
38-
try:
39-
import IPython
40-
import IPython.core.display
41-
_ipython_imported = True
42-
except ImportError:
43-
_ipython_imported = False
44-
45-
try:
46-
import numpy as np
47-
_numpy_imported = True
48-
except ImportError:
49-
_numpy_imported = False
50-
51-
try:
52-
import scipy as scp
53-
_scipy_imported = True
54-
except ImportError:
55-
_scipy_imported = False
56-
57-
try:
58-
import scipy.spatial as scs
59-
_scipy__spatial_imported = True
60-
except ImportError:
61-
_scipy__spatial_imported = False
62-
63-
try:
64-
import scipy.cluster.hierarchy as sch
65-
_scipy__cluster__hierarchy_imported = True
66-
except ImportError:
67-
_scipy__cluster__hierarchy_imported = False
68-
69-
try:
70-
import scipy
71-
import scipy.stats
72-
_scipy_imported = True
73-
except ImportError:
74-
_scipy_imported = False
75-
7630

7731
def get_config_defaults():
7832
"""
@@ -383,11 +337,12 @@ def embed(file_owner_or_url, file_id=None, width="100%", height=525):
383337
height=height)
384338

385339
# see if we are in the SageMath Cloud
386-
from sage_salvus import html
387-
return html(s, hide=False)
340+
sage_salvus = optional_imports.get_module('sage_salvus')
341+
if sage_salvus:
342+
return sage_salvus.html(s, hide=False)
388343
except:
389344
pass
390-
if _ipython_imported:
345+
if optional_imports.get_module('IPython.core.display'):
391346
if file_id:
392347
plotly_domain = (
393348
session.get_session_config().get('plotly_domain') or
@@ -466,7 +421,8 @@ def mpl_to_plotly(fig, resize=False, strip_style=False, verbose=False):
466421
{plotly_domain}/python/getting-started
467422
468423
"""
469-
if _matplotlylib_imported:
424+
matplotlylib = optional_imports.get_module('plotly.matplotlylib')
425+
if matplotlylib:
470426
renderer = matplotlylib.PlotlyRenderer()
471427
matplotlylib.Exporter(renderer).run(fig)
472428
if resize:
@@ -1361,8 +1317,9 @@ def _replace_newline(obj):
13611317
return obj # we return the actual reference... but DON'T mutate.
13621318

13631319

1364-
if _ipython_imported:
1365-
class PlotlyDisplay(IPython.core.display.HTML):
1320+
IPython__core__display = optional_imports.get_module('IPython.core.display')
1321+
if IPython__core__display:
1322+
class PlotlyDisplay(IPython__core__display.HTML):
13661323
"""An IPython display object for use with plotly urls
13671324
13681325
PlotlyDisplay objects should be instantiated with a url for a plot.

plotly/utils.py

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,8 @@
1515

1616
import pytz
1717

18-
19-
from . exceptions import PlotlyError
20-
21-
try:
22-
import numpy
23-
_numpy_imported = True
24-
except ImportError:
25-
_numpy_imported = False
26-
27-
try:
28-
import pandas
29-
_pandas_imported = True
30-
except ImportError:
31-
_pandas_imported = False
32-
33-
try:
34-
import sage.all
35-
_sage_imported = True
36-
except ImportError:
37-
_sage_imported = False
18+
from plotly import optional_imports
19+
from plotly.exceptions import PlotlyError
3820

3921

4022
### incase people are using threading, we lock file reads
@@ -229,7 +211,8 @@ def encode_as_list(obj):
229211
@staticmethod
230212
def encode_as_sage(obj):
231213
"""Attempt to convert sage.all.RR to floats and sage.all.ZZ to ints"""
232-
if not _sage_imported:
214+
sage = optional_imports.get_module('sage')
215+
if not sage:
233216
raise NotEncodable
234217

235218
if obj in sage.all.RR:
@@ -242,7 +225,8 @@ def encode_as_sage(obj):
242225
@staticmethod
243226
def encode_as_pandas(obj):
244227
"""Attempt to convert pandas.NaT"""
245-
if not _pandas_imported:
228+
pandas = optional_imports.get_module('pandas')
229+
if not pandas:
246230
raise NotEncodable
247231

248232
if obj is pandas.NaT:
@@ -253,7 +237,8 @@ def encode_as_pandas(obj):
253237
@staticmethod
254238
def encode_as_numpy(obj):
255239
"""Attempt to convert numpy.ma.core.masked"""
256-
if not _numpy_imported:
240+
numpy = optional_imports.get_module('numpy')
241+
if not numpy:
257242
raise NotEncodable
258243

259244
if obj is numpy.ma.core.masked:

0 commit comments

Comments
 (0)