Finalize ChatRequest#model (#231811)

* Finalize ChatRequest#model
Fix #230844

* Register model to fix tests
This commit is contained in:
Rob Lourens 2024-10-21 16:42:23 -07:00 committed by GitHub
parent fb42702b12
commit d9f379b135
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 57 additions and 13 deletions

View File

@ -5,7 +5,7 @@
import * as assert from 'assert';
import 'mocha';
import { ChatContext, ChatRequest, ChatResult, ChatVariableLevel, Disposable, Event, EventEmitter, chat, commands } from 'vscode';
import { ChatContext, ChatRequest, ChatResult, ChatVariableLevel, Disposable, Event, EventEmitter, chat, commands, lm } from 'vscode';
import { DeferredPromise, asPromise, assertNoRpc, closeAllEditors, delay, disposeAll } from '../utils';
suite('chat', () => {
@ -13,6 +13,25 @@ suite('chat', () => {
let disposables: Disposable[] = [];
setup(() => {
disposables = [];
// Register a dummy default model which is required for a participant request to go through
disposables.push(lm.registerChatModelProvider('test-lm', {
async provideLanguageModelResponse(_messages, _options, _extensionId, _progress, _token) {
return undefined;
},
async provideTokenCount(_text, _token) {
return 1;
},
}, {
name: 'test-lm',
version: '1.0.0',
family: 'test',
vendor: 'test-lm-vendor',
maxInputTokens: 100,
maxOutputTokens: 100,
isDefault: true,
isUserSelectable: true
}));
});
teardown(async function () {

View File

@ -369,10 +369,8 @@ export class ExtHostChatAgents2 extends Disposable implements ExtHostChatAgentsS
return undefined;
}
const extRequest = typeConvert.ChatAgentRequest.to(request, location);
if (request.userSelectedModelId && isProposedApiEnabled(detector.extension, 'chatParticipantAdditions')) {
extRequest.userSelectedModel = await this._languageModels.getLanguageModelByIdentifier(detector.extension, request.userSelectedModelId);
}
const model = await this.getModelForRequest(request, detector.extension);
const extRequest = typeConvert.ChatAgentRequest.to(request, location, model);
return detector.provider.provideParticipantDetection(
extRequest,
@ -405,6 +403,21 @@ export class ExtHostChatAgents2 extends Disposable implements ExtHostChatAgentsS
return { request, location, history: convertedHistory };
}
private async getModelForRequest(request: IChatAgentRequest, extension: IExtensionDescription): Promise<vscode.LanguageModelChat> {
let model: vscode.LanguageModelChat | undefined;
if (request.userSelectedModelId) {
model = await this._languageModels.getLanguageModelByIdentifier(extension, request.userSelectedModelId);
}
if (!model) {
model = await this._languageModels.getDefaultLanguageModel(extension);
if (!model) {
throw new Error('Language model unavailable');
}
}
return model;
}
async $invokeAgent(handle: number, requestDto: Dto<IChatAgentRequest>, context: { history: IChatAgentHistoryEntryDto[] }, token: CancellationToken): Promise<IChatAgentResult | undefined> {
const agent = this._agents.get(handle);
if (!agent) {
@ -425,10 +438,8 @@ export class ExtHostChatAgents2 extends Disposable implements ExtHostChatAgentsS
stream = new ChatAgentResponseStream(agent.extension, request, this._proxy, this._commands.converter, sessionDisposables);
const extRequest = typeConvert.ChatAgentRequest.to(request, location);
if (request.userSelectedModelId && isProposedApiEnabled(agent.extension, 'chatParticipantAdditions')) {
extRequest.userSelectedModel = await this._languageModels.getLanguageModelByIdentifier(agent.extension, request.userSelectedModelId);
}
const model = await this.getModelForRequest(request, agent.extension);
const extRequest = typeConvert.ChatAgentRequest.to(request, location, model);
const task = agent.invoke(
extRequest,

View File

@ -292,6 +292,15 @@ export class ExtHostLanguageModels implements ExtHostLanguageModelsShape {
this._onDidChangeProviders.fire(undefined);
}
async getDefaultLanguageModel(extension: IExtensionDescription): Promise<vscode.LanguageModelChat | undefined> {
const defaultModelId = Iterable.find(this._allLanguageModelData.entries(), ([, value]) => !!value.metadata.isDefault)?.[0];
if (!defaultModelId) {
return;
}
return this.getLanguageModelByIdentifier(extension, defaultModelId);
}
async getLanguageModelByIdentifier(extension: IExtensionDescription, identifier: string): Promise<vscode.LanguageModelChat | undefined> {
const data = this._allLanguageModelData.get(identifier);

View File

@ -2765,7 +2765,7 @@ export namespace ChatResponsePart {
}
export namespace ChatAgentRequest {
export function to(request: IChatAgentRequest, location2: vscode.ChatRequestEditorData | vscode.ChatRequestNotebookData | undefined): vscode.ChatRequest {
export function to(request: IChatAgentRequest, location2: vscode.ChatRequestEditorData | vscode.ChatRequestNotebookData | undefined, model: vscode.LanguageModelChat): vscode.ChatRequest {
const toolReferences = request.variables.variables.filter(v => v.isTool);
const variableReferences = request.variables.variables.filter(v => !v.isTool);
return {
@ -2780,7 +2780,8 @@ export namespace ChatAgentRequest {
acceptedConfirmationData: request.acceptedConfirmationData,
rejectedConfirmationData: request.rejectedConfirmationData,
location2,
toolInvocationToken: Object.freeze({ sessionId: request.sessionId })
toolInvocationToken: Object.freeze({ sessionId: request.sessionId }),
model
};
}
}

View File

@ -19086,6 +19086,12 @@ declare module 'vscode' {
* string-manipulation of the prompt.
*/
readonly references: readonly ChatPromptReference[];
/**
* This is the model that is currently selected in the UI. Extensions can use this or use {@link chat.selectChatModels} to
* pick another model. Don't hold onto this past the lifetime of the request.
*/
readonly model: LanguageModelChat;
}
/**

View File

@ -223,8 +223,6 @@ declare module 'vscode' {
* The `data` for any confirmations that were rejected
*/
rejectedConfirmationData?: any[];
userSelectedModel?: LanguageModelChat;
}
// TODO@API fit this into the stream