-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathgemma3.h
More file actions
366 lines (317 loc) · 12.2 KB
/
gemma3.h
File metadata and controls
366 lines (317 loc) · 12.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
/*
* gemma3.h - Public API for Gemma 3 4B inference in pure C
*
* This library provides zero-dependency inference for Google's Gemma 3 4B IT model.
* It loads weights directly from SafeTensors format and performs text generation.
*
* Example usage:
* gemma3_ctx *ctx = gemma3_load_dir("./gemma-3-4b-it");
* if (!ctx) {
* fprintf(stderr, "Failed to load model\n");
* return 1;
* }
*
* gemma3_gen_params params = gemma3_default_params();
* char *response = gemma3_generate(ctx, "Hello, world!", ¶ms, NULL, NULL);
* printf("%s\n", response);
* free(response);
*
* gemma3_free(ctx);
*/
#ifndef GEMMA3_H
#define GEMMA3_H
#include <stddef.h>
#include <stdint.h>
#ifdef __cplusplus
extern "C" {
#endif
/* ============================================================================
* Model Configuration Constants
* ========================================================================== */
#define GEMMA3_VOCAB_SIZE 262208
#define GEMMA3_HIDDEN_SIZE 2560
#define GEMMA3_INTERMEDIATE_SIZE 10240
#define GEMMA3_NUM_LAYERS 34
#define GEMMA3_NUM_HEADS 8
#define GEMMA3_NUM_KV_HEADS 4
#define GEMMA3_HEAD_DIM 256
#define GEMMA3_MAX_CONTEXT 131072 /* 128K tokens */
#define GEMMA3_SLIDING_WINDOW 1024
#define GEMMA3_LOCAL_RATIO 5 /* 5 local : 1 global */
#define GEMMA3_RMSNORM_EPS 1e-6f
#define GEMMA3_ROPE_THETA_LOCAL 10000.0f
#define GEMMA3_ROPE_THETA_GLOBAL 1000000.0f
/* Default context size for memory allocation */
#define GEMMA3_DEFAULT_CONTEXT 8192
/* ============================================================================
* Error Codes
* ========================================================================== */
typedef enum {
GEMMA3_OK = 0,
GEMMA3_ERR_INVALID_ARG = -1,
GEMMA3_ERR_FILE_NOT_FOUND = -2,
GEMMA3_ERR_INVALID_FORMAT = -3,
GEMMA3_ERR_OUT_OF_MEMORY = -4,
GEMMA3_ERR_MMAP_FAILED = -5,
GEMMA3_ERR_TOKENIZER_FAILED = -6,
GEMMA3_ERR_GENERATION_FAILED = -7,
GEMMA3_ERR_CONTEXT_OVERFLOW = -8,
} gemma3_error;
/* ============================================================================
* Forward Declarations
* ========================================================================== */
typedef struct gemma3_ctx gemma3_ctx;
typedef struct gemma3_tokenizer gemma3_tokenizer;
typedef struct gemma3_weights gemma3_weights;
typedef struct gemma3_kv_cache gemma3_kv_cache;
/* ============================================================================
* Model Configuration
* ========================================================================== */
typedef struct {
int vocab_size;
int hidden_size;
int intermediate_size;
int num_layers;
int num_heads;
int num_kv_heads;
int head_dim;
int max_context;
int sliding_window;
float rmsnorm_eps;
float rope_theta_local;
float rope_theta_global;
} gemma3_config;
/* ============================================================================
* Generation Parameters
* ========================================================================== */
typedef struct {
int max_tokens; /* Maximum tokens to generate */
float temperature; /* Sampling temperature (0 = greedy) */
int top_k; /* Top-k sampling (0 = disabled) */
float top_p; /* Top-p (nucleus) sampling (1.0 = disabled) */
int seed; /* Random seed (-1 for random) */
int stop_on_eos; /* Stop when EOS token generated */
int greedy; /* Force greedy decoding (overrides temperature) */
int verbose_tokens; /* Print token IDs during generation */
} gemma3_gen_params;
/**
* Get default generation parameters
* Default: max_tokens=512, temperature=0.7, top_k=50, top_p=0.9
*/
gemma3_gen_params gemma3_default_params(void);
/* ============================================================================
* Token Callback
* ========================================================================== */
/**
* Callback function called for each generated token
* @param token_id The token ID that was generated
* @param token_str The decoded string for this token (may be partial UTF-8)
* @param user_data User-provided context pointer
* @return 0 to continue generation, non-zero to stop early
*/
typedef int (*gemma3_token_callback)(int token_id, const char *token_str,
void *user_data);
/* ============================================================================
* Model Loading and Context
* ========================================================================== */
/**
* Load Gemma 3 model from a HuggingFace model directory
* @param model_dir Path to directory containing:
* - model.safetensors or model-00001-of-00002.safetensors, etc.
* - tokenizer.model (SentencePiece)
* - config.json (optional, uses defaults if missing)
* @return Context pointer on success, NULL on failure
*/
gemma3_ctx *gemma3_load_dir(const char *model_dir);
/**
* Load model with custom configuration
* @param model_dir Path to model directory
* @param max_context Maximum context length to support (affects memory usage)
* @return Context pointer on success, NULL on failure
*/
gemma3_ctx *gemma3_load_dir_ex(const char *model_dir, int max_context);
/**
* Free all resources associated with a context
*/
void gemma3_free(gemma3_ctx *ctx);
/**
* Get the last error message (thread-local)
*/
const char *gemma3_get_error(void);
/**
* Get the model configuration
*/
const gemma3_config *gemma3_get_config(const gemma3_ctx *ctx);
/**
* Get the tokenizer from a context
*/
gemma3_tokenizer *gemma3_get_tokenizer(gemma3_ctx *ctx);
/* ============================================================================
* Tokenization
* ========================================================================== */
/**
* Encode text to token IDs
* @param tok Tokenizer from gemma3_get_tokenizer()
* @param text Input text (UTF-8)
* @param tokens Output array for token IDs
* @param max_tokens Maximum number of tokens to output
* @param add_bos Add beginning-of-sequence token
* @param add_eos Add end-of-sequence token
* @return Number of tokens written, or negative error code
*/
int gemma3_tokenize(gemma3_tokenizer *tok, const char *text,
int *tokens, int max_tokens, int add_bos, int add_eos);
/**
* Decode token IDs to text
* @param tok Tokenizer from gemma3_get_tokenizer()
* @param tokens Array of token IDs
* @param num_tokens Number of tokens
* @return Decoded string (caller must free), or NULL on error
*/
char *gemma3_detokenize(gemma3_tokenizer *tok, const int *tokens, int num_tokens);
/**
* Decode a single token ID to text
* @param tok Tokenizer
* @param token_id Token ID to decode
* @return Token string (pointer to internal storage, do not free), or NULL
*/
const char *gemma3_decode_token(gemma3_tokenizer *tok, int token_id);
/**
* Get special token IDs
*/
int gemma3_bos_token(gemma3_tokenizer *tok);
int gemma3_eos_token(gemma3_tokenizer *tok);
int gemma3_pad_token(gemma3_tokenizer *tok);
int gemma3_end_turn_token(gemma3_tokenizer *tok);
int gemma3_start_turn_token(gemma3_tokenizer *tok);
/* ============================================================================
* Text Generation
* ========================================================================== */
/**
* Generate text from a prompt
* @param ctx Model context
* @param prompt Input prompt (raw text, will be tokenized)
* @param params Generation parameters (NULL for defaults)
* @param callback Optional callback for streaming output
* @param user_data User data passed to callback
* @return Generated text (caller must free), or NULL on error
*/
char *gemma3_generate(gemma3_ctx *ctx, const char *prompt,
gemma3_gen_params *params,
gemma3_token_callback callback, void *user_data);
/**
* Generate text with pre-tokenized input
* @param ctx Model context
* @param tokens Input token IDs
* @param num_tokens Number of input tokens
* @param params Generation parameters
* @param callback Optional callback for streaming
* @param user_data User data for callback
* @return Generated text (caller must free), or NULL on error
*/
char *gemma3_generate_tokens(gemma3_ctx *ctx, const int *tokens, int num_tokens,
gemma3_gen_params *params,
gemma3_token_callback callback, void *user_data);
/* ============================================================================
* Chat Interface
* ========================================================================== */
/**
* Role for chat messages
*/
typedef enum {
GEMMA3_ROLE_USER,
GEMMA3_ROLE_MODEL,
GEMMA3_ROLE_SYSTEM,
} gemma3_role;
/**
* Chat message structure
*/
typedef struct {
gemma3_role role;
const char *content;
} gemma3_message;
/**
* Generate chat completion using Gemma 3 chat template
* @param ctx Model context
* @param messages Array of chat messages
* @param num_msgs Number of messages
* @param params Generation parameters
* @param callback Optional callback for streaming
* @param user_data User data for callback
* @return Generated response (caller must free), or NULL on error
*/
char *gemma3_chat(gemma3_ctx *ctx, const gemma3_message *messages, int num_msgs,
gemma3_gen_params *params,
gemma3_token_callback callback, void *user_data);
/**
* Format messages with Gemma 3 chat template
* @param tok Tokenizer
* @param messages Array of chat messages
* @param num_msgs Number of messages
* @return Formatted prompt string (caller must free), or NULL on error
*/
char *gemma3_format_chat(gemma3_tokenizer *tok, const gemma3_message *messages,
int num_msgs);
/* ============================================================================
* KV Cache Management
* ========================================================================== */
/**
* Reset the KV cache (start fresh generation)
*/
void gemma3_reset_cache(gemma3_ctx *ctx);
/**
* Get current cache position (number of tokens processed)
*/
int gemma3_get_cache_position(gemma3_ctx *ctx);
/* ============================================================================
* Low-Level Forward Pass (Advanced Use)
* ========================================================================== */
/**
* Run forward pass for a single token
* @param ctx Model context
* @param token_id Input token ID
* @param pos Position in sequence
* @param logits Output logits array [vocab_size] (must be pre-allocated)
* @return 0 on success, negative error code on failure
*/
int gemma3_forward(gemma3_ctx *ctx, int token_id, int pos, float *logits);
/**
* Run forward pass for multiple tokens (prefill)
* @param ctx Model context
* @param tokens Input token IDs
* @param num_tokens Number of input tokens
* @param start_pos Starting position in sequence
* @param logits Output logits for last token [vocab_size]
* @return 0 on success, negative error code on failure
*/
int gemma3_forward_batch(gemma3_ctx *ctx, const int *tokens, int num_tokens,
int start_pos, float *logits);
/* ============================================================================
* Utility Functions
* ========================================================================== */
/**
* Get library version string
*/
const char *gemma3_version(void);
/**
* Check if a layer uses global attention
* @param layer_idx Layer index (0-based)
* @return 1 if global attention, 0 if local (sliding window)
*/
static inline int gemma3_is_global_layer(int layer_idx) {
// Global every 6th layer: layers 5, 11, 17, 23, 29 (0-indexed)
return ((layer_idx + 1) % 6 == 0);
}
/**
* Get RoPE theta for a layer
* @param layer_idx Layer index
* @return theta value (10K for local, 1M for global)
*/
static inline float gemma3_layer_rope_theta(int layer_idx) {
return gemma3_is_global_layer(layer_idx) ?
GEMMA3_ROPE_THETA_GLOBAL : GEMMA3_ROPE_THETA_LOCAL;
}
#ifdef __cplusplus
}
#endif
#endif /* GEMMA3_H */