Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions packages/compass-aggregations/src/modules/data-service.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
import type { DataService as OriginalDataService } from 'mongodb-data-service';

type FetchCollectionMetadataDataServiceMethods =
| 'collectionStats'
| 'collectionInfo'
| 'listCollections'
| 'isListSearchIndexesSupported';

export type RequiredDataServiceProps =
| FetchCollectionMetadataDataServiceMethods
| 'isCancelError'
| 'estimatedCount'
| 'aggregate'
Expand Down
2 changes: 2 additions & 0 deletions packages/compass-aggregations/src/modules/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ import type {
ConnectionScopedAppRegistry,
} from '@mongodb-js/compass-connections/provider';
import type { TrackFunction } from '@mongodb-js/compass-telemetry';
import type Collection from 'mongodb-collection-model';
/**
* The main application reducer.
*
Expand Down Expand Up @@ -110,6 +111,7 @@ export type PipelineBuilderExtraArgs = {
connectionScopedAppRegistry: ConnectionScopedAppRegistry<
'open-export' | 'view-edited' | 'agg-pipeline-out-executed'
>;
collection: Collection;
};

export type PipelineBuilderThunkDispatch<A extends Action = AnyAction> =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ export const runAIPipelineGeneration = (
logger: { log, mongoLogId },
track,
connectionInfoRef,
collection,
}
) => {
const {
Expand Down Expand Up @@ -286,6 +287,9 @@ export const runAIPipelineGeneration = (
}
)) || [];
const schema = await getSimplifiedSchema(sampleDocuments);
const { isFLE } = await collection.fetchMetadata({
dataService: dataService!,
});

const { collection: collectionName, database: databaseName } =
toNS(namespace);
Expand All @@ -303,6 +307,7 @@ export const runAIPipelineGeneration = (
}
: undefined),
requestId,
enableStorage: !isFLE,
},
connectionInfo
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,11 @@ function createStore({
dataService: {} as any,
connectionInfoRef,
connectionScopedAppRegistry,
collection: {
fetchMetadata() {
return Promise.resolve({ isFLE: false });
},
} as any,
})
)
);
Expand Down
1 change: 1 addition & 0 deletions packages/compass-aggregations/src/stores/store.ts
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ export function activateAggregationsPlugin(
connectionInfoRef,
connectionScopedAppRegistry,
dataService,
collection: collectionModel,
})
)
);
Expand Down
1 change: 1 addition & 0 deletions packages/compass-aggregations/test/configure-store.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ function getMockedPluginArgs(
CompassAggregationsPlugin.provider.withMockServices({
atlasAiService,
collection: {
fetchMetadata: () => ({}),
toJSON: () => ({}),
on: () => {},
removeListener: () => {},
Expand Down
5 changes: 4 additions & 1 deletion packages/compass-crud/test/render-with-query-bar.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@ export const MockQueryBarPlugin: typeof QueryBarPlugin =
getConnectionString() {
return { hosts: [] } as any;
},
},
} as any,
instance: { on() {}, removeListener() {} } as any,
favoriteQueryStorageAccess: compassFavoriteQueryStorageAccess,
recentQueryStorageAccess: compassRecentQueryStorageAccess,
atlasAiService: {} as any,
collection: {
fetchMetadata: () => Promise.resolve({} as any),
} as any,
});

export const renderWithQueryBar = (
Expand Down
10 changes: 5 additions & 5 deletions packages/compass-e2e-tests/helpers/assistant-service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -156,16 +156,16 @@ export async function startMockAssistantServer(
getResponse: () => MockAssistantResponse;
setResponse: (response: MockAssistantResponse) => void;
getRequests: () => {
content: any;
req: any;
content: Record<string, any>;
req: http.IncomingMessage;
}[];
endpoint: string;
server: http.Server;
stop: () => Promise<void>;
}> {
let requests: {
content: any;
req: any;
content: Record<string, any>;
req: http.IncomingMessage;
}[] = [];
let response = _response;
const server = http
Expand All @@ -174,7 +174,7 @@ export async function startMockAssistantServer(
res.setHeader('Access-Control-Allow-Methods', 'POST, OPTIONS');
res.setHeader(
'Access-Control-Allow-Headers',
'Content-Type, Authorization, X-Request-Origin, User-Agent, X-CSRF-Token, X-CSRF-Time'
'Content-Type, Authorization, X-Request-Origin, User-Agent, X-CSRF-Token, X-CSRF-Time, Entrypoint, X-Client-Request-Id'
);
res.setHeader('Access-Control-Allow-Credentials', 'true');

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,10 +265,16 @@ describe('Collection ai query with chatbot (with mocked backend)', function () {
expect(requests.length).to.equal(1);

const queryRequest = requests[0];
expect(queryRequest.req.headers).to.have.property('x-client-request-id');
// TODO(COMPASS-10125): Switch the model to `mongodb-slim-latest` when
// enabling this feature.
expect(queryRequest.content.model).to.equal('mongodb-chat-latest');
expect(queryRequest.content.instructions).to.be.string;
expect(queryRequest.content.metadata).to.have.property('userId');
expect(queryRequest.content.metadata.store).to.have.equal('true');
expect(queryRequest.content.metadata.sensitiveStorage).to.have.equal(
'sensitive'
);
expect(queryRequest.content.input).to.be.an('array').of.length(1);

const message = queryRequest.content.input[0];
Expand Down
25 changes: 21 additions & 4 deletions packages/compass-generative-ai/src/atlas-ai-service.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ describe('AtlasAiService', function () {
{ _id: new ObjectId('642d766b7300158b1f22e972') },
],
requestId: 'abc',
enableStorage: false,
},
mockConnectionInfo
);
Expand All @@ -223,7 +224,7 @@ describe('AtlasAiService', function () {

expect(args[0]).to.eq(expectedEndpoints[aiEndpoint]);
expect(args[1].body).to.eq(
'{"userInput":"test","collectionName":"jam","databaseName":"peanut","schema":{"_id":{"types":[{"bsonType":"ObjectId"}]}},"sampleDocuments":[{"_id":{"$oid":"642d766b7300158b1f22e972"}}]}'
'{"userInput":"test","collectionName":"jam","databaseName":"peanut","schema":{"_id":{"types":[{"bsonType":"ObjectId"}]}},"sampleDocuments":[{"_id":{"$oid":"642d766b7300158b1f22e972"}}],"enableStorage":false}'
);
expect(res).to.deep.eq(responses.success);
});
Expand All @@ -241,6 +242,7 @@ describe('AtlasAiService', function () {
databaseName: 'peanut',
requestId: 'abc',
signal: new AbortController().signal,
enableStorage: false,
},
mockConnectionInfo
);
Expand All @@ -263,6 +265,7 @@ describe('AtlasAiService', function () {
sampleDocuments: [{ test: '4'.repeat(5120001) }],
requestId: 'abc',
signal: new AbortController().signal,
enableStorage: false,
},
mockConnectionInfo
);
Expand Down Expand Up @@ -294,6 +297,7 @@ describe('AtlasAiService', function () {
],
requestId: 'abc',
signal: new AbortController().signal,
enableStorage: false,
},
mockConnectionInfo
);
Expand All @@ -302,7 +306,7 @@ describe('AtlasAiService', function () {

expect(fetchStub).to.have.been.calledOnce;
expect(args[1].body).to.eq(
'{"userInput":"test","collectionName":"test.test","databaseName":"peanut","sampleDocuments":[{"a":"1"}]}'
'{"userInput":"test","collectionName":"test.test","databaseName":"peanut","sampleDocuments":[{"a":"1"}],"enableStorage":false}'
);
});
});
Expand Down Expand Up @@ -912,6 +916,7 @@ describe('AtlasAiService', function () {
const mockAtlasService = new MockAtlasService();
await preferences.savePreferences({
enableChatbotEndpointForGenAI: true,
telemetryAtlasUserId: '1234',
});
atlasAiService = new AtlasAiService({
apiURLPreset: 'cloud',
Expand Down Expand Up @@ -1037,6 +1042,7 @@ describe('AtlasAiService', function () {
{ _id: new ObjectId('642d766b7300158b1f22e972') },
],
requestId: 'abc',
enableStorage: true,
};

const res = await atlasAiService[functionName](
Expand All @@ -1047,10 +1053,20 @@ describe('AtlasAiService', function () {
expect(fetchStub).to.have.been.calledOnce;

const { args } = fetchStub.firstCall;
const requestBody = JSON.parse(args[1].body as string);

const requestHeaders = args[1].headers as Record<string, string>;
expect(requestHeaders['x-client-request-id']).to.equal(
input.requestId
);

const requestBody = JSON.parse(args[1].body as string);
expect(requestBody.model).to.equal('mongodb-chat-latest');
expect(requestBody.store).to.equal(false);
const { userId, ...restOfMetadata } = requestBody.metadata;
expect(restOfMetadata).to.deep.equal({
store: 'true',
sensitiveStorage: 'sensitive',
});
expect(userId).to.be.a('string').that.is.not.empty;
expect(requestBody.instructions).to.be.a('string');
expect(requestBody.input).to.be.an('array');

Expand Down Expand Up @@ -1083,6 +1099,7 @@ describe('AtlasAiService', function () {
databaseName: 'peanut',
requestId: 'abc',
signal: new AbortController().signal,
enableStorage: false,
},
mockConnectionInfo
);
Expand Down
50 changes: 47 additions & 3 deletions packages/compass-generative-ai/src/atlas-ai-service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ type GenerativeAiInput = {
sampleDocuments?: Document[];
signal: AbortSignal;
requestId: string;
enableStorage: boolean;
};

// The size/token validation happens on the server, however, we do
Expand Down Expand Up @@ -259,6 +260,37 @@ export type MockDataSchemaResponse = z.infer<
typeof MockDataSchemaResponseShape
>;

async function getHashedActiveUserId(
preferences: PreferencesAccess,
logger: Logger
): Promise<string> {
const { currentUserId, telemetryAnonymousId, telemetryAtlasUserId } =
preferences.getPreferences();
const userId = currentUserId ?? telemetryAnonymousId ?? telemetryAtlasUserId;
if (!userId) {
return 'unknown';
}
try {
const data = new TextEncoder().encode(userId);
const hashBuffer = await crypto.subtle.digest('SHA-256', data);
const hashArray = Array.from(new Uint8Array(hashBuffer));
const hashHex = hashArray
.map((b) => b.toString(16).padStart(2, '0'))
.join('');
return hashHex;
} catch (e) {
logger.log.warn(
logger.mongoLogId(1_001_000_385),
'AtlasAiService',
'Failed to hash user id for AI request',
{
error: (e as Error).message,
}
);
return 'unknown';
}
}

/**
* The type of resource from the natural language query REST API
*/
Expand Down Expand Up @@ -304,7 +336,13 @@ export class AtlasAiService {
PLACEHOLDER_BASE_URL,
this.atlasService.assistantApiEndpoint()
);
return this.atlasService.authenticatedFetch(uri, init);
return this.atlasService.authenticatedFetch(uri, {
...init,
headers: {
...(init?.headers ?? {}),
entrypoint: 'natural-language-to-mql',
},
});
},
// TODO(COMPASS-10125): Switch the model to `mongodb-slim-latest` when
// enabling this feature (to use edu-chatbot for GenAI).
Expand Down Expand Up @@ -445,7 +483,10 @@ export class AtlasAiService {
connectionInfo: ConnectionInfo
) {
if (this.preferences.getPreferences().enableChatbotEndpointForGenAI) {
const message = buildAggregateQueryPrompt(input);
const message = buildAggregateQueryPrompt({
...input,
userId: await getHashedActiveUserId(this.preferences, this.logger),
});
return this.generateQueryUsingChatbot(
message,
validateAIAggregationResponse,
Expand All @@ -467,7 +508,10 @@ export class AtlasAiService {
connectionInfo: ConnectionInfo
) {
if (this.preferences.getPreferences().enableChatbotEndpointForGenAI) {
const message = buildFindQueryPrompt(input);
const message = buildFindQueryPrompt({
...input,
userId: await getHashedActiveUserId(this.preferences, this.logger),
});
return this.generateQueryUsingChatbot(message, validateAIQueryResponse, {
signal: input.signal,
type: 'find',
Expand Down
Loading
Loading