4747 torch ._C ._jit_set_profiling_executor (True )
4848 torch ._C ._jit_set_profiling_mode (False )
4949
50+ # models with forward_intermediates() and support for FeatureGetterNet features_only wrapper
51+ FEAT_INTER_FILTERS = [
52+ 'vit_*' , 'twins_*' , 'deit*' , 'beit*' , 'mvitv2*' , 'eva*' , 'samvit_*' , 'flexivit*'
53+ ]
54+
5055# transformer models don't support many of the spatial / feature based model functionalities
5156NON_STD_FILTERS = [
5257 'vit_*' , 'tnt_*' , 'pit_*' , 'coat_*' , 'cait_*' , '*mixer_*' , 'gmlp_*' , 'resmlp_*' , 'twins_*' ,
53- 'convit_*' , 'levit*' , 'visformer*' , 'deit*' , 'jx_nest_*' , 'nest_*' , ' xcit_*' , 'crossvit_*' , 'beit*' ,
54- 'poolformer_*' , 'volo_*' , 'sequencer2d_*' , 'pvt_v2*' , ' mvitv2*' , 'gcvit*' , 'efficientformer*' ,
58+ 'convit_*' , 'levit*' , 'visformer*' , 'deit*' , 'xcit_*' , 'crossvit_*' , 'beit*' ,
59+ 'poolformer_*' , 'volo_*' , 'sequencer2d_*' , 'mvitv2*' , 'gcvit*' , 'efficientformer*' ,
5560 'eva_*' , 'flexivit*' , 'eva02*' , 'samvit_*' , 'efficientvit_m*' , 'tiny_vit_*'
5661]
5762NUM_NON_STD = len (NON_STD_FILTERS )
@@ -351,15 +356,46 @@ def test_model_forward_torchscript(model_name, batch_size):
351356
352357@pytest .mark .features
353358@pytest .mark .timeout (120 )
354- @pytest .mark .parametrize ('model_name' , list_models (exclude_filters = EXCLUDE_FILTERS + EXCLUDE_FEAT_FILTERS , include_tags = True ))
359+ @pytest .mark .parametrize ('model_name' , list_models (exclude_filters = EXCLUDE_FILTERS + EXCLUDE_FEAT_FILTERS ))
355360@pytest .mark .parametrize ('batch_size' , [1 ])
356361def test_model_forward_features (model_name , batch_size ):
357362 """Run a single forward pass with each model in feature extraction mode"""
358363 model = create_model (model_name , pretrained = False , features_only = True )
359364 model .eval ()
360365 expected_channels = model .feature_info .channels ()
361366 expected_reduction = model .feature_info .reduction ()
362- assert len (expected_channels ) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6
367+ assert len (expected_channels ) >= 3 # all models here should have at least 3 default feat levels
368+
369+ input_size = _get_input_size (model = model , target = TARGET_FFEAT_SIZE )
370+ if max (input_size ) > MAX_FFEAT_SIZE :
371+ pytest .skip ("Fixed input size model > limit." )
372+ output_fmt = getattr (model , 'output_fmt' , 'NCHW' )
373+ feat_axis = get_channel_dim (output_fmt )
374+ spatial_axis = get_spatial_dim (output_fmt )
375+ import math
376+
377+ outputs = model (torch .randn ((batch_size , * input_size )))
378+ assert len (expected_channels ) == len (outputs )
379+ spatial_size = input_size [- 2 :]
380+ for e , r , o in zip (expected_channels , expected_reduction , outputs ):
381+ assert e == o .shape [feat_axis ]
382+ assert o .shape [spatial_axis [0 ]] <= math .ceil (spatial_size [0 ] / r ) + 1
383+ assert o .shape [spatial_axis [1 ]] <= math .ceil (spatial_size [1 ] / r ) + 1
384+ assert o .shape [0 ] == batch_size
385+ assert not torch .isnan (o ).any ()
386+
387+
388+ @pytest .mark .features
389+ @pytest .mark .timeout (120 )
390+ @pytest .mark .parametrize ('model_name' , list_models (FEAT_INTER_FILTERS , exclude_filters = EXCLUDE_FILTERS ))
391+ @pytest .mark .parametrize ('batch_size' , [1 ])
392+ def test_model_forward_intermediates_features (model_name , batch_size ):
393+ """Run a single forward pass with each model in feature extraction mode"""
394+ model = create_model (model_name , pretrained = False , features_only = True )
395+ model .eval ()
396+ print (model .feature_info .out_indices )
397+ expected_channels = model .feature_info .channels ()
398+ expected_reduction = model .feature_info .reduction ()
363399
364400 input_size = _get_input_size (model = model , target = TARGET_FFEAT_SIZE )
365401 if max (input_size ) > MAX_FFEAT_SIZE :
@@ -373,6 +409,41 @@ def test_model_forward_features(model_name, batch_size):
373409 assert len (expected_channels ) == len (outputs )
374410 spatial_size = input_size [- 2 :]
375411 for e , r , o in zip (expected_channels , expected_reduction , outputs ):
412+ print (o .shape )
413+ assert e == o .shape [feat_axis ]
414+ assert o .shape [spatial_axis [0 ]] <= math .ceil (spatial_size [0 ] / r ) + 1
415+ assert o .shape [spatial_axis [1 ]] <= math .ceil (spatial_size [1 ] / r ) + 1
416+ assert o .shape [0 ] == batch_size
417+ assert not torch .isnan (o ).any ()
418+
419+
420+ @pytest .mark .features
421+ @pytest .mark .timeout (120 )
422+ @pytest .mark .parametrize ('model_name' , list_models (FEAT_INTER_FILTERS , exclude_filters = EXCLUDE_FILTERS ))
423+ @pytest .mark .parametrize ('batch_size' , [1 ])
424+ def test_model_forward_intermediates (model_name , batch_size ):
425+ """Run a single forward pass with each model in feature extraction mode"""
426+ model = create_model (model_name , pretrained = False )
427+ model .eval ()
428+ feature_info = timm .models .FeatureInfo (model .feature_info , len (model .feature_info ))
429+ expected_channels = feature_info .channels ()
430+ expected_reduction = feature_info .reduction ()
431+ assert len (expected_channels ) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6
432+
433+ input_size = _get_input_size (model = model , target = TARGET_FFEAT_SIZE )
434+ if max (input_size ) > MAX_FFEAT_SIZE :
435+ pytest .skip ("Fixed input size model > limit." )
436+ output_fmt = getattr (model , 'output_fmt' , 'NCHW' )
437+ feat_axis = get_channel_dim (output_fmt )
438+ spatial_axis = get_spatial_dim (output_fmt )
439+ import math
440+
441+ output , intermediates = model .forward_intermediates (
442+ torch .randn ((batch_size , * input_size )),
443+ )
444+ assert len (expected_channels ) == len (intermediates )
445+ spatial_size = input_size [- 2 :]
446+ for e , r , o in zip (expected_channels , expected_reduction , intermediates ):
376447 assert e == o .shape [feat_axis ]
377448 assert o .shape [spatial_axis [0 ]] <= math .ceil (spatial_size [0 ] / r ) + 1
378449 assert o .shape [spatial_axis [1 ]] <= math .ceil (spatial_size [1 ] / r ) + 1
0 commit comments