Skip to content

Commit 08e0781

Browse files
committed
solve bug for rgba images
1 parent 34c1acf commit 08e0781

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

packages/python/plotly/plotly/express/_imshow.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,14 @@ def _array_to_b64str(img, backend="pil", compression=4):
4040
# PIL and pypng error messages are quite obscure so we catch invalid compression values
4141
if compression < 0 or compression > 9:
4242
raise ValueError("compression level must be between 0 and 9.")
43+
alpha = False
4344
if img.ndim == 2:
4445
mode = "L"
4546
elif img.ndim == 3 and img.shape[-1] == 3:
4647
mode = "RGB"
4748
elif img.ndim == 3 and img.shape[-1] == 4:
4849
mode = "RGBA"
50+
alpha = True
4951
else:
5052
raise ValueError("Invalid image shape")
5153
if backend == "auto":
@@ -55,7 +57,7 @@ def _array_to_b64str(img, backend="pil", compression=4):
5557
sh = img.shape
5658
if ndim == 3:
5759
img = img.reshape((sh[0], sh[1] * sh[2]))
58-
w = png.Writer(sh[1], sh[0], greyscale=(ndim == 2), compression=compression)
60+
w = png.Writer(sh[1], sh[0], greyscale=(ndim == 2), alpha=alpha, compression=compression)
5961
img_png = png.from_array(img, mode=mode)
6062
prefix = "data:image/png;base64,"
6163
with BytesIO() as stream:

0 commit comments

Comments
 (0)