@@ -134,7 +134,7 @@ def imshow(
134134 x = None ,
135135 y = None ,
136136 animation_frame = False ,
137- facet_col = False ,
137+ facet_col = None ,
138138 facet_col_wrap = None ,
139139 color_continuous_scale = None ,
140140 color_continuous_midpoint = None ,
@@ -189,6 +189,14 @@ def imshow(
189189 their lengths must match the lengths of the second and first dimensions of the
190190 img argument. They are auto-populated if the input is an xarray.
191191
192+ facet_col: int, optional (default None)
193+ axis number along which the image array is slices to create a facetted plot.
194+
195+ facet_col_wrap: int
196+ Maximum number of facet columns. Wraps the column variable at this width,
197+ so that the column facets span multiple rows.
198+ Ignored if `facet_col` is None.
199+
192200 color_continuous_scale : str or list of str
193201 colormap used to map scalar data to colors (for a 2D image). This parameter is
194202 not used for RGB or RGBA images. If a string is provided, it should be the name
@@ -280,14 +288,14 @@ def imshow(
280288 args = locals ()
281289 apply_default_cascade (args )
282290 labels = labels .copy ()
283- if facet_col :
284- nslices = img .shape [- 1 ]
285- ncols = facet_col_wrap
286- nrows = nslices / ncols
291+ if facet_col is not None :
292+ nslices = img .shape [facet_col ]
293+ ncols = int ( facet_col_wrap )
294+ nrows = nslices // ncols + 1 if nslices % ncols else nslices // ncols
287295 else :
288296 nrows = 1
289297 ncols = 1
290- fig = init_figure (args , 'xy' , [], nrows , ncols , [], [])
298+ fig = init_figure (args , "xy" , [], nrows , ncols , [], [])
291299 # ----- Define x and y, set labels if img is an xarray -------------------
292300 if xarray_imported and isinstance (img , xarray .DataArray ):
293301 if binary_string :
@@ -345,10 +353,16 @@ def imshow(
345353
346354 # --------------- Starting from here img is always a numpy array --------
347355 img = np .asanyarray (img )
356+ if facet_col is not None :
357+ img = np .moveaxis (img , facet_col , 0 )
358+ facet_col = True
348359
349360 # Default behaviour of binary_string: True for RGB images, False for 2D
350361 if binary_string is None :
351- binary_string = img .ndim >= 3 and not is_dataframe
362+ if facet_col :
363+ binary_string = img .ndim >= 4 and not is_dataframe
364+ else :
365+ binary_string = img .ndim >= 3 and not is_dataframe
352366
353367 # Cast bools to uint8 (also one byte)
354368 if img .dtype == np .bool :
@@ -377,7 +391,7 @@ def imshow(
377391 zmin = 0
378392
379393 # For 2d data, use Heatmap trace, unless binary_string is True
380- if img .ndim == 2 and not binary_string :
394+ if ( img .ndim == 2 or ( img . ndim == 3 and facet_col )) and not binary_string :
381395 if y is not None and img .shape [0 ] != len (y ):
382396 raise ValueError (
383397 "The length of the y vector must match the length of the first "
@@ -388,7 +402,13 @@ def imshow(
388402 "The length of the x vector must match the length of the second "
389403 + "dimension of the img matrix."
390404 )
391- trace = go .Heatmap (x = x , y = y , z = img , coloraxis = "coloraxis1" )
405+ if facet_col :
406+ traces = [
407+ go .Heatmap (x = x , y = y , z = img_slice , coloraxis = "coloraxis1" )
408+ for img_slice in img
409+ ]
410+ else :
411+ traces = [go .Heatmap (x = x , y = y , z = img , coloraxis = "coloraxis1" )]
392412 autorange = True if origin == "lower" else "reversed"
393413 layout = dict (yaxis = dict (autorange = autorange ))
394414 if aspect == "equal" :
@@ -407,7 +427,11 @@ def imshow(
407427 layout ["coloraxis1" ]["colorbar" ] = dict (title_text = labels ["color" ])
408428
409429 # For 2D+RGB data, use Image trace
410- elif img .ndim == 3 and img .shape [- 1 ] in [3 , 4 ] or (img .ndim == 2 and binary_string ):
430+ elif (
431+ img .ndim == 3
432+ and (img .shape [- 1 ] in [3 , 4 ] or (facet_col and binary_string ))
433+ or (img .ndim == 2 and binary_string )
434+ ):
411435 rescale_image = True # to check whether image has been modified
412436 if zmin is not None and zmax is not None :
413437 zmin , zmax = (
@@ -418,7 +442,7 @@ def imshow(
418442 if zmin is None and zmax is None : # no rescaling, faster
419443 img_rescaled = img
420444 rescale_image = False
421- elif img .ndim == 2 :
445+ elif img .ndim == 2 or ( img . ndim == 3 and facet_col ) :
422446 img_rescaled = rescale_intensity (
423447 img , in_range = (zmin [0 ], zmax [0 ]), out_range = np .uint8
424448 )
@@ -433,16 +457,30 @@ def imshow(
433457 for ch in range (img .shape [- 1 ])
434458 ]
435459 )
436- img_str = _array_to_b64str (
437- img_rescaled ,
438- backend = binary_backend ,
439- compression = binary_compression_level ,
440- ext = binary_format ,
441- )
442- trace = go .Image (source = img_str )
460+ if facet_col :
461+ img_str = [
462+ _array_to_b64str (
463+ img_rescaled_slice ,
464+ backend = binary_backend ,
465+ compression = binary_compression_level ,
466+ ext = binary_format ,
467+ )
468+ for img_rescaled_slice in img_rescaled
469+ ]
470+
471+ else :
472+ img_str = [
473+ _array_to_b64str (
474+ img_rescaled ,
475+ backend = binary_backend ,
476+ compression = binary_compression_level ,
477+ ext = binary_format ,
478+ )
479+ ]
480+ traces = [go .Image (source = img_str_slice ) for img_str_slice in img_str ]
443481 else :
444482 colormodel = "rgb" if img .shape [- 1 ] == 3 else "rgba256"
445- trace = go .Image (z = img , zmin = zmin , zmax = zmax , colormodel = colormodel )
483+ traces = [ go .Image (z = img , zmin = zmin , zmax = zmax , colormodel = colormodel )]
446484 layout = {}
447485 if origin == "lower" :
448486 layout ["yaxis" ] = dict (autorange = True )
@@ -460,7 +498,8 @@ def imshow(
460498 layout_patch ["title_text" ] = args ["title" ]
461499 elif args ["template" ].layout .margin .t is None :
462500 layout_patch ["margin" ] = {"t" : 60 }
463- fig .add_trace (trace )
501+ for index , trace in enumerate (traces ):
502+ fig .add_trace (trace , row = nrows - index // ncols , col = index % ncols + 1 )
464503 fig .update_layout (layout )
465504 fig .update_layout (layout_patch )
466505 # Hover name, z or color
0 commit comments