diff --git a/shell/AIShell.Abstraction/NamedPipe.cs b/shell/AIShell.Abstraction/NamedPipe.cs index 624b12ed..a9da0022 100644 --- a/shell/AIShell.Abstraction/NamedPipe.cs +++ b/shell/AIShell.Abstraction/NamedPipe.cs @@ -33,6 +33,21 @@ public enum MessageType : int /// A message from AIShell to command-line shell to send code block. /// PostCode = 4, + + /// + /// A message from AIShell to command-line shell to run a command. + /// + RunCommand = 5, + + /// + /// A message from AIShell to command-line shell to ask for the result of a previous command run. + /// + AskCommandOutput = 6, + + /// + /// A message from command-line shell to AIShell to post the result of a command. + /// + PostResult = 7, } /// @@ -201,6 +216,95 @@ public PostCodeMessage(List codeBlocks) } } +/// +/// Message for . +/// +public sealed class RunCommandMessage : PipeMessage +{ + /// + /// Gets the command to run. + /// + public string Command { get; } + + /// + /// Gets whether the command should be run in blocking mode. + /// + public bool Blocking { get; } + + /// + /// Creates an instance of . + /// + public RunCommandMessage(string command, bool blocking) + : base(MessageType.RunCommand) + { + ArgumentException.ThrowIfNullOrEmpty(command); + + Command = command; + Blocking = blocking; + } +} + +/// +/// Message for . +/// +public sealed class AskCommandOutputMessage : PipeMessage +{ + /// + /// Gets the id of the command to retrieve the output for. + /// + public string CommandId { get; } + + /// + /// Creates an instance of . + /// + public AskCommandOutputMessage(string commandId) + : base(MessageType.AskCommandOutput) + { + ArgumentException.ThrowIfNullOrEmpty(commandId); + CommandId = commandId; + } +} + +/// +/// Message for . +/// +public sealed class PostResultMessage : PipeMessage +{ + /// + /// Gets the result of the command for a blocking 'run_command' too call. + /// Or, for a non-blocking call, gets the id for retrieving the result later. + /// + public string Output { get; } + + /// + /// Gets whether the command execution had any error. + /// i.e. a native command returned a non-zero exit code, or a powershell command threw any errors. + /// + public bool HadError { get; } + + /// + /// Gets a value indicating whether the operation was canceled by the user. + /// + public bool UserCancelled { get; } + + /// + /// Gets the internal exception message that is thrown when trying to run the command. + /// + public string Exception { get; } + + /// + /// Creates an instance of . + /// + public PostResultMessage(string output, bool hadError, bool userCancelled, string exception) + : base(MessageType.PostResult) + { + Output = output; + HadError = hadError; + UserCancelled = userCancelled; + Exception = exception; + } +} + /// /// The base type for common pipe operations. /// @@ -301,7 +405,7 @@ protected async Task GetMessageAsync(CancellationToken cancellation return null; } - if (type > (int)MessageType.PostCode) + if (type > (int)MessageType.PostResult) { _pipeStream.Close(); throw new IOException($"Unknown message type received: {type}. Connection was dropped."); @@ -344,9 +448,12 @@ private static PipeMessage DeserializePayload(int type, ReadOnlySpan bytes { (int)MessageType.PostQuery => JsonSerializer.Deserialize(bytes), (int)MessageType.AskConnection => JsonSerializer.Deserialize(bytes), - (int)MessageType.PostContext => JsonSerializer.Deserialize(bytes), (int)MessageType.AskContext => JsonSerializer.Deserialize(bytes), + (int)MessageType.PostContext => JsonSerializer.Deserialize(bytes), (int)MessageType.PostCode => JsonSerializer.Deserialize(bytes), + (int)MessageType.RunCommand => JsonSerializer.Deserialize(bytes), + (int)MessageType.AskCommandOutput => JsonSerializer.Deserialize(bytes), + (int)MessageType.PostResult => JsonSerializer.Deserialize(bytes), _ => throw new NotSupportedException("Unreachable code"), }; } @@ -465,6 +572,16 @@ public async Task StartProcessingAsync(int timeout, CancellationToken cancellati InvokeOnPostCode((PostCodeMessage)message); break; + case MessageType.RunCommand: + var result = InvokeOnRunCommand((RunCommandMessage)message); + SendMessage(result); + break; + + case MessageType.AskCommandOutput: + var output = InvokeOnAskCommandOutput((AskCommandOutputMessage)message); + SendMessage(output); + break; + default: // Log: unexpected messages ignored. break; @@ -537,6 +654,66 @@ private PostContextMessage InvokeOnAskContext(AskContextMessage message) return null; } + /// + /// Helper to invoke the event. + /// + private PostResultMessage InvokeOnRunCommand(RunCommandMessage message) + { + if (OnRunCommand is null) + { + // Log: event handler not set. + return new PostResultMessage( + output: "Command execution is not supported.", + hadError: true, + userCancelled: false, + exception: null); + } + + try + { + return OnRunCommand(message); + } + catch (Exception e) + { + // Log: exception when invoking 'OnRunCommand' + return new PostResultMessage( + output: "Failed to execute the command due to an internal error.", + hadError: true, + userCancelled: false, + exception: e.Message); + } + } + + /// + /// Helper to invoke the event. + /// + private PostResultMessage InvokeOnAskCommandOutput(AskCommandOutputMessage message) + { + if (OnAskCommandOutput is null) + { + // Log: event handler not set. + return new PostResultMessage( + output: "Retrieving command output is not supported.", + hadError: true, + userCancelled: false, + exception: null); + } + + try + { + return OnAskCommandOutput(message); + } + catch (Exception e) + { + // Log: exception when invoking 'OnAskCommandOutput' + return new PostResultMessage( + output: "Failed to retrieve the command output due to an internal error.", + hadError: true, + userCancelled: false, + exception: e.Message); + } + } + /// /// Event for handling the message. /// @@ -551,6 +728,16 @@ private PostContextMessage InvokeOnAskContext(AskContextMessage message) /// Event for handling the message. /// public event Func OnAskContext; + + /// + /// Event for handling the message. + /// + public event Func OnRunCommand; + + /// + /// Event for handling the message. + /// + public event Func OnAskCommandOutput; } /// @@ -771,4 +958,52 @@ public async Task AskContext(AskContextMessage message, Canc return postContext; } + + /// + /// Run a command in the connected PowerShell session. + /// + /// The message. + /// A cancellation token. + /// A message as the response. + /// Throws when the pipe is closed by the other side. + public async Task RunCommand(RunCommandMessage message, CancellationToken cancellationToken) + { + // Send the request message to the shell. + SendMessage(message); + + // Receiving response from the shell. + var response = await GetMessageAsync(cancellationToken); + if (response is not PostResultMessage postResult) + { + // Log: unexpected message. drop connection. + _client.Close(); + throw new IOException($"Expecting '{MessageType.PostResult}' response, but received '{message.Type}' message."); + } + + return postResult; + } + + /// + /// Ask for the output of a previously run command in the connected PowerShell session. + /// + /// The message. + /// A cancellation token. + /// A message as the response. + /// Throws when the pipe is closed by the other side. + public async Task AskCommandOutput(AskCommandOutputMessage message, CancellationToken cancellationToken) + { + // Send the request message to the shell. + SendMessage(message); + + // Receiving response from the shell. + var response = await GetMessageAsync(cancellationToken); + if (response is not PostResultMessage postResult) + { + // Log: unexpected message. drop connection. + _client.Close(); + throw new IOException($"Expecting '{MessageType.PostResult}' response, but received '{message.Type}' message."); + } + + return postResult; + } } diff --git a/shell/AIShell.Integration/AIShell.psd1 b/shell/AIShell.Integration/AIShell.psd1 index 162a7171..3d40b6d6 100644 --- a/shell/AIShell.Integration/AIShell.psd1 +++ b/shell/AIShell.Integration/AIShell.psd1 @@ -10,9 +10,9 @@ PowerShellVersion = '7.4.6' PowerShellHostName = 'ConsoleHost' FunctionsToExport = @() - CmdletsToExport = @('Start-AIShell','Invoke-AIShell','Resolve-Error') + CmdletsToExport = @('Start-AIShell','Invoke-AIShell', 'Invoke-AICommand', 'Resolve-Error') VariablesToExport = '*' - AliasesToExport = @('aish', 'askai', 'fixit') + AliasesToExport = @('aish', 'askai', 'fixit', 'airun') HelpInfoURI = 'https://aka.ms/aishell-help' PrivateData = @{ PSData = @{ Prerelease = 'preview5'; ProjectUri = 'https://github.com/PowerShell/AIShell' } } } diff --git a/shell/AIShell.Integration/Channel.cs b/shell/AIShell.Integration/Channel.cs index 7561167f..28e63678 100644 --- a/shell/AIShell.Integration/Channel.cs +++ b/shell/AIShell.Integration/Channel.cs @@ -25,7 +25,8 @@ public class Channel : IDisposable private readonly object _psrlSingleton; private readonly ManualResetEvent _connSetupWaitHandler; private readonly Predictor _predictor; - private readonly ScriptBlock _onIdleAction; + private readonly ScriptBlock _onIdlePostAction; + private readonly ScriptBlock _onIdleRunAction; private readonly List _commandHistory; private PathInfo _currentLocation; @@ -35,6 +36,8 @@ public class Channel : IDisposable private Exception _exception; private Thread _serverThread; private CodePostData _pendingPostCodeData; + private RunCommandRequest _runCommandRequest; + private PowerShell _pwsh; private Channel(Runspace runspace, EngineIntrinsics intrinsics, Type psConsoleReadLineType) { @@ -70,7 +73,8 @@ private Channel(Runspace runspace, EngineIntrinsics intrinsics, Type psConsoleRe _commandHistory = []; _predictor = new Predictor(); - _onIdleAction = ScriptBlock.Create("[AIShell.Integration.Channel]::Singleton.OnIdleHandler()"); + _onIdlePostAction = ScriptBlock.Create("[AIShell.Integration.Channel]::Singleton.OnIdlePostHandler()"); + _onIdleRunAction = ScriptBlock.Create("[AIShell.Integration.Channel]::Singleton.OnIdleRunHandler()"); } public static Channel CreateSingleton(Runspace runspace, EngineIntrinsics intrinsics, Type psConsoleReadLineType) @@ -106,6 +110,29 @@ internal bool CheckConnection(bool blocking, out bool setupInProgress) return false; } + /// + /// A 'run_command' tool call request will set '_runCommandRequest' properly with the 'Result' property being null. + /// For a blocking call, it will set '_runCommandRequest' back to null once the call result has been set. + /// For an unblocking call, '_runCommandRequest' will remain as is until: + /// 1. a 'get_output' request comes to collect the result, or + /// 2. another 'run_command' request comes to run a new command. + /// So we consider there is a pending request only if '_runCommandRequest' is not null and its 'Result' property is null. + /// + internal string GetRunCommandRequest() => + _runCommandRequest is { Result: null } ? _runCommandRequest.Command : null; + + /// + /// Set the command result for a 'run_command' tool call request. + /// + internal void SetRunCommandResult(bool hadErrors, bool userCancelled, List errorAndOutput) + { + if (_runCommandRequest is { Result: null }) + { + _runCommandRequest.Result = new(hadErrors, userCancelled, errorAndOutput); + _runCommandRequest.Event?.Set(); + } + } + public string StartChannelSetup() { if (_serverPipe is not null) @@ -123,12 +150,14 @@ public string StartChannelSetup() _serverPipe.OnAskConnection += OnAskConnection; _serverPipe.OnAskContext += OnAskContext; _serverPipe.OnPostCode += OnPostCode; + _serverPipe.OnRunCommand += OnRunCommand; + _serverPipe.OnAskCommandOutput += OnAskCommandOutput; _serverThread = new Thread(ThreadProc) - { - IsBackground = true, - Name = "pwsh channel thread" - }; + { + IsBackground = true, + Name = "pwsh channel thread" + }; _serverThread.Start(); return _shellPipeName; @@ -183,6 +212,10 @@ private void RunspaceAvailableAction(object sender, RunspaceAvailabilityEventArg _commandHistory.AddRange(results); } } + catch + { + // Ignore unexpected exceptions. + } finally { pwshRunspace.AvailabilityChanged += RunspaceAvailableAction; @@ -255,6 +288,8 @@ private void Reset() _serverPipe.OnAskConnection -= OnAskConnection; _serverPipe.OnAskContext -= OnAskContext; _serverPipe.OnPostCode -= OnPostCode; + _serverPipe.OnRunCommand -= OnRunCommand; + _serverPipe.OnAskCommandOutput -= OnAskCommandOutput; } _serverPipe = null; @@ -284,7 +319,7 @@ private void ThrowIfNotConnected() } [Hidden()] - public void OnIdleHandler() + public void OnIdlePostHandler() { if (_pendingPostCodeData is not null) { @@ -294,6 +329,17 @@ public void OnIdleHandler() } } + [Hidden()] + public void OnIdleRunHandler() + { + if (_pendingPostCodeData is not null) + { + PSRLInsert(_pendingPostCodeData.CodeToInsert); + PSRLAcceptLine(); + _pendingPostCodeData = null; + } + } + private void OnPostCode(PostCodeMessage postCodeMessage) { // Ignore 'code post' request when a posting operation is on-going. @@ -351,7 +397,7 @@ private void OnPostCode(PostCodeMessage postCodeMessage) eventName: null, sourceIdentifier: PSEngineEvent.OnIdle, data: null, - action: _onIdleAction, + action: _onIdlePostAction, supportEvent: true, forwardEvent: false, maxTriggerCount: 1); @@ -448,6 +494,130 @@ private void OnAskConnection(ShellClientPipe clientPipe, Exception exception) _connSetupWaitHandler.Set(); } + private PostResultMessage OnRunCommand(RunCommandMessage runCommandMessage) + { + // Ignore 'run_command' request when a code posting operation is on-going. + if (_pendingPostCodeData is not null) + { + return new PostResultMessage( + output: "Cannot run command at the moment. Try again later.", + hadError: true, + userCancelled: false, + exception: null); + } + + string command = runCommandMessage.Command.Replace("\r\n", "\n"); + _runCommandRequest = new(command, runCommandMessage.Blocking); + + string codeToInsert = command.Contains('\n') + ? $$""" + airun { + {{command}} + } + """ + : $"airun {{ {command} }}"; + + // When PSReadLine is actively running, its '_readLineReady' field should be set to 'true'. + // When the value is 'false', it means PowerShell is still busy running scripts or commands. + if (_psrlReadLineReady.GetValue(_psrlSingleton) is true) + { + PSRLRevertLine(); + } + + _pendingPostCodeData = new CodePostData(codeToInsert, null); + // We use script block handler instead of a delegate handler because the latter will run + // in a background thread, while the former will run in the pipeline thread, which is way + // more predictable. + _runspace.Events.SubscribeEvent( + source: null, + eventName: null, + sourceIdentifier: PSEngineEvent.OnIdle, + data: null, + action: _onIdleRunAction, + supportEvent: true, + forwardEvent: false, + maxTriggerCount: 1); + + if (runCommandMessage.Blocking) + { + // Wait for the call to finish. + _runCommandRequest.Event.Wait(); + RunCommandResult result = _runCommandRequest.Result; + + string output = result.ErrorAndOutput.Count is 0 + ? string.Empty + : (_pwsh ??= PowerShell.Create()) + .AddCommand("Out-String") + .AddParameter("InputObject", result.ErrorAndOutput) + .AddParameter("Width", 120) + .InvokeAndCleanup()[0]; + + PostResultMessage response = new( + output: output, + hadError: result.HadErrors, + userCancelled: result.UserCancelled, + exception: null); + + _runCommandRequest.Dispose(); + _runCommandRequest = null; + + return response; + } + + return new PostResultMessage(output: _runCommandRequest.Id, hadError: false, userCancelled: false, exception: null); + } + + private PostResultMessage OnAskCommandOutput(AskCommandOutputMessage askOutputMessage) + { + if (_runCommandRequest is null) + { + return new PostResultMessage( + output: "No command was previously run in background, or the output of a background command was already retrieved.", + hadError: true, + userCancelled: false, + exception: null); + } + + string commandId = askOutputMessage.CommandId; + if (!string.Equals(commandId, _runCommandRequest.Id, StringComparison.OrdinalIgnoreCase)) + { + return new PostResultMessage( + output: $"The specified command id '{commandId}' cannot be found.", + hadError: true, + userCancelled: false, + exception: null); + } + + if (_runCommandRequest.Result is null) + { + return new PostResultMessage( + output: "Command output is not yet available.", + hadError: true, + userCancelled: false, + exception: null); + } + + RunCommandResult result = _runCommandRequest.Result; + string output = result.ErrorAndOutput.Count is 0 + ? string.Empty + : (_pwsh ??= PowerShell.Create()) + .AddCommand("Out-String") + .AddParameter("InputObject", result.ErrorAndOutput) + .AddParameter("Width", 120) + .InvokeAndCleanup()[0]; + + PostResultMessage response = new( + output: output, + hadError: result.HadErrors, + userCancelled: result.UserCancelled, + exception: null); + + _runCommandRequest.Dispose(); + _runCommandRequest = null; + + return response; + } + private void PSRLInsert(string text) { using var _ = new NoWindowResizingCheck(); diff --git a/shell/AIShell.Integration/Commands/InvokeAiCommand.cs b/shell/AIShell.Integration/Commands/InvokeAiCommand.cs new file mode 100644 index 00000000..c5d2dc90 --- /dev/null +++ b/shell/AIShell.Integration/Commands/InvokeAiCommand.cs @@ -0,0 +1,90 @@ +namespace AIShell.Integration.Commands; + +using System.Management.Automation; + +[Alias("airun")] +[Cmdlet(VerbsLifecycle.Invoke, "AICommand")] +public sealed class InvokeAICommand : PSCmdlet, IDisposable +{ + private readonly PowerShell _pwsh; + private readonly PSDataCollection _output; + + private bool _disposed, _hadErrors, _cancelled; + private List _capturedContent; + + [Parameter(Mandatory = true, Position = 0)] + public ScriptBlock Command { get; set; } + + public InvokeAICommand() + { + _pwsh = PowerShell.Create(RunspaceMode.CurrentRunspace); + _pwsh.Streams.Error.DataAdding += DataAddingHandler; + + _output = []; + _output.DataAdding += DataAddingHandler; + _capturedContent = null; + } + + /// + /// The handler for both 'OutputDataAdding' and 'ErrorDataAdding' events. + /// + /// + /// The handler is called on the pipeline thread, so it's safe to call 'WriteObject' in it. + /// + private void DataAddingHandler(object sender, DataAddingEventArgs e) + { + object item = e.ItemAdded; + _capturedContent?.Add(item); + WriteObject(item); + } + + protected override void EndProcessing() + { + string commandToRun = Command.ToString(); + string requestedCommand = Channel.Singleton.GetRunCommandRequest(); + + if (requestedCommand is not null && commandToRun.Contains(requestedCommand)) + { + // Only capture output when this is a tool call invoked by AI. + _capturedContent = []; + } + + try + { + _pwsh.AddScript(commandToRun, useLocalScope: false); + _pwsh.Invoke(input: null, _output, settings: null); + } + finally + { + _hadErrors = _pwsh.HadErrors; + } + } + + protected override void StopProcessing() + { + _pwsh.Stop(); + _cancelled = true; + } + + /// + /// Dispose the resources. + /// + public void Dispose() + { + if (_disposed) + { + return; + } + + if (_capturedContent is { }) + { + Channel.Singleton.SetRunCommandResult(_hadErrors, _cancelled, _capturedContent); + } + + _output.DataAdding -= DataAddingHandler; + _output.Dispose(); + _pwsh.Streams.Error.DataAdding -= DataAddingHandler; + _pwsh.Dispose(); + _disposed = true; + } +} diff --git a/shell/AIShell.Integration/RunInTerminal.cs b/shell/AIShell.Integration/RunInTerminal.cs new file mode 100644 index 00000000..da9d28d5 --- /dev/null +++ b/shell/AIShell.Integration/RunInTerminal.cs @@ -0,0 +1,40 @@ +namespace AIShell.Integration; + +internal class RunCommandRequest : IDisposable +{ + internal string Id { get; } + internal string Command { get; } + internal ManualResetEventSlim Event { get; } + internal RunCommandResult Result { get; set; } + + internal RunCommandRequest(string command, bool blockingCall) + { + ArgumentException.ThrowIfNullOrEmpty(command); + + Id = Guid.NewGuid().ToString(); + Command = command; + Event = blockingCall ? new() : null; + Result = null; + } + + public void Dispose() + { + Event?.Dispose(); + } +} + +internal class RunCommandResult +{ + internal bool HadErrors { get; } + internal bool UserCancelled { get; } + internal List ErrorAndOutput { get; } + + internal RunCommandResult(bool hadErrors, bool userCancelled, List errorAndOutput) + { + ArgumentNullException.ThrowIfNull(errorAndOutput); + + HadErrors = hadErrors; + UserCancelled = userCancelled; + ErrorAndOutput = errorAndOutput; + } +} diff --git a/shell/AIShell.Kernel/Host.cs b/shell/AIShell.Kernel/Host.cs index 7199f4cd..d1101254 100644 --- a/shell/AIShell.Kernel/Host.cs +++ b/shell/AIShell.Kernel/Host.cs @@ -570,9 +570,9 @@ internal void RenderReferenceText(string header, string content) /// /// The MCP tool. /// The arguments in JSON form to be sent for the tool call. - internal void RenderToolCallRequest(McpTool tool, string jsonArgs) + internal void RenderMcpToolCallRequest(McpTool tool, string jsonArgs) { - RequireStdoutOrStderr(operation: "render tool call request"); + RequireStdoutOrStderr(operation: "render MCP tool call request"); IAnsiConsole ansiConsole = _outputRedirected ? _stderrConsole : AnsiConsole.Console; bool hasArgs = !string.IsNullOrEmpty(jsonArgs); @@ -610,6 +610,44 @@ internal void RenderToolCallRequest(McpTool tool, string jsonArgs) FancyStreamRender.ConsoleUpdated(); } + /// + /// Render the built-in tool call request. + /// + internal void RenderBuiltInToolCallRequest(string toolName, string description, Tuple argument) + { + RequireStdoutOrStderr(operation: "render built-in tool call request"); + IAnsiConsole ansiConsole = _outputRedirected ? _stderrConsole : AnsiConsole.Console; + + bool hasArgs = argument is not null; + string argLine = hasArgs ? $"{argument.Item1}:" : $"Input: "; + IRenderable content = new Markup($""" + + [bold]Run [olive]{toolName}[/] from [olive]{McpManager.BuiltInServerName}[/] (Built-in tool)[/] + + {description} + + {argLine} + """); + + if (hasArgs) + { + content = new Grid() + .AddColumn(new GridColumn()) + .AddRow(content) + .AddRow(argument.Item2.EscapeMarkup()); + } + + var panel = new Panel(content) + .Expand() + .RoundedBorder() + .Header("[green] Tool Call Request [/]") + .BorderColor(Color.Grey); + + ansiConsole.WriteLine(); + ansiConsole.Write(panel); + FancyStreamRender.ConsoleUpdated(); + } + /// /// Render a table with information about available MCP servers and tools. /// @@ -672,7 +710,13 @@ internal void RenderMcpServersAndTools(McpManager mcpManager) toolTable.AddRow($"[olive underline]{McpManager.BuiltInServerName}[/]", "[green]\u2713 Ready[/]", string.Empty); foreach (var item in mcpManager.BuiltInTools) { - toolTable.AddRow(string.Empty, item.Key.EscapeMarkup(), item.Value.Description.EscapeMarkup()); + string description = item.Value.Description; + int index = description.IndexOf('\n'); + if (index > 0) + { + description = description[..index].Trim(); + } + toolTable.AddRow(string.Empty, item.Key.EscapeMarkup(), description.EscapeMarkup()); } } diff --git a/shell/AIShell.Kernel/MCP/BuiltInTool.cs b/shell/AIShell.Kernel/MCP/BuiltInTool.cs index 45dd3da2..49580da7 100644 --- a/shell/AIShell.Kernel/MCP/BuiltInTool.cs +++ b/shell/AIShell.Kernel/MCP/BuiltInTool.cs @@ -1,6 +1,7 @@ using AIShell.Abstraction; using Microsoft.Extensions.AI; using System.Diagnostics; +using System.Text; using System.Text.Json; namespace AIShell.Kernel.Mcp; @@ -16,7 +17,7 @@ private enum ToolType : int copy_text_to_clipboard = 4, post_code_to_terminal = 5, run_command_in_terminal = 6, - get_terminal_output = 7, + get_command_output = 7, NumberOfBuiltInTools = 8 }; @@ -58,15 +59,15 @@ private enum ToolType : int Background Processes: - For long-running tasks (e.g., servers), set `isBackground=true`. - - Returns a terminal ID for checking status and runtime later. + - Returns a command ID for checking status and output later. Important Notes: - If the command may produce excessively large output, use head or tail to reduce the output. - If a command may use a pager, you must add something to disable it. For example, you can use `git --no-pager`. Otherwise you should add something like ` | cat`. Examples: git, less, man, etc. """, - // get_terminal_output - "Get the output of a command previous started with `run_command_in_terminal`" + // get_command_output + "Get the output of a command previously started with `run_command_in_terminal`." ]; private static readonly string[] s_toolSchema = @@ -171,7 +172,7 @@ private enum ToolType : int }, "isBackground": { "type": "boolean", - "description": "Whether the command starts a background process. If true, the command will run in the background and you will not see the output. If false, the tool call will block on the command finishing, and then you will get the output. Examples of backgrond processes: building in watch mode, starting a server. You can check the output of a backgrond process later on by using get_terminal_output." + "description": "Whether the command starts a background process. If true, the command will run in the background and you will not see the output. If false, the tool call will block on the command finishing, and then you will get the output. Examples of backgrond processes: building in watch mode, starting a server. You can check the output of a backgrond process later on by using `get_command_output`." } }, "required": [ @@ -184,7 +185,7 @@ private enum ToolType : int } """, - // get_terminal_output + // get_command_output """ { "type": "object", @@ -248,6 +249,22 @@ protected override async ValueTask InvokeCoreAsync(AIFunctionArguments a return postContextMsg.ContextInfo; } + if (response is PostResultMessage postResultMsg) + { + StringBuilder strb = new(postResultMsg.Output.Length + 40); + strb.AppendLine("### Status") + .AppendLine(postResultMsg.UserCancelled + ? "Execution was cancelled by the user." + : postResultMsg.HadError ? "Had error." : "Succeeded.") + .AppendLine() + .AppendLine("### Output") + .AppendLine("```") + .AppendLine(postResultMsg.Output.Trim()) + .AppendLine("```"); + + return strb.ToString(); + } + return response is null ? "Success: Function completed." : JsonSerializer.SerializeToElement(response); } @@ -267,6 +284,7 @@ internal async Task CallAsync( IReadOnlyDictionary arguments = null, CancellationToken cancellationToken = default) { + PipeMessage response = null; AskContextMessage contextRequest = _toolType switch { ToolType.get_working_directory => new(ContextType.CurrentLocation), @@ -280,12 +298,8 @@ internal async Task CallAsync( _ => null }; - bool succeeded = false; - PostContextMessage response = null; - if (contextRequest is not null) { - succeeded = true; response = await _shell.Host.RunWithSpinnerAsync( async () => await _shell.Channel.AskContext(contextRequest, cancellationToken), status: $"Running '{_toolName}'", @@ -294,39 +308,74 @@ internal async Task CallAsync( else if (_toolType is ToolType.copy_text_to_clipboard) { TryGetArgumentValue(arguments, "content", out string content); - if (string.IsNullOrEmpty(content)) { throw new ArgumentException("The 'content' argument is required for the 'copy_text_to_clipboard' tool."); } - succeeded = true; Clipboard.SetText(content); } else if (_toolType is ToolType.post_code_to_terminal) { TryGetArgumentValue(arguments, "command", out string command); - if (string.IsNullOrEmpty(command)) { throw new ArgumentException("The 'command' argument is required for the 'post_code_to_terminal' tool."); } - succeeded = true; _shell.Channel.PostCode(new PostCodeMessage([command])); } + else if (_toolType is ToolType.run_command_in_terminal) + { + TryGetArgumentValue(arguments, "command", out string command); + TryGetArgumentValue(arguments, "explanation", out string explanation); + TryGetArgumentValue(arguments, "isBackground", out bool isBackground); - if (succeeded) + if (string.IsNullOrEmpty(command)) + { + throw new ArgumentException("The 'command' argument is required for the 'run_command_in_terminal' tool."); + } + if (string.IsNullOrEmpty(explanation)) + { + throw new ArgumentException("The 'explanation' argument is required for the 'run_command_in_terminal' tool."); + } + + _shell.Host.RenderBuiltInToolCallRequest(OriginalName, explanation, Tuple.Create("command", command)); + // Prompt for user's approval to call the tool. + const string title = "\n\u26A0 Malicious conversation content may attempt to misuse 'AIShell' through the built-in tools. Please carefully review any requested actions to decide if you want to proceed."; + string choice = await _shell.Host.PromptForSelectionAsync( + title: title, + choices: McpTool.UserChoices, + cancellationToken: cancellationToken); + + if (choice is "Cancel") + { + _shell.Host.MarkupLine($"\n [red]\u2717[/] Cancelled '{OriginalName}'"); + throw new OperationCanceledException("The call was rejected by user."); + } + + response = await _shell.Host.RunWithSpinnerAsync( + async () => await _shell.Channel.RunCommand(new RunCommandMessage(command, blocking: !isBackground), cancellationToken), + status: $"Running '{_toolName}'", + spinnerKind: SpinnerKind.Processing); + } + else if (_toolType is ToolType.get_command_output) { - // Notify the user about this tool call. - _shell.Host.MarkupLine($"\n [green]\u2713[/] Ran '{_toolName}'"); - // Signal any active stream reander about the output - FancyStreamRender.ConsoleUpdated(); + TryGetArgumentValue(arguments, "id", out string id); + if (string.IsNullOrEmpty(id)) + { + throw new ArgumentException("The 'id' argument is required for the 'get_command_output' tool."); + } - return response; + response = await _shell.Channel.AskCommandOutput(new AskCommandOutputMessage(id), cancellationToken); } - throw new NotSupportedException($"Tool type '{_toolType}' is not yet supported."); + // Notify the user about this tool call. + _shell.Host.MarkupLine($"\n [green]\u2713[/] Ran '{_toolName}'"); + + // Signal any active stream render about the output. + FancyStreamRender.ConsoleUpdated(); + return response; } private static bool TryGetArgumentValue(IReadOnlyDictionary arguments, string argName, out T value) @@ -398,8 +447,7 @@ internal static Dictionary GetBuiltInTools(Shell shell) { ArgumentNullException.ThrowIfNull(shell); - // We don't have the 'run_command' and 'get_terminal_output' tools yet. Will use 'ToolType.NumberOfBuiltInTools' when all tools are ready. - int toolCount = (int)ToolType.run_command_in_terminal; + int toolCount = (int)ToolType.NumberOfBuiltInTools; Debug.Assert(s_toolDescription.Length == (int)ToolType.NumberOfBuiltInTools, "Number of tool descriptions doesn't match the number of tools."); Debug.Assert(s_toolSchema.Length == (int)ToolType.NumberOfBuiltInTools, "Number of tool schemas doesn't match the number of tools."); diff --git a/shell/AIShell.Kernel/MCP/McpTool.cs b/shell/AIShell.Kernel/MCP/McpTool.cs index 97142485..a78a21ac 100644 --- a/shell/AIShell.Kernel/MCP/McpTool.cs +++ b/shell/AIShell.Kernel/MCP/McpTool.cs @@ -113,10 +113,10 @@ internal async ValueTask CallAsync( string jsonArgs = arguments is { Count: > 0 } ? JsonSerializer.Serialize(arguments, serializerOptions ?? JsonSerializerOptions) : null; - _host.RenderToolCallRequest(this, jsonArgs); + _host.RenderMcpToolCallRequest(this, jsonArgs); // Prompt for user's approval to call the tool. - const string title = "\n\u26A0 MCP servers or malicious converstaion content may attempt to misuse 'AIShell' through the installed tools. Please carefully review any requested actions to decide if you want to proceed."; + const string title = "\n\u26A0 MCP servers or malicious conversation content may attempt to misuse 'AIShell' through the installed tools. Please carefully review any requested actions to decide if you want to proceed."; string choice = await _host.PromptForSelectionAsync( title: title, choices: UserChoices, diff --git a/shell/AIShell.Kernel/ShellIntegration/Channel.cs b/shell/AIShell.Kernel/ShellIntegration/Channel.cs index dfdc7054..78249645 100644 --- a/shell/AIShell.Kernel/ShellIntegration/Channel.cs +++ b/shell/AIShell.Kernel/ShellIntegration/Channel.cs @@ -212,6 +212,24 @@ internal async Task AskContext(AskContextMessage message, Ca return await _clientPipe.AskContext(message, cancellationToken); } + /// + /// Run command in the connected shell. + /// + internal async Task RunCommand(RunCommandMessage message, CancellationToken cancellationToken) + { + ThrowIfNotConnected(); + return await _clientPipe.RunCommand(message, cancellationToken); + } + + /// + /// Ask for the output of a command that was previously run in the connected shell. + /// + internal async Task AskCommandOutput(AskCommandOutputMessage message, CancellationToken cancellationToken) + { + ThrowIfNotConnected(); + return await _clientPipe.AskCommandOutput(message, cancellationToken); + } + public void Dispose() { if (_disposed) diff --git a/shell/agents/AIShell.OpenAI.Agent/Helpers.cs b/shell/agents/AIShell.OpenAI.Agent/Helpers.cs index c4df2e5e..93e2f253 100644 --- a/shell/agents/AIShell.OpenAI.Agent/Helpers.cs +++ b/shell/agents/AIShell.OpenAI.Agent/Helpers.cs @@ -195,7 +195,7 @@ internal static class Prompt internal static string SystemPromptWithConnectedPSSession = $""" You are a virtual assistant in **AIShell**, specializing in PowerShell and other command-line tools. - You are connected to an interactive PowerShell session and can retrieve session context and interact with the session using built-in tools. When user queries are ambiguous or minimal, rely on session context to better understand intent and deliver accurate, helpful responses.. + You are connected to an interactive PowerShell session and can retrieve session context and run commands in the session using built-in tools. When user queries are ambiguous or minimal, rely on session context to better understand intent and deliver accurate, helpful responses. Your primary function is to assist users with accomplishing tasks and troubleshooting errors in the command line. Autonomously resolve the user's query to the best of your ability before returning with a response.