Skip to content

Commit 98ec943

Browse files
committed
llm.c kernels + wasm targeting (WIP)
1 parent 2693fc7 commit 98ec943

File tree

5 files changed

+212
-190
lines changed

5 files changed

+212
-190
lines changed

experimental/kernels/Makefile

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,11 @@ else
1111
STDLIB := -stdlib=libc++
1212
endif
1313

14+
# ASYNCIFY allows emscripten to sleep
15+
# EMFLAGS=-std=c++17 -I$(GPUCPP) -I$(GPUCPP)/third_party/headers/wasm -I. -Iunittest_llmc -Illm.c -s USE_WEBGPU=1 -s -s STACK_SIZE=100000 -s MEMORY64=1 -s ALLOW_MEMORY_GROWTH=1
1416
EMFLAGS=-std=c++17 -I$(GPUCPP) -I$(GPUCPP)/third_party/headers/wasm -I. -Iunittest_llmc -Illm.c -s USE_WEBGPU=1 -s ASYNCIFY=1 -s STACK_SIZE=100000 -s MEMORY64=1 -s ALLOW_MEMORY_GROWTH=1
1517
CXXFLAGS=-std=c++17 -I$(GPUCPP) -I$(GPUCPP)/third_party/headers -I. -Iunittest_llmc
18+
CXXFLAGS=-std=c++17 -I$(GPUCPP) -I$(GPUCPP)/third_party/headers -I. -Iunittest_llmc
1619
CFLAGS=-Ofast -march=native -I. -Iunittest_llmc
1720

1821
LDFLAGS=$(STDLIB) -L$(GPUCPP)/third_party/lib -ldl -ldawn
@@ -121,7 +124,10 @@ build/gpt2_gpucpp.html: check-emsdk run.cpp term.html build/train_gpt2
121124
$(EMFLAGS) \
122125
--shell-file term.html \
123126

124-
server: build/train_gpt2.html build/test_gpt2.html
127+
watch:
128+
ls *.cpp *.c *.hpp *.h | entr -c make build/gpt2_gpucpp.html
129+
130+
server: build/train_gpt2.html build/test_gpt2.html build/gpt2_gpucpp.html
125131
@echo "\n┌───────────────────────────────────────────────────────────────────────────────────┐"
126132
@echo "│ Open http://localhost:8000/build/run.html in your browser to see the output. │"
127133
@echo "│ │"

experimental/kernels/gpt2_wasm.c

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1+
#include "gpu.hpp"
2+
#ifdef __EMSCRIPTEN__
3+
#include "unittest_kernels.h" // replace once we figure out how to get context to persist
4+
#else
15
#include "ops.hpp"
6+
#endif
27
/*
38
This file trains the GPT-2 model.
49
This version is the clean, minimal, reference. As such:
@@ -18,6 +23,7 @@ There will be other versions of this code that specialize it and make it fast.
1823
#include <time.h>
1924
#include <string.h>
2025
#include <unistd.h>
26+
#include <memory>
2127
#ifdef OMP
2228
#include <omp.h>
2329
#endif
@@ -722,8 +728,11 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) {
722728
size_t maxT, V, Vp, L, NH, C; // size_t to prevent int overflow
723729
model->config.max_seq_len = maxT = model_header[2];
724730
model->config.vocab_size = V = model_header[3];
725-
// model->config.num_layers = L = model_header[4];
726-
model->config.num_layers = L = 3; // TODO(avh): Debugging only hack - revert this
731+
#ifdef __EMSCRIPTEN__
732+
model->config.num_layers = L = 12; // TODO(avh): Debugging only hack - revert this
733+
#else
734+
model->config.num_layers = L = model_header[4];
735+
#endif
727736
model->config.num_heads = NH = model_header[5];
728737
model->config.channels = C = model_header[6];
729738
model->config.padded_vocab_size = Vp = model_header[7];
@@ -827,6 +836,7 @@ void gpt2_forward(GPT2 *model, int* inputs, int* targets, size_t B, size_t T) {
827836
ParameterTensors params = model->params; // for brevity
828837
ActivationTensors acts = model->acts;
829838
float* residual;
839+
printf("Encoding\n");
830840
encoder_forward(acts.encoded, inputs, params.wte, params.wpe, B, T, C); // encoding goes into residual[0]
831841
for (int l = 0; l < L; l++) {
832842
printf("Forward Pass Layer %d\n", l);
@@ -1106,7 +1116,6 @@ int sample_mult(float* probabilities, int n, float coin) {
11061116
// ----------------------------------------------------------------------------
11071117
// main training loop
11081118
int main() {
1109-
initRuntime();
11101119

11111120
// build the GPT-2 model from a checkpoint
11121121
GPT2 model;
@@ -1137,9 +1146,22 @@ int main() {
11371146
int* gen_tokens = (int*)mallocCheck(B * T * sizeof(int));
11381147
const int genT = 64; // number of steps of inference we will do
11391148

1149+
#ifdef __EMSCRIPTEN__
1150+
#else
1151+
printf("Creating GPU context\n");
1152+
WGPURequiredLimits requiredLimits = LIMITS_BUFFER_SIZE_1GB;
1153+
kCtx = static_cast<gpu::Context*>(mallocCheck(sizeof(gpu::Context) * 32));
1154+
*kCtx = gpu::createContext({}, {}, {
1155+
.requiredLimits = &requiredLimits
1156+
});
1157+
printf("GPU context created\n");
1158+
#endif
1159+
11401160
// train
11411161
struct timespec start, end;
1162+
printf("Starting training\n");
11421163
for (int step = 0; step <= 40; step++) {
1164+
printf("Step %d\n", step);
11431165

11441166
// once in a while estimate the validation loss
11451167
if (step % 10 == 0) {

0 commit comments

Comments
 (0)