@@ -792,13 +792,40 @@ void runTest(int version, size_t M, size_t K, size_t N,
792792 }
793793
794794 // Allocate GPU buffers and copy data
795- Context ctx = createContext (
796- {}, {},
797- /* device descriptor, enabling f16 in WGSL*/
798- {
795+ WGPUDeviceDescriptor devDescriptor = {};
796+ devDescriptor.requiredFeatureCount = 1 ;
797+ devDescriptor.requiredFeatures = std::array{WGPUFeatureName_ShaderF16}.data ();
798+
799+ Context ctx;
800+ if (numtype == kf16) {
801+ ctx = createContext (
802+ {}, {},
803+ /* device descriptor, enabling f16 in WGSL*/
804+ {
799805 .requiredFeatureCount = 1 ,
800- .requiredFeatures = std::array{WGPUFeatureName_ShaderF16}.data (),
801- });
806+ .requiredFeatures = std::array{WGPUFeatureName_ShaderF16}.data ()
807+ });
808+ if (ctx.adapterStatus != WGPURequestAdapterStatus_Success) {
809+ LOG (kDefLog , kError , " Failed to create adapter with f16 support, try running an f32 test instead (`export MATMUL_VERSION=9)." );
810+ exit (1 );
811+ }
812+ if (ctx.deviceStatus != WGPURequestDeviceStatus_Success) {
813+ LOG (kDefLog , kError , " Failed to create device with f16 support, try running an f32 test instead. (`export MATMUL_VERSION=9)" );
814+ exit (1 );
815+ }
816+ }
817+
818+ if (numtype == kf32) {
819+ ctx = createContext ({}, {}, {});
820+ if (ctx.adapterStatus != WGPURequestAdapterStatus_Success ||
821+ ctx.deviceStatus != WGPURequestDeviceStatus_Success) {
822+ LOG (kDefLog , kError , " Failed to create adapter or device" );
823+ // stop execution
824+ exit (1 );
825+ } else {
826+ LOG (kDefLog , kInfo , " Successfully created adapter and device" );
827+ }
828+ }
802829
803830 Tensor input = createTensor (ctx, Shape{M, K}, numtype, inputPtr.get ());
804831 Tensor weights = createTensor (ctx, Shape{N, K}, numtype, weightsPtr.get ()); // column-major
@@ -810,8 +837,6 @@ void runTest(int version, size_t M, size_t K, size_t N,
810837#endif
811838
812839 // Initialize Kernel and bind GPU buffers
813-
814-
815840 // pre-allocate for async dispatch
816841 std::array<std::promise<void >, nIter> promises;
817842 std::array<std::future<void >, nIter> futures;
@@ -823,10 +848,6 @@ void runTest(int version, size_t M, size_t K, size_t N,
823848 kernels[i] = selectMatmul (ctx, version, {input, weights, outputs[i]}, M, K, N, numtype);
824849 }
825850
826- #ifndef METAL_PROFILER
827- printf (" [ Press enter to start tests ... ]\n " );
828- getchar ();
829- #endif
830851 LOG (kDefLog , kInfo , " Dispatching Kernel version %d: %s, %d iterations ..." ,
831852 version, versionToStr (version).c_str (), nIter);
832853
0 commit comments