Skip to content

Commit 926923b

Browse files
joshistoastlstein
andauthored
feat(prompts): hotkey controlled prompt weighting (#8647)
* feat(prompts): add abstract syntax tree (AST) builder for prompts * fix(prompts): add escaped parens to AST * test(prompts): add AST tests * fix(prompts): appease the linter * perf(prompts): break up tokenize function into subroutines * feat(prompts): add hotkey controlled prompt attention adjust * fix(hotkeys): 🩹 add translations for hotkey dialog * fix: 🏷️ remove unused exports * fix(keybinds): 🐛 use `arrowup`/`arrowdown` over `up`/`down` * refactor(prompts): ♻️ use better language for attention direction * style: 🚨 appease the linter --------- Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
1 parent 6f9f8e5 commit 926923b

File tree

7 files changed

+1267
-3
lines changed

7 files changed

+1267
-3
lines changed

invokeai/frontend/web/public/locales/en.json

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,14 @@
501501
"title": "Next Prompt in History",
502502
"desc": "When the prompt is focused, move to the next (newer) prompt in your history."
503503
},
504+
"promptWeightUp": {
505+
"title": "Increase Weight of Prompt Selection",
506+
"desc": "When the prompt is focused and text is selected, increase the weight of the selected prompt."
507+
},
508+
"promptWeightDown": {
509+
"title": "Decrease Weight of Prompt Selection",
510+
"desc": "When the prompt is focused and text is selected, decrease the weight of the selected prompt."
511+
},
504512
"toggleLeftPanel": {
505513
"title": "Toggle Left Panel",
506514
"desc": "Show or hide the left panel."
Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
import { describe, expect, it } from 'vitest';
2+
3+
import { parseTokens, serialize, tokenize } from './promptAST';
4+
5+
describe('promptAST', () => {
6+
describe('tokenize', () => {
7+
it('should tokenize basic text', () => {
8+
const tokens = tokenize('a cat');
9+
expect(tokens).toEqual([
10+
{ type: 'word', value: 'a' },
11+
{ type: 'whitespace', value: ' ' },
12+
{ type: 'word', value: 'cat' },
13+
]);
14+
});
15+
16+
it('should tokenize groups with parentheses', () => {
17+
const tokens = tokenize('(a cat)');
18+
expect(tokens).toEqual([
19+
{ type: 'lparen' },
20+
{ type: 'word', value: 'a' },
21+
{ type: 'whitespace', value: ' ' },
22+
{ type: 'word', value: 'cat' },
23+
{ type: 'rparen' },
24+
]);
25+
});
26+
27+
it('should tokenize escaped parentheses', () => {
28+
const tokens = tokenize('\\(medium\\)');
29+
expect(tokens).toEqual([
30+
{ type: 'escaped_paren', value: '(' },
31+
{ type: 'word', value: 'medium' },
32+
{ type: 'escaped_paren', value: ')' },
33+
]);
34+
});
35+
36+
it('should tokenize mixed escaped and unescaped parentheses', () => {
37+
const tokens = tokenize('colored pencil \\(medium\\) (enhanced)');
38+
expect(tokens).toEqual([
39+
{ type: 'word', value: 'colored' },
40+
{ type: 'whitespace', value: ' ' },
41+
{ type: 'word', value: 'pencil' },
42+
{ type: 'whitespace', value: ' ' },
43+
{ type: 'escaped_paren', value: '(' },
44+
{ type: 'word', value: 'medium' },
45+
{ type: 'escaped_paren', value: ')' },
46+
{ type: 'whitespace', value: ' ' },
47+
{ type: 'lparen' },
48+
{ type: 'word', value: 'enhanced' },
49+
{ type: 'rparen' },
50+
]);
51+
});
52+
53+
it('should tokenize groups with weights', () => {
54+
const tokens = tokenize('(a cat)1.2');
55+
expect(tokens).toEqual([
56+
{ type: 'lparen' },
57+
{ type: 'word', value: 'a' },
58+
{ type: 'whitespace', value: ' ' },
59+
{ type: 'word', value: 'cat' },
60+
{ type: 'rparen' },
61+
{ type: 'weight', value: 1.2 },
62+
]);
63+
});
64+
65+
it('should tokenize words with weights', () => {
66+
const tokens = tokenize('cat+');
67+
expect(tokens).toEqual([
68+
{ type: 'word', value: 'cat' },
69+
{ type: 'weight', value: '+' },
70+
]);
71+
});
72+
73+
it('should tokenize embeddings', () => {
74+
const tokens = tokenize('<embedding_name>');
75+
expect(tokens).toEqual([{ type: 'lembed' }, { type: 'word', value: 'embedding_name' }, { type: 'rembed' }]);
76+
});
77+
});
78+
79+
describe('parseTokens', () => {
80+
it('should parse basic text', () => {
81+
const tokens = tokenize('a cat');
82+
const ast = parseTokens(tokens);
83+
expect(ast).toEqual([
84+
{ type: 'word', text: 'a' },
85+
{ type: 'whitespace', value: ' ' },
86+
{ type: 'word', text: 'cat' },
87+
]);
88+
});
89+
90+
it('should parse groups', () => {
91+
const tokens = tokenize('(a cat)');
92+
const ast = parseTokens(tokens);
93+
expect(ast).toEqual([
94+
{
95+
type: 'group',
96+
children: [
97+
{ type: 'word', text: 'a' },
98+
{ type: 'whitespace', value: ' ' },
99+
{ type: 'word', text: 'cat' },
100+
],
101+
},
102+
]);
103+
});
104+
105+
it('should parse escaped parentheses', () => {
106+
const tokens = tokenize('\\(medium\\)');
107+
const ast = parseTokens(tokens);
108+
expect(ast).toEqual([
109+
{ type: 'escaped_paren', value: '(' },
110+
{ type: 'word', text: 'medium' },
111+
{ type: 'escaped_paren', value: ')' },
112+
]);
113+
});
114+
115+
it('should parse mixed escaped and unescaped parentheses', () => {
116+
const tokens = tokenize('colored pencil \\(medium\\) (enhanced)');
117+
const ast = parseTokens(tokens);
118+
expect(ast).toEqual([
119+
{ type: 'word', text: 'colored' },
120+
{ type: 'whitespace', value: ' ' },
121+
{ type: 'word', text: 'pencil' },
122+
{ type: 'whitespace', value: ' ' },
123+
{ type: 'escaped_paren', value: '(' },
124+
{ type: 'word', text: 'medium' },
125+
{ type: 'escaped_paren', value: ')' },
126+
{ type: 'whitespace', value: ' ' },
127+
{
128+
type: 'group',
129+
children: [{ type: 'word', text: 'enhanced' }],
130+
},
131+
]);
132+
});
133+
134+
it('should parse groups with attention', () => {
135+
const tokens = tokenize('(a cat)1.2');
136+
const ast = parseTokens(tokens);
137+
expect(ast).toEqual([
138+
{
139+
type: 'group',
140+
attention: 1.2,
141+
children: [
142+
{ type: 'word', text: 'a' },
143+
{ type: 'whitespace', value: ' ' },
144+
{ type: 'word', text: 'cat' },
145+
],
146+
},
147+
]);
148+
});
149+
150+
it('should parse words with attention', () => {
151+
const tokens = tokenize('cat+');
152+
const ast = parseTokens(tokens);
153+
expect(ast).toEqual([{ type: 'word', text: 'cat', attention: '+' }]);
154+
});
155+
156+
it('should parse embeddings', () => {
157+
const tokens = tokenize('<embedding_name>');
158+
const ast = parseTokens(tokens);
159+
expect(ast).toEqual([{ type: 'embedding', value: 'embedding_name' }]);
160+
});
161+
});
162+
163+
describe('serialize', () => {
164+
it('should serialize basic text', () => {
165+
const tokens = tokenize('a cat');
166+
const ast = parseTokens(tokens);
167+
const result = serialize(ast);
168+
expect(result).toBe('a cat');
169+
});
170+
171+
it('should serialize groups', () => {
172+
const tokens = tokenize('(a cat)');
173+
const ast = parseTokens(tokens);
174+
const result = serialize(ast);
175+
expect(result).toBe('(a cat)');
176+
});
177+
178+
it('should serialize escaped parentheses', () => {
179+
const tokens = tokenize('\\(medium\\)');
180+
const ast = parseTokens(tokens);
181+
const result = serialize(ast);
182+
expect(result).toBe('\\(medium\\)');
183+
});
184+
185+
it('should serialize mixed escaped and unescaped parentheses', () => {
186+
const tokens = tokenize('colored pencil \\(medium\\) (enhanced)');
187+
const ast = parseTokens(tokens);
188+
const result = serialize(ast);
189+
expect(result).toBe('colored pencil \\(medium\\) (enhanced)');
190+
});
191+
192+
it('should serialize groups with attention', () => {
193+
const tokens = tokenize('(a cat)1.2');
194+
const ast = parseTokens(tokens);
195+
const result = serialize(ast);
196+
expect(result).toBe('(a cat)1.2');
197+
});
198+
199+
it('should serialize words with attention', () => {
200+
const tokens = tokenize('cat+');
201+
const ast = parseTokens(tokens);
202+
const result = serialize(ast);
203+
expect(result).toBe('cat+');
204+
});
205+
206+
it('should serialize embeddings', () => {
207+
const tokens = tokenize('<embedding_name>');
208+
const ast = parseTokens(tokens);
209+
const result = serialize(ast);
210+
expect(result).toBe('<embedding_name>');
211+
});
212+
});
213+
214+
describe('compel compatibility examples', () => {
215+
it('should handle escaped parentheses for literal text', () => {
216+
const prompt = 'A bear \\(with razor-sharp teeth\\) in a forest.';
217+
const tokens = tokenize(prompt);
218+
const ast = parseTokens(tokens);
219+
const result = serialize(ast);
220+
expect(result).toBe(prompt);
221+
});
222+
223+
it('should handle unescaped parentheses as grouping syntax', () => {
224+
const prompt = 'A bear (with razor-sharp teeth) in a forest.';
225+
const tokens = tokenize(prompt);
226+
const ast = parseTokens(tokens);
227+
const result = serialize(ast);
228+
expect(result).toBe(prompt);
229+
});
230+
231+
it('should handle colored pencil medium example', () => {
232+
const prompt = 'colored pencil \\(medium\\)';
233+
const tokens = tokenize(prompt);
234+
const ast = parseTokens(tokens);
235+
const result = serialize(ast);
236+
expect(result).toBe(prompt);
237+
});
238+
239+
it('should distinguish between escaped and unescaped in same prompt', () => {
240+
const prompt = 'portrait \\(realistic\\) (high quality)1.2';
241+
const tokens = tokenize(prompt);
242+
const ast = parseTokens(tokens);
243+
244+
// Should have escaped parens as nodes and a group with attention
245+
expect(ast).toEqual([
246+
{ type: 'word', text: 'portrait' },
247+
{ type: 'whitespace', value: ' ' },
248+
{ type: 'escaped_paren', value: '(' },
249+
{ type: 'word', text: 'realistic' },
250+
{ type: 'escaped_paren', value: ')' },
251+
{ type: 'whitespace', value: ' ' },
252+
{
253+
type: 'group',
254+
attention: 1.2,
255+
children: [
256+
{ type: 'word', text: 'high' },
257+
{ type: 'whitespace', value: ' ' },
258+
{ type: 'word', text: 'quality' },
259+
],
260+
},
261+
]);
262+
263+
const result = serialize(ast);
264+
expect(result).toBe(prompt);
265+
});
266+
});
267+
});

0 commit comments

Comments
 (0)