Skip to content

Commit 91f608e

Browse files
committed
tweaks so that zmin/zmax are set less often
1 parent 08e0781 commit 91f608e

File tree

2 files changed

+36
-25
lines changed

2 files changed

+36
-25
lines changed

packages/python/plotly/plotly/express/_imshow.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,9 @@ def _array_to_b64str(img, backend="pil", compression=4):
5757
sh = img.shape
5858
if ndim == 3:
5959
img = img.reshape((sh[0], sh[1] * sh[2]))
60-
w = png.Writer(sh[1], sh[0], greyscale=(ndim == 2), alpha=alpha, compression=compression)
60+
w = png.Writer(
61+
sh[1], sh[0], greyscale=(ndim == 2), alpha=alpha, compression=compression
62+
)
6163
img_png = png.from_array(img, mode=mode)
6264
prefix = "data:image/png;base64,"
6365
with BytesIO() as stream:
@@ -81,11 +83,11 @@ def _vectorize_zvalue(z):
8183
if z is None:
8284
return z
8385
elif np.isscalar(z):
84-
return [z] * 3 + [1]
86+
return [z] * 3 + [255]
8587
elif len(z) == 1:
86-
return list(z) * 3 + [1]
88+
return list(z) * 3 + [255]
8789
elif len(z) == 3:
88-
return list(z) + [1]
90+
return list(z) + [255]
8991
elif len(z) == 4:
9092
return z
9193
else:
@@ -311,11 +313,7 @@ def imshow(
311313
raise ValueError("Binary strings cannot be used with pandas arrays")
312314
has_nans = True
313315
else:
314-
has_nans = np.any(np.isnan(img))
315-
if has_nans and binary_string:
316-
raise ValueError(
317-
"Binary strings cannot be used with arrays containing NaNs"
318-
)
316+
has_nans = False
319317

320318
# --------------- Starting from here img is always a numpy array --------
321319
img = np.asanyarray(img)
@@ -340,9 +338,9 @@ def imshow(
340338
if (zmax is not None or binary_string) and zmin is None:
341339
zmin = img.min()
342340
else:
343-
if zmax is None and (img.dtype is not np.uint8 or img.ndim == 2):
341+
if zmax is None and (img.dtype != np.uint8 or img.ndim == 2):
344342
zmax = _infer_zmax_from_type(img)
345-
if zmin is None:
343+
if zmin is None and zmax is not None:
346344
zmin = 0
347345

348346
# For 2d data, use Heatmap trace, unless binary_string is True
@@ -377,9 +375,12 @@ def imshow(
377375

378376
# For 2D+RGB data, use Image trace
379377
elif img.ndim == 3 and img.shape[-1] in [3, 4] or (img.ndim == 2 and binary_string):
380-
zmin, zmax = _vectorize_zvalue(zmin), _vectorize_zvalue(zmax)
378+
if zmin is not None and zmax is not None:
379+
zmin, zmax = _vectorize_zvalue(zmin), _vectorize_zvalue(zmax)
381380
if binary_string:
382-
if img.ndim == 2:
381+
if zmin is None and zmax is None:
382+
img_rescaled = img
383+
elif img.ndim == 2:
383384
img_rescaled = rescale_intensity(
384385
img, in_range=(zmin[0], zmax[0]), out_range=np.uint8
385386
)
@@ -403,7 +404,8 @@ def imshow(
403404
)
404405
trace = go.Image(source=img_str)
405406
else:
406-
trace = go.Image(z=img, zmin=zmin, zmax=zmax)
407+
colormodel = "rgb" if img.shape[-1] == 3 else "rgba"
408+
trace = go.Image(z=img, zmin=zmin, zmax=zmax, colormodel=colormodel)
407409
layout = {}
408410
if origin == "lower":
409411
layout["yaxis"] = dict(autorange=True)

packages/python/plotly/plotly/tests/test_core/test_px/test_imshow.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,7 @@ def decode_image_string(image_string):
2626
@pytest.mark.parametrize("binary_string", [False, True])
2727
def test_rgb_uint8(binary_string):
2828
fig = px.imshow(img_rgb, binary_string=binary_string)
29-
if not binary_string:
30-
assert fig.data[0]["zmax"] == (255, 255, 255, 1)
31-
else:
32-
assert fig.data[0]["zmax"] is None
29+
assert fig.data[0]["zmax"] is None
3330

3431

3532
def test_zmax():
@@ -39,10 +36,10 @@ def test_zmax():
3936
(100,),
4037
[100, 100, 100],
4138
(100, 100, 100),
42-
(100, 100, 100, 1),
39+
(100, 100, 100, 255),
4340
]:
4441
fig = px.imshow(img_rgb, zmax=zmax, binary_string=False)
45-
assert fig.data[0]["zmax"] == (100, 100, 100, 1)
42+
assert fig.data[0]["zmax"] == (100, 100, 100, 255)
4643

4744

4845
def test_automatic_zmax_from_dtype():
@@ -56,7 +53,11 @@ def test_automatic_zmax_from_dtype():
5653
img = np.array([0, 1], dtype=key)
5754
img = np.dstack((img,) * 3)
5855
fig = px.imshow(img, binary_string=False)
59-
assert fig.data[0]["zmax"] == (val, val, val, 1)
56+
# For uint8 in "infer" mode we don't pass zmin/zmax unless specified
57+
if key in [np.uint8, np.bool]:
58+
assert fig.data[0]["zmax"] is None
59+
else:
60+
assert fig.data[0]["zmax"] == (val, val, val, 255)
6061

6162

6263
@pytest.mark.parametrize("binary_string", [False, True])
@@ -91,15 +92,23 @@ def test_wrong_dimensions():
9192
_ = px.imshow(img)
9293

9394

94-
def test_nan_inf_data():
95+
@pytest.mark.parametrize("binary_string", [False, True])
96+
def test_nan_inf_data(binary_string):
9597
imgs = [np.ones((20, 20)), 255 * np.ones((20, 20))]
9698
zmaxs = [1, 255]
9799
for zmax, img in zip(zmaxs, imgs):
98100
img[0] = 0
99101
img[10:12] = np.nan
100102
# the case of 2d/heatmap is handled gracefully by the JS trace but I don't know how to check it
101-
fig = px.imshow(np.dstack((img,) * 3))
102-
assert fig.data[0]["zmax"] == (zmax, zmax, zmax, 1)
103+
fig = px.imshow(
104+
np.dstack((img,) * 3),
105+
binary_string=binary_string,
106+
contrast_rescaling="minxmax",
107+
)
108+
if not binary_string:
109+
assert fig.data[0]["zmax"] == (zmax, zmax, zmax, 255)
110+
else:
111+
assert fig.data[0]["zmax"] is None
103112

104113

105114
def test_zmax_floats():
@@ -113,7 +122,7 @@ def test_zmax_floats():
113122
zmaxs = [1, 1, 255, 65535]
114123
for zmax, img in zip(zmaxs, imgs):
115124
fig = px.imshow(img, binary_string=False)
116-
assert fig.data[0]["zmax"] == (zmax, zmax, zmax, 1)
125+
assert fig.data[0]["zmax"] == (zmax, zmax, zmax, 255)
117126
# single-channel
118127
imgs = [
119128
np.ones((5, 5)),

0 commit comments

Comments
 (0)