@@ -434,7 +434,11 @@ def imshow(
434434 zmin = 0
435435
436436 # For 2d data, use Heatmap trace, unless binary_string is True
437- if (img .ndim == 2 or (img .ndim == 3 and slice_through )) and not binary_string :
437+ if (
438+ img .ndim == 2
439+ or (img .ndim == 3 and slice_through )
440+ or (img .ndim == 4 and double_slice_through )
441+ ) and not binary_string :
438442 y_index = 1 if slice_through else 0
439443 if y is not None and img .shape [y_index ] != len (y ):
440444 raise ValueError (
@@ -447,20 +451,16 @@ def imshow(
447451 "The length of the x vector must match the length of the second "
448452 + "dimension of the img matrix."
449453 )
454+ iterables = ()
450455 if slice_through :
451- iterables = ()
452456 if animation_frame is not None :
453457 iterables += (range (nslices_animation ),)
454458 if facet_col is not None :
455459 iterables += (range (nslices_facet ),)
456- traces = [
457- go .Heatmap (
458- x = x , y = y , z = img [index_tup ], coloraxis = "coloraxis1" , name = str (i )
459- )
460- for i , index_tup in enumerate (itertools .product (* iterables ))
461- ]
462- else :
463- traces = [go .Heatmap (x = x , y = y , z = img , coloraxis = "coloraxis1" )]
460+ traces = [
461+ go .Heatmap (x = x , y = y , z = img [index_tup ], coloraxis = "coloraxis1" , name = str (i ))
462+ for i , index_tup in enumerate (itertools .product (* iterables ))
463+ ]
464464 autorange = True if origin == "lower" else "reversed"
465465 layout = dict (yaxis = dict (autorange = autorange ))
466466 if aspect == "equal" :
@@ -488,8 +488,8 @@ def imshow(
488488 _vectorize_zvalue (zmin , mode = "min" ),
489489 _vectorize_zvalue (zmax , mode = "max" ),
490490 )
491+ iterables = ()
491492 if slice_through :
492- iterables = ()
493493 if animation_frame is not None :
494494 iterables += (range (nslices_animation ),)
495495 if facet_col is not None :
@@ -518,42 +518,26 @@ def imshow(
518518 ],
519519 axis = - 1 ,
520520 )
521- if slice_through :
522- tuples = [index_tup for index_tup in itertools .product (* iterables )]
523- img_str = [
524- _array_to_b64str (
525- img_rescaled [index_tup ],
526- backend = binary_backend ,
527- compression = binary_compression_level ,
528- ext = binary_format ,
529- )
530- for index_tup in itertools .product (* iterables )
531- ]
521+ img_str = [
522+ _array_to_b64str (
523+ img_rescaled [index_tup ],
524+ backend = binary_backend ,
525+ compression = binary_compression_level ,
526+ ext = binary_format ,
527+ )
528+ for index_tup in itertools .product (* iterables )
529+ ]
532530
533- else :
534- img_str = [
535- _array_to_b64str (
536- img_rescaled ,
537- backend = binary_backend ,
538- compression = binary_compression_level ,
539- ext = binary_format ,
540- )
541- ]
542531 traces = [
543532 go .Image (source = img_str_slice , name = str (i ))
544533 for i , img_str_slice in enumerate (img_str )
545534 ]
546535 else :
547536 colormodel = "rgb" if img .shape [- 1 ] == 3 else "rgba256"
548- if slice_through :
549- traces = [
550- go .Image (
551- z = img [index_tup ], zmin = zmin , zmax = zmax , colormodel = colormodel
552- )
553- for index_tup in itertools .product (* iterables )
554- ]
555- else :
556- traces = [go .Image (z = img , zmin = zmin , zmax = zmax , colormodel = colormodel )]
537+ traces = [
538+ go .Image (z = img [index_tup ], zmin = zmin , zmax = zmax , colormodel = colormodel )
539+ for index_tup in itertools .product (* iterables )
540+ ]
557541 layout = {}
558542 if origin == "lower" :
559543 layout ["yaxis" ] = dict (autorange = True )
0 commit comments