diff --git a/.skills/csharp-async-best-practices/SKILL.md b/.skills/csharp-async-best-practices/SKILL.md new file mode 100644 index 0000000000..fd4f6fa159 --- /dev/null +++ b/.skills/csharp-async-best-practices/SKILL.md @@ -0,0 +1,120 @@ +--- +name: csharp-async-best-practices +description: Use when reviewing, writing, refactoring, or designing c# async code that uses task, task-generic, valuetask, cancellationtoken, task.whenall, task.whenany, task.run, configureawait, async void, or fire-and-forget patterns. Trigger on `.result`, `.wait()`, deadlocks, cancellation propagation, asp.net core background work, ui responsiveness, exception flow, and performance-sensitive async api design. +metadata: + category: technique + triggers: + - c# + - async + - task + - valuetask + - cancellationtoken + - configureawait + - .result + - .wait() + - async void + - fire-and-forget + - task.run + - whenall + - whenany + - asp.net core + - deadlock +--- + +# C# Async Best Practices + +## Overview + +Apply evidence-backed async guidance with this priority order: + +1. correctness and cancellation semantics +2. context-specific API design +3. concurrency behavior and failure handling +4. performance tuning only when the hot path is real + +Treat blanket advice as suspect. Separate official behavior from expert interpretation and from your own recommendation. + +## Workflow + +1. Classify the code before judging it. + - **I/O-bound async**: network, file, database, timers, async waits + - **CPU-bound work**: expensive computation + - **Context**: library, UI app, ASP.NET Core app, background service, test code + - **Pressure**: hot path or ordinary path +2. Prefer the least surprising correct design. +3. Only optimize allocations or scheduling after the correctness story is sound. +4. Load the matching reference file before making strong claims. + - **General rules and code review defaults**: `references/core-guidance.md` + - **Context-sensitive rules**: `references/context-and-tradeoffs.md` + - **Source notes and authority breakdown**: `references/source-notes.md` + +## Review defaults + +Start from these defaults unless the case-specific evidence says otherwise: + +| Topic | Default judgment | +|---|---| +| Blocking on async | usually a defect or interop boundary smell | +| `async void` | only acceptable for event handlers | +| `ValueTask` | avoid by default; justify with measurements or a very hot path | +| `ConfigureAwait(false)` | good library default, not an app-wide default | +| `Task.Run` | use to offload CPU work when needed, not to fake async I/O | +| Fire-and-forget | assume unsafe until lifecycle, scope, and exception handling are explicit | +| `Task.WhenAll` | prefer for independent concurrent operations | +| `Task.WhenAny` | always inspect winner and define what happens to losers | +| Cancellation | accept and propagate token until the point of no cancellation | + +## Output contract + +When you review or design code, label your reasoning like this: + +- **Fact**: official runtime or API behavior +- **Expert guidance**: interpretation from strong experts when it adds design meaning +- **Synthesis**: your recommendation for this exact case + +Do not present contextual advice as a universal law. + +## Common traps + +- Calling `.Result`, `.Wait()`, or `GetAwaiter().GetResult()` inside normal async-capable code +- Recommending `ConfigureAwait(false)` everywhere because “it is .NET Core” or “it prevents deadlocks” +- Recommending `Task.Run` inside ASP.NET Core request code just to make code “more async” +- Recommending `ValueTask` for every hot-looking method without checking completion behavior, call frequency, or single-consumer assumptions +- Ignoring cancellation after plumbing a `CancellationToken` +- Using `Task.WhenAny` without awaiting the returned winner task or handling the remaining tasks +- Treating fire-and-forget as harmless when it touches scoped services, `HttpContext`, or unobserved failures + +## Rationalization traps + +| Rationalization | Better reasoning | +|---|---| +| “It works, so `.Result` is fine.” | Lack of failure under one context does not make blocking safe or scalable. | +| “`ValueTask` is always faster.” | It trades simplicity for niche allocation wins and stricter consumption rules. | +| “`ConfigureAwait(false)` everywhere is modern guidance.” | Library and app code have different constraints. Blanket rules are weak. | +| “`Task.Run` makes server code asynchronous.” | It only queues work; it does not turn blocking I/O into true async I/O. | +| “Fire-and-forget is okay because logging exists.” | Logging does not solve scope lifetime, shutdown, retries, or error propagation. | + +## Deliverable shape + +For code review or implementation help, prefer: + +1. a short context classification +2. the concrete problem +3. the corrected pattern +4. the context-dependent tradeoff, if any +5. the smallest safe code change + +## API shape and testability + +- Prefer `Async` suffixes for awaitable-returning methods unless an established contract or event pattern dictates otherwise. +- Prefer `Task`-returning seams over hidden background work so tests can await completion, faults, and cancellation. +- For timers, queues, retries, or background pipelines, recommend abstractions that let tests control time and observe completion. +- When reviewing an async API, ask whether callers can compose it, cancel it, await it, and assert its failure behavior. + +## Hard boundaries + +- Do not endorse sync-over-async as a normal design choice. +- Do not suggest `async void` except for event handlers. +- Do not suggest `ValueTask` unless the constraints are understood. +- Do not claim `ConfigureAwait(false)` is always needed or always unnecessary. +- Do not approve fire-and-forget unless ownership, exception handling, and lifetime are explicit. diff --git a/.skills/csharp-async-best-practices/references/context-and-tradeoffs.md b/.skills/csharp-async-best-practices/references/context-and-tradeoffs.md new file mode 100644 index 0000000000..f6bdaedb79 --- /dev/null +++ b/.skills/csharp-async-best-practices/references/context-and-tradeoffs.md @@ -0,0 +1,82 @@ +--- +description: >- + Context-specific async guidance for library code, ui apps, asp.net core, + background work, task.run, configureawait, and performance-sensitive design. +metadata: + tags: [configureawait, task.run, asp.net core, ui, library, performance] + source: mixed +--- + +# Context and Tradeoffs + +## Library code versus app code + +### General-purpose library code +- Prefer APIs that expose true async for I/O-bound work. +- Do not add async wrappers around purely compute-bound methods just to look modern. Expose sync compute APIs and let callers decide whether to offload. +- `ConfigureAwait(false)` is a strong default when the library does not need the caller’s context. +- Avoid ambient assumptions about a UI thread, request context, or test framework behavior. + +### App code +- Prefer the style that fits the app model. +- UI code often needs the original context after `await`. +- ASP.NET Core request code normally does not need `Task.Run` just to stay responsive, because it already runs on thread pool threads. +- Do not present “ASP.NET Core has no synchronization context” as proof that every `ConfigureAwait(false)` discussion is obsolete. + +## `Task.Run` boundaries + +### Good uses +- Offload CPU-bound work so a UI thread can stay responsive. +- Offload CPU work from a caller when that scheduling boundary is deliberate. + +### Weak uses +- Wrapping synchronous I/O to pretend it is true async I/O. +- Calling `Task.Run` and immediately awaiting it in ASP.NET Core request handling when no CPU offload goal exists. +- Using `Task.Run` to hide blocking APIs instead of fixing the underlying API choice. + +## Fire-and-forget + +### Assume unsafe until proven otherwise +A background task needs answers for all of these: +- Who owns its lifetime? +- How are exceptions observed? +- How does shutdown cancel it? +- Does it touch scoped services or request-bound objects? +- Does work need retries, backpressure, or queueing? + +### Safer alternatives +- Await the task normally. +- Queue work to an owned background component. +- In ASP.NET Core, prefer hosted services or a dedicated background queue pattern for long-lived work. +- If scoped services are required in background processing, create an explicit scope instead of capturing request scope objects. + +## `ConfigureAwait` + +### Strong recommendation +- In general-purpose libraries, use `ConfigureAwait(false)` unless the continuation must run in the captured context. + +### Weak recommendation +- “Always use it in app code.” +- “Never use it on .NET Core.” +- “Use it once at the first await and you are done.” + +### Review note +If code after the `await` needs a specific context, say so explicitly. If it does not, the recommendation depends on whether the code is app-level or general-purpose library code. + +## Performance guidance + +### Correctness first +Do not trade API clarity for speculative micro-optimizations. + +### `ValueTask` is performance-specialized +Recommend it only when most of these are true: +1. the method is called very frequently +2. it often completes synchronously or from a reusable source +3. allocation reduction matters on measurements +4. consumers can respect single-consumer semantics +5. task combinator ergonomics are not central to the API + +### Throttling and concurrency control +- `Task.WhenAll` expresses concurrency; it does not limit it. +- For bounded concurrency, use an async gate such as `SemaphoreSlim.WaitAsync`, or platform helpers such as `Parallel.ForEachAsync` when the workload fits. +- Always define what happens to remaining work after the first completion or first failure. diff --git a/.skills/csharp-async-best-practices/references/core-guidance.md b/.skills/csharp-async-best-practices/references/core-guidance.md new file mode 100644 index 0000000000..114ad50d0e --- /dev/null +++ b/.skills/csharp-async-best-practices/references/core-guidance.md @@ -0,0 +1,105 @@ +--- +description: >- + Source-backed core guidance for task, valuetask, cancellation, exception flow, + blocking, and concurrency in c# async code reviews and implementations. +metadata: + tags: [csharp, async, task, valuetask, cancellation, exceptions, concurrency] + source: mixed +--- + +# Core Guidance + +## Facts from official .NET documentation + +### 1. Return types and `async void` +- Async methods should normally return `Task` or `Task`. +- `async void` is intended for event handlers; callers cannot await it and exception handling differs. +- TAP methods that return awaitable types conventionally use the `Async` suffix. + +### 2. Blocking on async +- `Task.Result` is blocking. Prefer `await` in most cases. +- Blocking can deadlock in context-bound environments and reduces scalability even when it does not deadlock. +- `await` on a faulted task rethrows one exception directly; `.Wait()` and `.Result` wrap failures in `AggregateException`. + +### 3. `Task` versus `ValueTask` +- Default to `Task` or `Task` unless there is a demonstrated reason not to. +- `ValueTask` has stricter usage rules. A given instance should generally be awaited only once. +- Do not await the same `ValueTask` multiple times, call `AsTask()` multiple times, or mix consumption techniques on the same instance. +- For synchronously successful `Task`-returning methods, `Task.CompletedTask` is the normal zero-result completion value. + +### 4. Cancellation +- If a TAP method supports cancellation, expose a `CancellationToken`. +- Pass the token to nested operations that should participate in cancellation. +- If an async method throws `OperationCanceledException` associated with the method’s token, the returned task transitions to `Canceled`. +- After a method has completed its work successfully, do not report cancellation instead of success. + +### 5. Exception flow and task combinators +- `Task.WhenAll` does not block the calling thread. +- If any supplied task faults, the `WhenAll` task faults and aggregates the unwrapped exceptions from the component tasks. +- If none fault and at least one is canceled, the `WhenAll` task is canceled. +- `Task.WhenAny` returns a task that completes successfully with the first completed task as its result, even when that winning task itself is faulted or canceled. +- After `WhenAny`, await the returned winner task to propagate its outcome. +- The remaining tasks continue unless you cancel or otherwise handle them. + +## Expert guidance that is strong and technically grounded + +### Stephen Toub +- Use `ConfigureAwait(false)` as the general default for general-purpose library code, because library code should not depend on an app model’s context. +- App-level code is different. UI code often needs the captured context. ASP.NET Core also changes the deadlock discussion because it does not install the classic ASP.NET style synchronization context, but that does not make blanket `ConfigureAwait` advice strong. +- `ValueTask` exists mainly to avoid allocations on frequently synchronous success paths. It is not a general replacement for `Task` because `Task` is more flexible for multiple awaits, caching, and combinators. + +### Andrew Arnott +- Propagate the token until the point of no cancellation. +- Validate arguments before cancellation checks when argument validation should always run. +- Prefer catching `OperationCanceledException` rather than `TaskCanceledException` in general-purpose logic. +- Keep `CancellationToken` last in the parameter list; make it optional mainly on public APIs, not necessarily on internal methods. + +### Stephen Cleary +- “Async all the way” is a strong design guideline, not an absolute law of physics. Sync bridges exist, but they are specialized boundary decisions, not a normal code review recommendation. +- `async void` and sync-over-async both create real observability and composition problems even when a sample appears to work. + +## Naming and testability + +### Naming +- TAP methods that return awaitable types conventionally use the `Async` suffix. Do not force renames when an interface, base class, or event pattern already dictates the name. + +### Testability +- Favor awaitable APIs over hidden work so tests can await completion, assert faults, and drive cancellation deterministically. +- Prefer explicit background components, injected clocks, and owned queues over ad hoc fire-and-forget logic that tests cannot observe. + +## Synthesis for agents + +### Code review defaults +- Treat `.Result`, `.Wait()`, and `GetAwaiter().GetResult()` as likely defects unless the code is a deliberate sync boundary and the caller explicitly cannot be async. +- Prefer `Task`/`Task` for API design. Require an explicit reason before recommending `ValueTask`. +- Require cancellation behavior to be coherent: accepted, propagated, and not silently dropped. +- Prefer `await Task.WhenAll(...)` for independent operations started before awaiting. +- Treat `Task.WhenAny(...)` as incomplete until the winner is awaited and losers are canceled, observed, or intentionally left running. + +### Minimal examples + +#### Avoid sync-over-async +```csharp +// bad +var user = client.GetUserAsync(id).Result; + +// better +var user = await client.GetUserAsync(id); +``` + +#### Use `Task.WhenAll` for parallel I/O +```csharp +var userTask = repo.GetUserAsync(id, ct); +var ordersTask = repo.GetOrdersAsync(id, ct); +await Task.WhenAll(userTask, ordersTask); +return new Dashboard(await userTask, await ordersTask); +``` + +#### Be conservative with `ValueTask` +```csharp +// default +Task GetAsync(string key, CancellationToken ct); + +// specialized hot path only when justified +ValueTask TryGetCachedAsync(string key); +``` diff --git a/.skills/csharp-async-best-practices/references/source-notes.md b/.skills/csharp-async-best-practices/references/source-notes.md new file mode 100644 index 0000000000..0d34589d1d --- /dev/null +++ b/.skills/csharp-async-best-practices/references/source-notes.md @@ -0,0 +1,59 @@ +--- +description: >- + Authority notes and citations for the c# async best practices skill, separating + official documentation, expert interpretation, and synthesized guidance. +metadata: + tags: [sources, citations, authority, notes] + source: external +--- + +# Source Notes + +## Official facts + +- Microsoft Learn, "Implementing the Task-based Asynchronous Pattern" + - https://learn.microsoft.com/en-us/dotnet/standard/asynchronous-programming-patterns/implementing-the-task-based-asynchronous-pattern + - Return types, cancellation behavior, `Task.Run` boundaries, and TAP implementation guidance. +- Microsoft Learn, "Consuming the Task-based Asynchronous Pattern" + - https://learn.microsoft.com/en-us/dotnet/standard/asynchronous-programming-patterns/consuming-the-task-based-asynchronous-pattern + - `await`, `WhenAll`, `WhenAny`, cancellation propagation, and exception behavior. +- Microsoft Learn, "Async return types" + - https://learn.microsoft.com/en-us/dotnet/csharp/asynchronous-programming/async-return-types + - `Task`, `Task`, `async void`, generalized async return types. +- Microsoft Learn, `ValueTask` API reference + - https://learn.microsoft.com/en-us/dotnet/api/system.threading.tasks.valuetask + - single-consumer warnings and default-to-`Task` guidance. +- Microsoft Learn, ASP.NET Core best practices + - https://learn.microsoft.com/en-us/aspnet/core/fundamentals/best-practices + - avoid blocking calls, avoid unnecessary `Task.Run`, background-work cautions. +- Microsoft Learn, hosted services in ASP.NET Core + - https://learn.microsoft.com/en-us/aspnet/core/fundamentals/host/hosted-services + - safe long-lived background work and cancellation during shutdown. + +## Expert guidance used only when technically grounded + +- Stephen Toub, ".NET Blog: ConfigureAwait FAQ" + - https://devblogs.microsoft.com/dotnet/configureawait-faq/ + - best source for context capture semantics and library-vs-app guidance. +- Stephen Toub, ".NET Blog: Understanding the Whys, Whats, and Whens of ValueTask" + - https://devblogs.microsoft.com/dotnet/understanding-the-whys-whats-and-whens-of-valuetask/ + - performance rationale and tradeoffs behind `ValueTask`. +- Stephen Toub, ".NET Blog: Await, and UI, and deadlocks! Oh my!" + - https://devblogs.microsoft.com/dotnet/await-and-ui-and-deadlocks-oh-my/ + - canonical deadlock explanation for context-bound code. +- Stephen Toub, ".NET Blog: Task Exception Handling in .NET 4.5" + - https://devblogs.microsoft.com/dotnet/task-exception-handling-in-net-4-5/ + - explains `await` versus blocking exception shape and why `WhenAll` matters. +- Andrew Arnott, "Recommended patterns for CancellationToken" + - https://devblogs.microsoft.com/premier-developer/recommended-patterns-for-cancellationtoken/ + - practical cancellation design heuristics; useful, but not treated as a language/runtime spec. +- Stephen Cleary, "Async/Await - Best Practices in Asynchronous Programming" + - https://learn.microsoft.com/en-us/archive/msdn-magazine/2013/march/async-await-best-practices-in-asynchronous-programming + - useful design interpretation, but older and treated as contextual guidance rather than current official policy. + +## Where the skill is intentionally cautious + +- `ConfigureAwait`: strong guidance exists for libraries, weaker guidance for app code. Blanket rules are rejected. +- `Task.Run`: valid for deliberate CPU offload, weak as a server-side patch for blocking I/O. +- `ValueTask`: supported and useful, but easy to misuse. The skill defaults to `Task` unless evidence is present. +- Fire-and-forget: acceptable only with explicit ownership and lifecycle design, especially in server code. diff --git a/DebugTools/MccMcpStdioHarness/Program.cs b/DebugTools/MccMcpStdioHarness/Program.cs index 1b90ddbb40..a6eca2d118 100644 --- a/DebugTools/MccMcpStdioHarness/Program.cs +++ b/DebugTools/MccMcpStdioHarness/Program.cs @@ -409,6 +409,9 @@ public MccMcpResult DigBlock(double x, double y, double z, double durationSecond playerLocation = new { x = C(0.5), y = C(80.0), z = C(0.5) } }); + public Task DigBlockAsync(double x, double y, double z, double durationSeconds) => + Task.FromResult(DigBlock(x, y, z, durationSeconds)); + public MccMcpResult PlaceBlock(int x, int y, int z, string face, string hand, bool lookAtBlock) => MccMcpResult.Ok(new { success = true, x, y, z, face, hand, lookAtBlock, action = "place_block" }); @@ -567,6 +570,9 @@ public MccMcpResult MoveTo(double x, double y, double z, bool allowUnsafe, bool timeoutMs }); + public Task MoveToAsync(double x, double y, double z, bool allowUnsafe, bool allowDirectTeleport, int maxOffset, int minOffset, int timeoutMs) => + Task.FromResult(MoveTo(x, y, z, allowUnsafe, allowDirectTeleport, maxOffset, minOffset, timeoutMs)); + public MccMcpResult MoveToPlayer(string playerName, bool allowUnsafe, bool allowDirectTeleport, int maxOffset, int minOffset, int timeoutMs) => MccMcpResult.Ok(new { @@ -593,6 +599,9 @@ public MccMcpResult MoveToPlayer(string playerName, bool allowUnsafe, bool allow timeoutMs }); + public Task MoveToPlayerAsync(string playerName, bool allowUnsafe, bool allowDirectTeleport, int maxOffset, int minOffset, int timeoutMs) => + Task.FromResult(MoveToPlayer(playerName, allowUnsafe, allowDirectTeleport, maxOffset, minOffset, timeoutMs)); + public MccMcpResult LookAt(double x, double y, double z) => MccMcpResult.Ok(new { looked = true, x = C(x), y = C(y), z = C(z) }); @@ -755,6 +764,9 @@ public MccMcpResult OpenContainerAt(int x, int y, int z, int timeoutMs, bool clo }); } + public Task OpenContainerAtAsync(int x, int y, int z, int timeoutMs, bool closeCurrent) => + Task.FromResult(OpenContainerAt(x, y, z, timeoutMs, closeCurrent)); + public MccMcpResult CloseContainer(int inventoryId, int timeoutMs) { int resolvedInventoryId = inventoryId <= 0 ? 1 : inventoryId; @@ -768,6 +780,9 @@ public MccMcpResult CloseContainer(int inventoryId, int timeoutMs) }); } + public Task CloseContainerAsync(int inventoryId, int timeoutMs) => + Task.FromResult(CloseContainer(inventoryId, timeoutMs)); + public MccMcpResult InventoryWindowAction(int inventoryId, int slotId, string actionType) => MccMcpResult.Ok(new { success = true, inventoryId, slotId, actionType }); @@ -785,6 +800,9 @@ public MccMcpResult DropInventoryItem(string itemType, int count, int inventoryI preferStack }); + public Task DropInventoryItemAsync(string itemType, int count, int inventoryId, bool preferStack) => + Task.FromResult(DropInventoryItem(itemType, count, inventoryId, preferStack)); + public MccMcpResult DepositContainerItem(string itemType, int count, int inventoryId, bool preferLargestStack) => MccMcpResult.Ok(new { @@ -803,6 +821,9 @@ public MccMcpResult DepositContainerItem(string itemType, int count, int invento touchedTargetSlots = new[] { 0 } }); + public Task DepositContainerItemAsync(string itemType, int count, int inventoryId, bool preferLargestStack) => + Task.FromResult(DepositContainerItem(itemType, count, inventoryId, preferLargestStack)); + public MccMcpResult WithdrawContainerItem(string itemType, int count, int inventoryId, bool preferLargestStack) => MccMcpResult.Ok(new { @@ -821,6 +842,9 @@ public MccMcpResult WithdrawContainerItem(string itemType, int count, int invent touchedTargetSlots = new[] { 36 } }); + public Task WithdrawContainerItemAsync(string itemType, int count, int inventoryId, bool preferLargestStack) => + Task.FromResult(WithdrawContainerItem(itemType, count, inventoryId, preferLargestStack)); + public MccMcpResult QueryEntities(int maxCount) => MccMcpResult.Ok(new { @@ -963,6 +987,9 @@ public MccMcpResult PickupItems(string itemType, double radius, int maxItems, bo } }); + public Task PickupItemsAsync(string itemType, double radius, int maxItems, bool allowUnsafe, int timeoutMs) => + Task.FromResult(PickupItems(itemType, radius, maxItems, allowUnsafe, timeoutMs)); + public MccMcpResult Respawn() { health = 20.0f; diff --git a/MinecraftClient/AutoTimeout.cs b/MinecraftClient/AutoTimeout.cs index 786be4db13..f08dc5e590 100644 --- a/MinecraftClient/AutoTimeout.cs +++ b/MinecraftClient/AutoTimeout.cs @@ -1,5 +1,6 @@ using System; using System.Threading; +using System.Threading.Tasks; namespace MinecraftClient { @@ -22,6 +23,11 @@ public static bool Perform(Action action, int timeout) return Perform(action, TimeSpan.FromMilliseconds(timeout)); } + public static Task PerformAsync(Action action, int timeout, CancellationToken cancellationToken = default) + { + return PerformAsync(action, TimeSpan.FromMilliseconds(timeout), cancellationToken); + } + /// /// Perform the specified action with specified timeout /// @@ -30,14 +36,26 @@ public static bool Perform(Action action, int timeout) /// True if the action finished whithout timing out public static bool Perform(Action action, TimeSpan timeout) { - Thread thread = new(new ThreadStart(action)); - thread.Start(); + return PerformAsync(action, timeout).GetAwaiter().GetResult(); + } - bool success = thread.Join(timeout); - if (!success) - thread.Interrupt(); + public static async Task PerformAsync(Action action, TimeSpan timeout, CancellationToken cancellationToken = default) + { + ArgumentNullException.ThrowIfNull(action); - return success; + try + { + await Task.Run(action, cancellationToken).WaitAsync(timeout, cancellationToken); + return true; + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + return false; + } + catch (TimeoutException) + { + return false; + } } } -} \ No newline at end of file +} diff --git a/MinecraftClient/ClassicConsoleBackend.cs b/MinecraftClient/ClassicConsoleBackend.cs index ac4912c914..bc8dedd822 100644 --- a/MinecraftClient/ClassicConsoleBackend.cs +++ b/MinecraftClient/ClassicConsoleBackend.cs @@ -1,4 +1,6 @@ using System; +using System.Threading; +using System.Threading.Tasks; namespace MinecraftClient { @@ -50,6 +52,13 @@ public string RequestImmediateInput() return ConsoleInteractive.ConsoleReader.RequestImmediateInput(); } + public Task RequestImmediateInputAsync(CancellationToken cancellationToken) + { + // ConsoleInteractive only exposes a blocking immediate-read API. + // Keep the compatibility boundary here so the wider startup/runtime path can await it. + return Task.Run(ConsoleInteractive.ConsoleReader.RequestImmediateInput, cancellationToken); + } + public string? ReadPassword() { ConsoleInteractive.ConsoleReader.SetInputVisible(false); @@ -58,6 +67,19 @@ public string RequestImmediateInput() return input; } + public async Task ReadPasswordAsync(CancellationToken cancellationToken) + { + ConsoleInteractive.ConsoleReader.SetInputVisible(false); + try + { + return await RequestImmediateInputAsync(cancellationToken); + } + finally + { + ConsoleInteractive.ConsoleReader.SetInputVisible(true); + } + } + public void ClearInputBuffer() { ConsoleInteractive.ConsoleReader.ClearBuffer(); diff --git a/MinecraftClient/ConsoleIO.cs b/MinecraftClient/ConsoleIO.cs index 8a77aaeefa..8245af1324 100644 --- a/MinecraftClient/ConsoleIO.cs +++ b/MinecraftClient/ConsoleIO.cs @@ -76,6 +76,13 @@ public static void SetAutoCompleteEngine(IAutoComplete engine) return Backend.ReadPassword(); } + public static Task ReadPasswordAsync(CancellationToken cancellationToken = default) + { + if (BasicIO) + return Task.FromResult(Console.ReadLine()); + return Backend.ReadPasswordAsync(cancellationToken); + } + /// /// Read a line from the standard input /// @@ -86,6 +93,13 @@ public static string ReadLine() return Backend.RequestImmediateInput(); } + public static Task ReadLineAsync(CancellationToken cancellationToken = default) + { + if (BasicIO) + return Task.FromResult(Console.ReadLine() ?? string.Empty); + return Backend.RequestImmediateInputAsync(cancellationToken); + } + /// /// Debug routine: print all keys pressed in the console /// @@ -233,84 +247,120 @@ private static void MccAutocompleteHandler(ConsoleInputBuffer buffer) DoClearSuggestions(); return; } + _cancellationTokenSource?.Cancel(); - using var cts = new CancellationTokenSource(); + var cts = new CancellationTokenSource(); _cancellationTokenSource = cts; - var previousTask = _latestTask; - var newTask = new Task(async () => + Task newTask = UpdateSuggestionsAsync(fullCommand, offset, buffer.CursorPosition, cts.Token); + _latestTask = newTask; + _ = ObserveAutocompleteTaskAsync(newTask, cts); + } + else + { + DoClearSuggestions(); + return; + } + } + + private static async Task UpdateSuggestionsAsync(string fullCommand, int offset, int cursorPosition, CancellationToken cancellationToken) + { + string command = fullCommand[offset..]; + if (command.Length == 0) + { + List suggestionList = new() { - string command = fullCommand[offset..]; - if (command.Length == 0) - { - List sugList = new(); + new("/") + }; - sugList.Add(new("/")); + var childs = McClient.dispatcher.GetRoot().Children; + if (childs is not null) + { + foreach (var child in childs) + suggestionList.Add(new(child.Name)); + } - var childs = McClient.dispatcher.GetRoot().Children; - if (childs is not null) - foreach (var child in childs) - sugList.Add(new(child.Name)); + foreach (var cmd in Commands) + suggestionList.Add(new(cmd)); - foreach (var cmd in Commands) - sugList.Add(new(cmd)); + if (cancellationToken.IsCancellationRequested) + return; - SendSuggestions(sugList.ToArray(), new(offset, offset)); - } - else if (command.Length > 0 && command[0] == '/' && !command.Contains(' ')) - { - var sorted = Process.ExtractSorted(command[1..], Commands); - var sugList = new ConsoleInteractive.ConsoleSuggestion.Suggestion[sorted.Count()]; + SendSuggestions(suggestionList.ToArray(), new(offset, offset)); + return; + } - int index = 0; - foreach (var sug in sorted) - sugList[index++] = new(sug.Value); - SendSuggestions(sugList, new(offset, offset + command.Length)); - } - else - { - CommandDispatcher? dispatcher = McClient.dispatcher; - if (dispatcher is null) - return; + if (command[0] == '/' && !command.Contains(' ')) + { + var sorted = Process.ExtractSorted(command[1..], Commands); + var suggestionList = new ConsoleInteractive.ConsoleSuggestion.Suggestion[sorted.Count()]; - ParseResults parse = dispatcher.Parse(command, CmdResult.Empty); + int index = 0; + foreach (var suggestion in sorted) + suggestionList[index++] = new(suggestion.Value); - Brigadier.NET.Suggestion.Suggestions suggestions = await dispatcher.GetCompletionSuggestions(parse, buffer.CursorPosition - offset); + if (cancellationToken.IsCancellationRequested) + return; - int sugLen = suggestions.List.Count; - if (sugLen == 0) - { - DoClearSuggestions(); - return; - } + SendSuggestions(suggestionList, new(offset, offset + command.Length)); + return; + } - Dictionary dictionary = new(); - foreach (var sug in suggestions.List) - dictionary.Add(sug.Text, sug.Tooltip?.String); + CommandDispatcher? dispatcher = McClient.dispatcher; + if (dispatcher is null) + return; - var sugList = new ConsoleInteractive.ConsoleSuggestion.Suggestion[sugLen]; - if (cts.IsCancellationRequested) - return; + ParseResults parse = dispatcher.Parse(command, CmdResult.Empty); + Brigadier.NET.Suggestion.Suggestions suggestions = + await dispatcher.GetCompletionSuggestions(parse, cursorPosition - offset); - Tuple range = new(suggestions.Range.Start + offset, suggestions.Range.End + offset); - var sorted = Process.ExtractSorted(fullCommand[range.Item1..range.Item2], dictionary.Keys); - if (cts.IsCancellationRequested) - return; + if (cancellationToken.IsCancellationRequested) + return; - int index = 0; - foreach (var sug in sorted) - sugList[index++] = new(sug.Value, dictionary[sug.Value] ?? string.Empty); + int suggestionCount = suggestions.List.Count; + if (suggestionCount == 0) + { + DoClearSuggestions(); + return; + } - SendSuggestions(sugList, range); - } - }, cts.Token); - _latestTask = newTask; - try { newTask.Start(); } catch { } - if (_cancellationTokenSource == cts) _cancellationTokenSource = null; + Dictionary tooltips = new(); + foreach (var suggestion in suggestions.List) + tooltips.Add(suggestion.Text, suggestion.Tooltip?.String); + + Tuple range = new(suggestions.Range.Start + offset, suggestions.Range.End + offset); + var sortedSuggestions = Process.ExtractSorted(fullCommand[range.Item1..range.Item2], tooltips.Keys); + if (cancellationToken.IsCancellationRequested) + return; + + var suggestionListWithTooltips = new ConsoleInteractive.ConsoleSuggestion.Suggestion[suggestionCount]; + int suggestionIndex = 0; + foreach (var suggestion in sortedSuggestions) + suggestionListWithTooltips[suggestionIndex++] = new(suggestion.Value, tooltips[suggestion.Value] ?? string.Empty); + + SendSuggestions(suggestionListWithTooltips, range); + } + + private static async Task ObserveAutocompleteTaskAsync(Task task, CancellationTokenSource cancellationTokenSource) + { + try + { + await task; } - else + catch (OperationCanceledException) when (cancellationTokenSource.IsCancellationRequested) + { + } + catch (Exception e) { + if (Settings.Config.Logging.DebugMessages) + WriteLogLine(e.ToString(), acceptnewlines: true); DoClearSuggestions(); - return; + } + finally + { + if (ReferenceEquals(_cancellationTokenSource, cancellationTokenSource)) + _cancellationTokenSource = null; + + cancellationTokenSource.Dispose(); } } diff --git a/MinecraftClient/Crypto/AesCfb8Stream.cs b/MinecraftClient/Crypto/AesCfb8Stream.cs index b60eb84528..381674a53a 100644 --- a/MinecraftClient/Crypto/AesCfb8Stream.cs +++ b/MinecraftClient/Crypto/AesCfb8Stream.cs @@ -1,7 +1,10 @@ using System; +using System.Buffers; using System.IO; using System.Runtime.CompilerServices; -using System.Security.Cryptography; +using System.Threading; +using System.Threading.Tasks; +using MinecraftClient.Crypto.AesHandler; namespace MinecraftClient.Crypto { @@ -9,8 +12,7 @@ public class AesCfb8Stream : Stream { public const int blockSize = 16; - private readonly Aes? Aes = null; - private readonly FastAes? FastAes = null; + private readonly IAesHandler aesHandler; private bool inStreamEnded = false; @@ -22,18 +24,7 @@ public class AesCfb8Stream : Stream public AesCfb8Stream(Stream stream, byte[] key) { BaseStream = stream; - - if (FastAes.IsSupported()) - FastAes = new FastAes(key); - else - { - Aes = Aes.Create(); - Aes.BlockSize = 128; - Aes.KeySize = 128; - Aes.Key = key; - Aes.Mode = CipherMode.ECB; - Aes.Padding = PaddingMode.None; - } + aesHandler = AesHandlerFactory.Create(key); Array.Copy(key, ReadStreamIV, 16); Array.Copy(key, WriteStreamIV, 16); @@ -59,6 +50,11 @@ public override void Flush() BaseStream.Flush(); } + public override Task FlushAsync(CancellationToken cancellationToken) + { + return BaseStream.FlushAsync(cancellationToken); + } + public override long Length { get { throw new NotSupportedException(); } @@ -89,10 +85,7 @@ public override int ReadByte() } Span blockOutput = stackalloc byte[blockSize]; - if (FastAes is not null) - FastAes.EncryptEcb(ReadStreamIV, blockOutput); - else - Aes!.EncryptEcb(ReadStreamIV, blockOutput, PaddingMode.None); + aesHandler.EncryptEcb(ReadStreamIV, blockOutput); // Shift left Array.Copy(ReadStreamIV, 1, ReadStreamIV, 0, blockSize - 1); @@ -101,6 +94,12 @@ public override int ReadByte() return (byte)(blockOutput[0] ^ inputBuf); } + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private void EncryptBlock(ReadOnlySpan blockInput, Span blockOutput) + { + aesHandler.EncryptEcb(blockInput, blockOutput); + } + [MethodImpl(MethodImplOptions.AggressiveOptimization)] public override int Read(byte[] buffer, int outOffset, int required) { @@ -108,43 +107,39 @@ public override int Read(byte[] buffer, int outOffset, int required) return 0; Span blockOutput = stackalloc byte[blockSize]; + byte[] inputBuf = ArrayPool.Shared.Rent(blockSize + required); - byte[] inputBuf = new byte[blockSize + required]; - Array.Copy(ReadStreamIV, inputBuf, blockSize); - - for (int readed = 0, curRead; readed < required; readed += curRead) + try { - curRead = BaseStream.Read(inputBuf, blockSize + readed, required - readed); - if (curRead == 0) - { - inStreamEnded = true; - return readed; - } + Array.Copy(ReadStreamIV, inputBuf, blockSize); - int processEnd = readed + curRead; - if (FastAes is not null) + for (int readed = 0, curRead; readed < required; readed += curRead) { - for (int idx = readed; idx < processEnd; idx++) + curRead = BaseStream.Read(inputBuf, blockSize + readed, required - readed); + if (curRead == 0) { - ReadOnlySpan blockInput = new(inputBuf, idx, blockSize); - FastAes.EncryptEcb(blockInput, blockOutput); - buffer[outOffset + idx] = (byte)(blockOutput[0] ^ inputBuf[idx + blockSize]); + inStreamEnded = true; + Array.Copy(inputBuf, readed, ReadStreamIV, 0, blockSize); + return readed; } - } - else - { + + int processEnd = readed + curRead; for (int idx = readed; idx < processEnd; idx++) { ReadOnlySpan blockInput = new(inputBuf, idx, blockSize); - Aes!.EncryptEcb(blockInput, blockOutput, PaddingMode.None); + EncryptBlock(blockInput, blockOutput); buffer[outOffset + idx] = (byte)(blockOutput[0] ^ inputBuf[idx + blockSize]); } } - } - Array.Copy(inputBuf, required, ReadStreamIV, 0, blockSize); + Array.Copy(inputBuf, required, ReadStreamIV, 0, blockSize); - return required; + return required; + } + finally + { + ArrayPool.Shared.Return(inputBuf); + } } public override long Seek(long offset, SeekOrigin origin) @@ -161,10 +156,7 @@ public override void WriteByte(byte b) { Span blockOutput = stackalloc byte[blockSize]; - if (FastAes is not null) - FastAes.EncryptEcb(WriteStreamIV, blockOutput); - else - Aes!.EncryptEcb(WriteStreamIV, blockOutput, PaddingMode.None); + EncryptBlock(WriteStreamIV, blockOutput); byte outputBuf = (byte)(blockOutput[0] ^ b); @@ -178,22 +170,129 @@ public override void WriteByte(byte b) [MethodImpl(MethodImplOptions.AggressiveOptimization)] public override void Write(byte[] input, int offset, int required) { - byte[] outputBuf = new byte[blockSize + required]; - Array.Copy(WriteStreamIV, outputBuf, blockSize); + byte[] outputBuf = ArrayPool.Shared.Rent(blockSize + required); + + try + { + Array.Copy(WriteStreamIV, outputBuf, blockSize); + + Span blockOutput = stackalloc byte[blockSize]; + for (int written = 0; written < required; ++written) + { + ReadOnlySpan blockInput = new(outputBuf, written, blockSize); + EncryptBlock(blockInput, blockOutput); + outputBuf[blockSize + written] = (byte)(blockOutput[0] ^ input[offset + written]); + } + + BaseStream.Write(outputBuf, blockSize, required); + Array.Copy(outputBuf, required, WriteStreamIV, 0, blockSize); + } + finally + { + ArrayPool.Shared.Return(outputBuf); + } + } + + [MethodImpl(MethodImplOptions.AggressiveOptimization)] + public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + if (inStreamEnded || buffer.Length == 0) + return 0; + byte[] inputBuf = ArrayPool.Shared.Rent(blockSize + buffer.Length); + + try + { + Array.Copy(ReadStreamIV, inputBuf, blockSize); + + for (int readed = 0; readed < buffer.Length;) + { + int curRead = await BaseStream.ReadAsync(inputBuf.AsMemory(blockSize + readed, buffer.Length - readed), cancellationToken); + if (curRead == 0) + { + inStreamEnded = true; + Array.Copy(inputBuf, readed, ReadStreamIV, 0, blockSize); + return readed; + } + + int processEnd = readed + curRead; + DecryptToOutputBuffer(inputBuf, buffer, readed, processEnd); + readed = processEnd; + } + + Array.Copy(inputBuf, buffer.Length, ReadStreamIV, 0, blockSize); + return buffer.Length; + } + finally + { + ArrayPool.Shared.Return(inputBuf); + } + } + + [MethodImpl(MethodImplOptions.AggressiveOptimization)] + public override async ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) + { + if (buffer.Length == 0) + return; + + byte[] outputBuf = ArrayPool.Shared.Rent(blockSize + buffer.Length); + + try + { + Array.Copy(WriteStreamIV, outputBuf, blockSize); + EncryptToOutputBuffer(buffer, outputBuf); + + await BaseStream.WriteAsync(outputBuf.AsMemory(blockSize, buffer.Length), cancellationToken); + Array.Copy(outputBuf, buffer.Length, WriteStreamIV, 0, blockSize); + } + finally + { + ArrayPool.Shared.Return(outputBuf); + } + } + + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return ReadAsync(buffer.AsMemory(offset, count), cancellationToken).AsTask(); + } + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return WriteAsync(buffer.AsMemory(offset, count), cancellationToken).AsTask(); + } + + protected override void Dispose(bool disposing) + { + if (disposing) + { + aesHandler.Dispose(); + } + + base.Dispose(disposing); + } + + [MethodImpl(MethodImplOptions.AggressiveOptimization)] + private void DecryptToOutputBuffer(byte[] inputBuf, Memory output, int start, int end) + { Span blockOutput = stackalloc byte[blockSize]; - for (int wirtten = 0; wirtten < required; ++wirtten) + for (int idx = start; idx < end; idx++) { - ReadOnlySpan blockInput = new(outputBuf, wirtten, blockSize); - if (FastAes is not null) - FastAes.EncryptEcb(blockInput, blockOutput); - else - Aes!.EncryptEcb(blockInput, blockOutput, PaddingMode.None); - outputBuf[blockSize + wirtten] = (byte)(blockOutput[0] ^ input[offset + wirtten]); + ReadOnlySpan blockInput = new(inputBuf, idx, blockSize); + EncryptBlock(blockInput, blockOutput); + output.Span[idx] = (byte)(blockOutput[0] ^ inputBuf[idx + blockSize]); } - BaseStream.WriteAsync(outputBuf, blockSize, required); + } - Array.Copy(outputBuf, required, WriteStreamIV, 0, blockSize); + [MethodImpl(MethodImplOptions.AggressiveOptimization)] + private void EncryptToOutputBuffer(ReadOnlyMemory input, byte[] outputBuf) + { + Span blockOutput = stackalloc byte[blockSize]; + for (int written = 0; written < input.Length; ++written) + { + ReadOnlySpan blockInput = new(outputBuf, written, blockSize); + EncryptBlock(blockInput, blockOutput); + outputBuf[blockSize + written] = (byte)(blockOutput[0] ^ input.Span[written]); + } } } } diff --git a/MinecraftClient/Crypto/AesHandler/BasicAes.cs b/MinecraftClient/Crypto/AesHandler/BasicAes.cs new file mode 100644 index 0000000000..8c08572fcf --- /dev/null +++ b/MinecraftClient/Crypto/AesHandler/BasicAes.cs @@ -0,0 +1,31 @@ +using System; +using System.Security.Cryptography; + +namespace MinecraftClient.Crypto.AesHandler; + +public sealed class BasicAes : IAesHandler +{ + private readonly Aes aes; + + public BasicAes(byte[] key) + { + ArgumentNullException.ThrowIfNull(key); + + aes = Aes.Create(); + aes.BlockSize = 128; + aes.KeySize = 128; + aes.Key = key; + aes.Mode = CipherMode.ECB; + aes.Padding = PaddingMode.None; + } + + public override void EncryptEcb(ReadOnlySpan plaintext, Span destination) + { + aes.EncryptEcb(plaintext, destination, PaddingMode.None); + } + + public override void Dispose() + { + aes.Dispose(); + } +} diff --git a/MinecraftClient/Crypto/AesHandler/FasterAesArm.cs b/MinecraftClient/Crypto/AesHandler/FasterAesArm.cs new file mode 100644 index 0000000000..a5b8621d6b --- /dev/null +++ b/MinecraftClient/Crypto/AesHandler/FasterAesArm.cs @@ -0,0 +1,162 @@ +using System; +using System.Runtime.CompilerServices; +using System.Runtime.Intrinsics; +using System.Runtime.Intrinsics.Arm; + +namespace MinecraftClient.Crypto.AesHandler; + +public sealed class FasterAesArm : IAesHandler +{ + private const int BlockSize = 16; + private const int Rounds = 10; + + private readonly byte[] enc; + + public FasterAesArm(ReadOnlySpan key) + { + enc = new byte[(Rounds + 1) * BlockSize]; + + int[] intKey = GenerateKeyExpansion(key); + for (int i = 0; i < intKey.Length; ++i) + { + enc[i * 4 + 0] = (byte)((intKey[i] >> 0) & 0xFF); + enc[i * 4 + 1] = (byte)((intKey[i] >> 8) & 0xFF); + enc[i * 4 + 2] = (byte)((intKey[i] >> 16) & 0xFF); + enc[i * 4 + 3] = (byte)((intKey[i] >> 24) & 0xFF); + } + } + + public static bool IsSupported() + { + return Aes.IsSupported && AdvSimd.IsSupported; + } + + [MethodImpl(MethodImplOptions.AggressiveOptimization)] + public override void EncryptEcb(ReadOnlySpan plaintext, Span destination) + { + int position = 0; + int left = plaintext.Length; + + Vector128 key0 = Unsafe.ReadUnaligned>(ref enc[0 * BlockSize]); + Vector128 key1 = Unsafe.ReadUnaligned>(ref enc[1 * BlockSize]); + Vector128 key2 = Unsafe.ReadUnaligned>(ref enc[2 * BlockSize]); + Vector128 key3 = Unsafe.ReadUnaligned>(ref enc[3 * BlockSize]); + Vector128 key4 = Unsafe.ReadUnaligned>(ref enc[4 * BlockSize]); + Vector128 key5 = Unsafe.ReadUnaligned>(ref enc[5 * BlockSize]); + Vector128 key6 = Unsafe.ReadUnaligned>(ref enc[6 * BlockSize]); + Vector128 key7 = Unsafe.ReadUnaligned>(ref enc[7 * BlockSize]); + Vector128 key8 = Unsafe.ReadUnaligned>(ref enc[8 * BlockSize]); + Vector128 key9 = Unsafe.ReadUnaligned>(ref enc[9 * BlockSize]); + Vector128 key10 = Unsafe.ReadUnaligned>(ref enc[10 * BlockSize]); + + while (left >= BlockSize) + { + Vector128 block = Unsafe.ReadUnaligned>(ref Unsafe.AsRef(in plaintext[position])); + + block = Aes.Encrypt(block, key0); + block = Aes.MixColumns(block); + + block = Aes.Encrypt(block, key1); + block = Aes.MixColumns(block); + + block = Aes.Encrypt(block, key2); + block = Aes.MixColumns(block); + + block = Aes.Encrypt(block, key3); + block = Aes.MixColumns(block); + + block = Aes.Encrypt(block, key4); + block = Aes.MixColumns(block); + + block = Aes.Encrypt(block, key5); + block = Aes.MixColumns(block); + + block = Aes.Encrypt(block, key6); + block = Aes.MixColumns(block); + + block = Aes.Encrypt(block, key7); + block = Aes.MixColumns(block); + + block = Aes.Encrypt(block, key8); + block = Aes.MixColumns(block); + + block = Aes.Encrypt(block, key9); + block = AdvSimd.Xor(block, key10); + + Unsafe.WriteUnaligned(ref destination[position], block); + + position += BlockSize; + left -= BlockSize; + } + } + + private static int[] GenerateKeyExpansion(ReadOnlySpan rgbKey) + { + int[] encryptKeyExpansion = new int[4 * (Rounds + 1)]; + + int index = 0; + for (int i = 0; i < 4; ++i) + { + int i0 = rgbKey[index++]; + int i1 = rgbKey[index++]; + int i2 = rgbKey[index++]; + int i3 = rgbKey[index++]; + encryptKeyExpansion[i] = i3 << 24 | i2 << 16 | i1 << 8 | i0; + } + + for (int i = 4; i < 4 * (Rounds + 1); ++i) + { + int temp = encryptKeyExpansion[i - 1]; + + if (i % 4 == 0) + { + temp = SubWord(Rot3(temp)); + temp ^= Rcon[(i / 4) - 1]; + } + + encryptKeyExpansion[i] = encryptKeyExpansion[i - 4] ^ temp; + } + + return encryptKeyExpansion; + } + + private static int SubWord(int value) + { + return Sbox[value & 0xFF] + | Sbox[(value >> 8) & 0xFF] << 8 + | Sbox[(value >> 16) & 0xFF] << 16 + | Sbox[(value >> 24) & 0xFF] << 24; + } + + private static int Rot3(int value) + { + return (value << 24 & unchecked((int)0xFF000000)) | (value >> 8 & unchecked((int)0x00FFFFFF)); + } + + private static ReadOnlySpan Sbox => + [ + 99, 124, 119, 123, 242, 107, 111, 197, 48, 1, 103, 43, 254, 215, 171, 118, + 202, 130, 201, 125, 250, 89, 71, 240, 173, 212, 162, 175, 156, 164, 114, 192, + 183, 253, 147, 38, 54, 63, 247, 204, 52, 165, 229, 241, 113, 216, 49, 21, + 4, 199, 35, 195, 24, 150, 5, 154, 7, 18, 128, 226, 235, 39, 178, 117, + 9, 131, 44, 26, 27, 110, 90, 160, 82, 59, 214, 179, 41, 227, 47, 132, + 83, 209, 0, 237, 32, 252, 177, 91, 106, 203, 190, 57, 74, 76, 88, 207, + 208, 239, 170, 251, 67, 77, 51, 133, 69, 249, 2, 127, 80, 60, 159, 168, + 81, 163, 64, 143, 146, 157, 56, 245, 188, 182, 218, 33, 16, 255, 243, 210, + 205, 12, 19, 236, 95, 151, 68, 23, 196, 167, 126, 61, 100, 93, 25, 115, + 96, 129, 79, 220, 34, 42, 144, 136, 70, 238, 184, 20, 222, 94, 11, 219, + 224, 50, 58, 10, 73, 6, 36, 92, 194, 211, 172, 98, 145, 149, 228, 121, + 231, 200, 55, 109, 141, 213, 78, 169, 108, 86, 244, 234, 101, 122, 174, 8, + 186, 120, 37, 46, 28, 166, 180, 198, 232, 221, 116, 31, 75, 189, 139, 138, + 112, 62, 181, 102, 72, 3, 246, 14, 97, 53, 87, 185, 134, 193, 29, 158, + 225, 248, 152, 17, 105, 217, 142, 148, 155, 30, 135, 233, 206, 85, 40, 223, + 140, 161, 137, 13, 191, 230, 66, 104, 65, 153, 45, 15, 176, 84, 187, 22 + ]; + + private static ReadOnlySpan Rcon => + [ + 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1B, 0x36, + 0x6C, 0xD8, 0xAB, 0x4D, 0x9A, 0x2F, 0x5E, 0xBC, 0x63, 0xC6, + 0x97, 0x35, 0x6A, 0xD4, 0xB3, 0x7D, 0xFA, 0xEF, 0xC5, 0x91 + ]; +} diff --git a/MinecraftClient/Crypto/AesHandler/FasterAesX86.cs b/MinecraftClient/Crypto/AesHandler/FasterAesX86.cs new file mode 100644 index 0000000000..715f8d9ef6 --- /dev/null +++ b/MinecraftClient/Crypto/AesHandler/FasterAesX86.cs @@ -0,0 +1,89 @@ +using System; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Runtime.Intrinsics; +using System.Runtime.Intrinsics.X86; + +namespace MinecraftClient.Crypto.AesHandler; + +public sealed class FasterAesX86 : IAesHandler +{ + private Vector128[] RoundKeys { get; } + + public FasterAesX86(ReadOnlySpan key) + { + RoundKeys = KeyExpansion(key); + } + + public static bool IsSupported() + { + return Sse2.IsSupported && Aes.IsSupported; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)] + public override void EncryptEcb(ReadOnlySpan plaintext, Span destination) + { + Vector128[] keys = RoundKeys; + + ReadOnlySpan> blocks = MemoryMarshal.Cast>(plaintext); + Span> dest = MemoryMarshal.Cast>(destination); + + _ = keys[10]; + + for (int i = 0; i < blocks.Length; i++) + { + Vector128 b = blocks[i]; + + b = Sse2.Xor(b, keys[0]); + b = Aes.Encrypt(b, keys[1]); + b = Aes.Encrypt(b, keys[2]); + b = Aes.Encrypt(b, keys[3]); + b = Aes.Encrypt(b, keys[4]); + b = Aes.Encrypt(b, keys[5]); + b = Aes.Encrypt(b, keys[6]); + b = Aes.Encrypt(b, keys[7]); + b = Aes.Encrypt(b, keys[8]); + b = Aes.Encrypt(b, keys[9]); + b = Aes.EncryptLast(b, keys[10]); + + dest[i] = b; + } + } + + private static Vector128[] KeyExpansion(ReadOnlySpan key) + { + Vector128[] keys = new Vector128[20]; + + keys[0] = Unsafe.ReadUnaligned>(ref MemoryMarshal.GetReference(key)); + + MakeRoundKey(keys, 1, 0x01); + MakeRoundKey(keys, 2, 0x02); + MakeRoundKey(keys, 3, 0x04); + MakeRoundKey(keys, 4, 0x08); + MakeRoundKey(keys, 5, 0x10); + MakeRoundKey(keys, 6, 0x20); + MakeRoundKey(keys, 7, 0x40); + MakeRoundKey(keys, 8, 0x80); + MakeRoundKey(keys, 9, 0x1B); + MakeRoundKey(keys, 10, 0x36); + + for (int i = 1; i < 10; i++) + keys[10 + i] = Aes.InverseMixColumns(keys[i]); + + return keys; + } + + private static void MakeRoundKey(Vector128[] keys, int index, byte rcon) + { + Vector128 s = keys[index - 1]; + Vector128 t = keys[index - 1]; + + t = Aes.KeygenAssist(t, rcon); + t = Sse2.Shuffle(t.AsUInt32(), 0xFF).AsByte(); + + s = Sse2.Xor(s, Sse2.ShiftLeftLogical128BitLane(s, 4)); + s = Sse2.Xor(s, Sse2.ShiftLeftLogical128BitLane(s, 8)); + + keys[index] = Sse2.Xor(s, t); + } +} diff --git a/MinecraftClient/Crypto/AesHandlerFactory.cs b/MinecraftClient/Crypto/AesHandlerFactory.cs new file mode 100644 index 0000000000..0c2e7b3044 --- /dev/null +++ b/MinecraftClient/Crypto/AesHandlerFactory.cs @@ -0,0 +1,20 @@ +using System; +using MinecraftClient.Crypto.AesHandler; + +namespace MinecraftClient.Crypto; + +internal static class AesHandlerFactory +{ + public static IAesHandler Create(ReadOnlySpan key) + { + byte[] ownedKey = key.ToArray(); + + if (FasterAesX86.IsSupported()) + return new FasterAesX86(ownedKey); + + if (FasterAesArm.IsSupported()) + return new FasterAesArm(ownedKey); + + return new BasicAes(ownedKey); + } +} diff --git a/MinecraftClient/Crypto/IAesHandler.cs b/MinecraftClient/Crypto/IAesHandler.cs new file mode 100644 index 0000000000..1f34ec9001 --- /dev/null +++ b/MinecraftClient/Crypto/IAesHandler.cs @@ -0,0 +1,12 @@ +using System; + +namespace MinecraftClient.Crypto; + +public abstract class IAesHandler : IDisposable +{ + public abstract void EncryptEcb(ReadOnlySpan plaintext, Span destination); + + public virtual void Dispose() + { + } +} diff --git a/MinecraftClient/FileMonitor.cs b/MinecraftClient/FileMonitor.cs index 16f590a4f1..d67610afa1 100644 --- a/MinecraftClient/FileMonitor.cs +++ b/MinecraftClient/FileMonitor.cs @@ -3,6 +3,7 @@ using System.IO; using System.Text; using System.Threading; +using System.Threading.Tasks; namespace MinecraftClient { @@ -12,7 +13,7 @@ namespace MinecraftClient public class FileMonitor : IDisposable { private readonly Tuple? monitor = null; - private readonly Tuple? polling = null; + private readonly Tuple? polling = null; /// /// Create a new FileMonitor and start monitoring @@ -48,9 +49,9 @@ public FileMonitor(string folder, string filename, FileSystemEventHandler handle monitor = null; var cancellationTokenSource = new CancellationTokenSource(); - polling = new Tuple(new Thread(() => PollingThread(folder, filename, handler, cancellationTokenSource.Token)), cancellationTokenSource); - polling.Item1.Name = String.Format("{0} Polling thread: {1}", GetType().Name, Path.Combine(folder, filename)); - polling.Item1.Start(); + polling = new Tuple( + Task.Run(() => PollingLoopAsync(folder, filename, handler, cancellationTokenSource.Token), cancellationTokenSource.Token), + cancellationTokenSource); } } @@ -66,25 +67,29 @@ public void Dispose() } /// - /// Fallback polling thread for use when operating system does not support FileSystemWatcher + /// Fallback polling loop for use when operating system does not support FileSystemWatcher /// /// Folder to monitor /// File name to monitor /// Callback when file changes - private void PollingThread(string folder, string filename, FileSystemEventHandler handler, CancellationToken cancellationToken) + private async Task PollingLoopAsync(string folder, string filename, FileSystemEventHandler handler, CancellationToken cancellationToken) { string filePath = Path.Combine(folder, filename); DateTime lastWrite = GetLastWrite(filePath); - while (!cancellationToken.IsCancellationRequested) + using PeriodicTimer periodicTimer = new(TimeSpan.FromSeconds(5)); + try { - Thread.Sleep(5000); - DateTime lastWriteNew = GetLastWrite(filePath); - if (lastWriteNew != lastWrite) + while (await periodicTimer.WaitForNextTickAsync(cancellationToken)) { - lastWrite = lastWriteNew; - handler(this, new FileSystemEventArgs(WatcherChangeTypes.Changed, folder, filename)); + DateTime lastWriteNew = GetLastWrite(filePath); + if (lastWriteNew != lastWrite) + { + lastWrite = lastWriteNew; + handler(this, new FileSystemEventArgs(WatcherChangeTypes.Changed, folder, filename)); + } } } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) { } } /// diff --git a/MinecraftClient/IConsoleBackend.cs b/MinecraftClient/IConsoleBackend.cs index b4d52ebf57..4759e0f865 100644 --- a/MinecraftClient/IConsoleBackend.cs +++ b/MinecraftClient/IConsoleBackend.cs @@ -1,4 +1,6 @@ using System; +using System.Threading; +using System.Threading.Tasks; namespace MinecraftClient { @@ -56,8 +58,12 @@ public interface IConsoleBackend string RequestImmediateInput(); + Task RequestImmediateInputAsync(CancellationToken cancellationToken); + string? ReadPassword(); + Task ReadPasswordAsync(CancellationToken cancellationToken); + void ClearInputBuffer(); bool DisplayUserInput { get; set; } diff --git a/MinecraftClient/MainThreadExecutionScope.cs b/MinecraftClient/MainThreadExecutionScope.cs new file mode 100644 index 0000000000..226e41ebdb --- /dev/null +++ b/MinecraftClient/MainThreadExecutionScope.cs @@ -0,0 +1,45 @@ +using System; +using System.Threading; + +namespace MinecraftClient +{ + internal static class MainThreadExecutionScope + { + private sealed class ScopeNode(object owner, ScopeNode? parent) : IDisposable + { + public object Owner { get; } = owner; + public ScopeNode? Parent { get; } = parent; + + public void Dispose() + { + if (!ReferenceEquals(s_currentScope.Value, this)) + throw new InvalidOperationException("Main-thread execution scope disposed out of order."); + + s_currentScope.Value = Parent; + } + } + + private static readonly AsyncLocal s_currentScope = new(); + + public static IDisposable Enter(object owner) + { + ScopeNode scopeNode = new(owner, s_currentScope.Value); + s_currentScope.Value = scopeNode; + return scopeNode; + } + + public static bool IsActive(object owner) + { + ScopeNode? scopeNode = s_currentScope.Value; + while (scopeNode is not null) + { + if (ReferenceEquals(scopeNode.Owner, owner)) + return true; + + scopeNode = scopeNode.Parent; + } + + return false; + } + } +} diff --git a/MinecraftClient/Mapping/Movement.cs b/MinecraftClient/Mapping/Movement.cs index 0e972e09fd..2bc3807c7a 100644 --- a/MinecraftClient/Mapping/Movement.cs +++ b/MinecraftClient/Mapping/Movement.cs @@ -1,7 +1,6 @@ using System; using System.Collections.Generic; using System.Threading; -using System.Threading.Tasks; namespace MinecraftClient.Mapping { @@ -150,17 +149,8 @@ public static Queue Move2Steps(Location start, Location goal, ref doub public static Queue? CalculatePath(World world, Location start, Location goal, bool allowUnsafe, int maxOffset, int minOffset, TimeSpan timeout) { - CancellationTokenSource cts = new(); - Task?> pathfindingTask = Task.Factory.StartNew(() => - CalculatePath(world, start, goal, allowUnsafe, maxOffset, minOffset, cts.Token)); - pathfindingTask.Wait(timeout); - if (!pathfindingTask.IsCompleted) - { - cts.Cancel(); - pathfindingTask.Wait(); - } - - return pathfindingTask.Result; + using CancellationTokenSource cts = new(timeout); + return CalculatePath(world, start, goal, allowUnsafe, maxOffset, minOffset, cts.Token); } /// @@ -713,4 +703,4 @@ public static bool CheckChunkLoading(World world, Location start, Location dest) return true; } } -} \ No newline at end of file +} diff --git a/MinecraftClient/McClient.cs b/MinecraftClient/McClient.cs index 4d23bb60ca..0ccaaf7599 100644 --- a/MinecraftClient/McClient.cs +++ b/MinecraftClient/McClient.cs @@ -4,6 +4,8 @@ using System.Net.Sockets; using System.Text; using System.Threading; +using System.Threading.Channels; +using System.Threading.Tasks; using Brigadier.NET; using Brigadier.NET.Exceptions; using MinecraftClient.ChatBots; @@ -42,10 +44,12 @@ public class McClient : IMinecraftComHandler private readonly Queue chatQueue = new(); private static DateTime nextMessageSendTime = DateTime.MinValue; - private readonly Queue threadTasks = new(); + private Queue threadTasks = new(); private readonly Lock threadTasksLock = new(); private readonly Lock recipeBookLock = new(); private readonly Lock achievementsLock = new(); + private readonly Lock consoleCommandProcessingLock = new(); + private readonly Lock networkAutoCompleteLock = new(); private readonly List bots = new(); private static readonly List botsOnHold = new(); @@ -223,7 +227,11 @@ public Dictionary GetTeams() IMinecraftCom handler = null!; SessionToken _sessionToken; CancellationTokenSource? cmdprompt = null; - Tuple? timeoutdetector = null; + private Channel? consoleCommandChannel; + private Task? consoleCommandProcessingTask; + private TaskCompletionSource? pendingNetworkAutoCompleteRequest; + private TaskCompletionSource? pendingCommandListInitialization; + Tuple? timeoutdetector = null; private int transferInProgress = 0; public ILogger Log; @@ -310,9 +318,10 @@ public McClient(SessionToken session, PlayerKeyPair? playerKeyPair, string serve handler = Protocol.ProtocolHandler.GetProtocolHandler(client, protocolversion, forgeInfo, this); Log.Info(Translations.mcc_version_supported); - timeoutdetector = new(new Thread(new ParameterizedThreadStart(TimeoutDetector)), new CancellationTokenSource()); - timeoutdetector.Item1.Name = "MCC Connection timeout detector"; - timeoutdetector.Item1.Start(timeoutdetector.Item2.Token); + CancellationTokenSource timeoutDetectorCancellationTokenSource = new(); + Task timeoutDetectorTask = TimeoutDetectorAsync(timeoutDetectorCancellationTokenSource.Token); + timeoutdetector = new(timeoutDetectorTask, timeoutDetectorCancellationTokenSource); + _ = ObserveTimeoutDetectorAsync(timeoutDetectorTask, timeoutDetectorCancellationTokenSource.Token); try { @@ -324,10 +333,7 @@ public McClient(SessionToken session, PlayerKeyPair? playerKeyPair, string serve Log.Info(string.Format(Translations.mcc_joined, Config.Main.Advanced.InternalCmdChar.ToLogString())); - cmdprompt = new CancellationTokenSource(); - ConsoleIO.Backend.BeginReadThread(); - ConsoleIO.Backend.MessageReceived += ConsoleReaderOnMessageReceived; - ConsoleIO.Backend.OnInputChange += ConsoleIO.AutocompleteHandler; + StartConsoleHandlers(); } else { @@ -363,15 +369,12 @@ public McClient(SessionToken session, PlayerKeyPair? playerKeyPair, string serve if (ReconnectionAttemptsLeft > 0) { Log.Info(string.Format(Translations.mcc_reconnect, ReconnectionAttemptsLeft)); - Thread.Sleep(5000); ReconnectionAttemptsLeft--; - Program.Restart(); + Program.Restart(5, announceDelay: false); } else if (InternalConfig.InteractiveMode) { - ConsoleIO.Backend.StopReadThread(); - ConsoleIO.Backend.MessageReceived -= ConsoleReaderOnMessageReceived; - ConsoleIO.Backend.OnInputChange -= ConsoleIO.AutocompleteHandler; + StopConsoleHandlers(); Program.HandleFailure(); } @@ -389,9 +392,7 @@ public McClient(SessionToken session, PlayerKeyPair? playerKeyPair, string serve // kick messages and Ignore_Kick_Message is false, or retry limit reached) if (InternalConfig.InteractiveMode) { - ConsoleIO.Backend.StopReadThread(); - ConsoleIO.Backend.MessageReceived -= ConsoleReaderOnMessageReceived; - ConsoleIO.Backend.OnInputChange -= ConsoleIO.AutocompleteHandler; + StopConsoleHandlers(); Program.HandleFailure(); } @@ -415,6 +416,7 @@ public void Transfer(string newHost, int newPort) try { Log.Info($"Initiating a transfer to: {newHost}:{newPort}"); + StopConsoleHandlers(); // Unload bots UnloadAllBots(); @@ -449,10 +451,7 @@ public void Transfer(string newHost, int newPort) UpdateKeepAlive(); Log.Info($"Successfully transferred connection and logged in to {newHost}:{newPort}."); - cmdprompt = new CancellationTokenSource(); - ConsoleIO.Backend.BeginReadThread(); - ConsoleIO.Backend.MessageReceived += ConsoleReaderOnMessageReceived; - ConsoleIO.Backend.OnInputChange += ConsoleIO.AutocompleteHandler; + StartConsoleHandlers(); } else { @@ -490,15 +489,12 @@ public void Transfer(string newHost, int newPort) if (ReconnectionAttemptsLeft > 0) { Log.Info($"Reconnecting... Attempts left: {ReconnectionAttemptsLeft}"); - Thread.Sleep(5000); ReconnectionAttemptsLeft--; - Program.Restart(); + Program.Restart(5, announceDelay: false); } else if (InternalConfig.InteractiveMode) { - ConsoleIO.Backend.StopReadThread(); - ConsoleIO.Backend.MessageReceived -= ConsoleReaderOnMessageReceived; - ConsoleIO.Backend.OnInputChange -= ConsoleIO.AutocompleteHandler; + StopConsoleHandlers(); Program.HandleFailure(); } @@ -703,15 +699,22 @@ public void OnUpdate() } } + Queue? pendingThreadTasks = null; lock (threadTasksLock) { - while (threadTasks.Count > 0) + if (threadTasks.Count > 0) { - Action taskToRun = threadTasks.Dequeue(); - taskToRun(); + pendingThreadTasks = threadTasks; + threadTasks = new(); } } + if (pendingThreadTasks is not null) + { + while (pendingThreadTasks.Count > 0) + pendingThreadTasks.Dequeue().ExecuteSynchronously(); + } + lock (DigLock) { if (RemainingDiggingTime > 0) @@ -734,29 +737,44 @@ public void OnUpdate() /// /// Periodically checks for server keepalives and consider that connection has been lost if the last received keepalive is too old. /// - private void TimeoutDetector(object? o) + private async Task TimeoutDetectorAsync(CancellationToken cancellationToken) { UpdateKeepAlive(); - do + using PeriodicTimer periodicTimer = new(TimeSpan.FromSeconds(15)); + try { - Thread.Sleep(TimeSpan.FromSeconds(15)); - - if (((CancellationToken)o!).IsCancellationRequested) - return; - - lock (lastKeepAliveLock) + while (await periodicTimer.WaitForNextTickAsync(cancellationToken)) { - if (lastKeepAlive.AddSeconds(Config.Main.Advanced.TcpTimeout) < DateTime.Now) + lock (lastKeepAliveLock) { - if (((CancellationToken)o!).IsCancellationRequested) - return; + if (lastKeepAlive.AddSeconds(Config.Main.Advanced.TcpTimeout) < DateTime.Now) + { + cancellationToken.ThrowIfCancellationRequested(); - OnConnectionLost(ChatBot.DisconnectReason.ConnectionLost, Translations.error_timeout); - return; + OnConnectionLost(ChatBot.DisconnectReason.ConnectionLost, Translations.error_timeout); + return; + } } } } - while (!((CancellationToken)o!).IsCancellationRequested); + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + } + } + + private async Task ObserveTimeoutDetectorAsync(Task timeoutDetectorTask, CancellationToken cancellationToken) + { + try + { + await timeoutDetectorTask; + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + } + catch (Exception e) + { + Log.Warn(e.ToString()); + } } /// @@ -770,6 +788,259 @@ private void UpdateKeepAlive() } } + private void StartConsoleHandlers() + { + if (ConsoleIO.Backend is null) + return; + + cmdprompt = new CancellationTokenSource(); + StartConsoleCommandProcessing(cmdprompt.Token); + ConsoleIO.Backend.BeginReadThread(); + ConsoleIO.Backend.MessageReceived += ConsoleReaderOnMessageReceived; + ConsoleIO.Backend.OnInputChange += ConsoleIO.AutocompleteHandler; + } + + private void StopConsoleHandlers() + { + if (ConsoleIO.Backend is not null) + { + ConsoleIO.Backend.StopReadThread(); + ConsoleIO.Backend.MessageReceived -= ConsoleReaderOnMessageReceived; + ConsoleIO.Backend.OnInputChange -= ConsoleIO.AutocompleteHandler; + } + + StopConsoleCommandProcessing(); + } + + private void StartConsoleCommandProcessing(CancellationToken cancellationToken) + { + lock (consoleCommandProcessingLock) + { + consoleCommandChannel = Channel.CreateUnbounded(new UnboundedChannelOptions() + { + SingleReader = true, + SingleWriter = false, + AllowSynchronousContinuations = false + }); + consoleCommandProcessingTask = ProcessConsoleMessagesAsync(consoleCommandChannel.Reader, cancellationToken); + _ = ObserveConsoleCommandProcessingAsync(consoleCommandProcessingTask, cancellationToken); + } + } + + private void StopConsoleCommandProcessing() + { + Channel? activeChannel; + + lock (consoleCommandProcessingLock) + { + activeChannel = consoleCommandChannel; + consoleCommandChannel = null; + } + + activeChannel?.Writer.TryComplete(); + + if (cmdprompt is not null) + { + cmdprompt.Cancel(); + cmdprompt = null; + } + + CancelPendingNetworkAutoComplete(); + CancelPendingCommandListInitialization(); + } + + private async Task ObserveConsoleCommandProcessingAsync(Task processingTask, CancellationToken cancellationToken) + { + try + { + await processingTask; + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + } + catch (Exception e) + { + Log.Warn(e.ToString()); + } + finally + { + lock (consoleCommandProcessingLock) + { + if (ReferenceEquals(consoleCommandProcessingTask, processingTask)) + consoleCommandProcessingTask = null; + } + } + } + + private async Task ProcessConsoleMessagesAsync(ChannelReader channelReader, CancellationToken cancellationToken) + { + await foreach (string message in channelReader.ReadAllAsync(cancellationToken)) + { + if (cancellationToken.IsCancellationRequested) + return; + + if (TryParseBasicIoAutocompleteRequest(message, out _)) + await HandleBasicIoAutocompleteRequestAsync(message, cancellationToken); + else + await InvokeOnMainThreadAsync(() => HandleCommandPromptText(message)); + } + } + + private async Task HandleBasicIoAutocompleteRequestAsync(string text, CancellationToken cancellationToken) + { + try + { + string[] command = text[1..].Split((char)0x00); + if (command.Length < 2 || !command[0].Equals("autocomplete", StringComparison.OrdinalIgnoreCase)) + return; + + await WaitForCommandListInitializationAsync(cancellationToken); + + Task requestTask = InvokeRequired + ? await InvokeOnMainThreadAsync(() => BeginNetworkAutoCompleteRequest(command[1])) + : BeginNetworkAutoCompleteRequest(command[1]); + + await requestTask.WaitAsync(cancellationToken); + + if (command.Length > 1) + ConsoleIO.WriteLine((char)0x00 + "autocomplete" + (char)0x00 + ConsoleIO.AutoCompleteResult); + else ConsoleIO.WriteLine((char)0x00 + "autocomplete" + (char)0x00); + } + catch (OperationCanceledException) when (cancellationToken.IsCancellationRequested) + { + } + catch (OperationCanceledException) + { + } + } + + private static bool TryParseBasicIoAutocompleteRequest(string text, out string behindCursor) + { + behindCursor = string.Empty; + + if (!ConsoleIO.BasicIO || string.IsNullOrEmpty(text) || text[0] != (char)0x00) + return false; + + string[] command = text[1..].Split((char)0x00); + if (command.Length < 2 || !command[0].Equals("autocomplete", StringComparison.OrdinalIgnoreCase)) + return false; + + behindCursor = command[1]; + return true; + } + + private Task BeginNetworkAutoCompleteRequest(string behindCursor) + { + if (string.IsNullOrEmpty(behindCursor)) + return Task.FromResult(Array.Empty()); + + TaskCompletionSource request = new(TaskCreationOptions.RunContinuationsAsynchronously); + lock (networkAutoCompleteLock) + { + pendingNetworkAutoCompleteRequest?.TrySetException(new OperationCanceledException()); + pendingNetworkAutoCompleteRequest = request; + } + + try + { + if (handler.AutoComplete(behindCursor) < 0) + { + CompletePendingNetworkAutoComplete(Array.Empty()); + } + } + catch (Exception e) + { + lock (networkAutoCompleteLock) + { + if (ReferenceEquals(pendingNetworkAutoCompleteRequest, request)) + pendingNetworkAutoCompleteRequest = null; + } + request.TrySetException(e); + } + + return request.Task; + } + + private void BeginCommandListInitialization() + { + lock (networkAutoCompleteLock) + { + pendingCommandListInitialization?.TrySetCanceled(); + pendingCommandListInitialization = new(TaskCreationOptions.RunContinuationsAsynchronously); + } + } + + private void CompletePendingNetworkAutoComplete(string[] result) + { + TaskCompletionSource? pendingRequest; + lock (networkAutoCompleteLock) + { + pendingRequest = pendingNetworkAutoCompleteRequest; + pendingNetworkAutoCompleteRequest = null; + } + + pendingRequest?.TrySetResult(result); + } + + private void CancelPendingNetworkAutoComplete() + { + TaskCompletionSource? pendingRequest; + lock (networkAutoCompleteLock) + { + pendingRequest = pendingNetworkAutoCompleteRequest; + pendingNetworkAutoCompleteRequest = null; + } + + pendingRequest?.TrySetCanceled(); + } + + private void CompletePendingCommandListInitialization() + { + TaskCompletionSource? pendingInitialization; + lock (networkAutoCompleteLock) + { + pendingInitialization = pendingCommandListInitialization; + pendingCommandListInitialization = null; + } + + pendingInitialization?.TrySetResult(true); + } + + private void CancelPendingCommandListInitialization() + { + TaskCompletionSource? pendingInitialization; + lock (networkAutoCompleteLock) + { + pendingInitialization = pendingCommandListInitialization; + pendingCommandListInitialization = null; + } + + pendingInitialization?.TrySetCanceled(); + } + + private async Task WaitForCommandListInitializationAsync(CancellationToken cancellationToken) + { + Task? initializationTask; + lock (networkAutoCompleteLock) + { + initializationTask = pendingCommandListInitialization?.Task; + } + + if (initializationTask is null) + return; + + try + { + await initializationTask.WaitAsync(TimeSpan.FromSeconds(1), cancellationToken); + } + catch (TimeoutException) + { + } + catch (OperationCanceledException) + { + } + } + /// /// Disconnect the client from the server (initiated from MCC) /// @@ -781,6 +1052,7 @@ public void Disconnect() botsOnHold.Clear(); botsOnHold.AddRange(bots); + StopConsoleHandlers(); if (handler is not null) { @@ -788,12 +1060,6 @@ public void Disconnect() handler.Dispose(); } - if (cmdprompt is not null) - { - cmdprompt.Cancel(); - cmdprompt = null; - } - if (timeoutdetector is not null) { timeoutdetector.Item2.Cancel(); @@ -820,8 +1086,7 @@ public void OnConnectionLost(ChatBot.DisconnectReason reason, string message) if (timeoutdetector is not null) { - if (timeoutdetector is not null && Thread.CurrentThread != timeoutdetector.Item1) - timeoutdetector.Item2.Cancel(); + timeoutdetector.Item2.Cancel(); timeoutdetector = null; } @@ -872,9 +1137,7 @@ public void OnConnectionLost(ChatBot.DisconnectReason reason, string message) if (!will_restart) { - ConsoleIO.Backend.StopReadThread(); - ConsoleIO.Backend.MessageReceived -= ConsoleReaderOnMessageReceived; - ConsoleIO.Backend.OnInputChange -= ConsoleIO.AutocompleteHandler; + StopConsoleHandlers(); Program.HandleFailure(null, false, reason); } } @@ -885,16 +1148,18 @@ public void OnConnectionLost(ChatBot.DisconnectReason reason, string message) private void ConsoleReaderOnMessageReceived(object? sender, string e) { + Channel? activeChannel; + lock (consoleCommandProcessingLock) + { + activeChannel = consoleCommandChannel; + } - if (client.Client is null) + if (activeChannel is null || client.Client is null) return; if (client.Client.Connected) { - new Thread(() => - { - InvokeOnMainThread(() => HandleCommandPromptText(e)); - }).Start(); + activeChannel.Writer.TryWrite(e); } else return; @@ -916,56 +1181,45 @@ private void HandleCommandPromptText(string text) { if (ConsoleIO.BasicIO && text.Length > 0 && text[0] == (char)0x00) { - //Process a request from the GUI - string[] command = text[1..].Split((char)0x00); - switch (command[0].ToLower()) - { - case "autocomplete": - int id = handler.AutoComplete(command[1]); - while (!ConsoleIO.AutoCompleteDone) { Thread.Sleep(100); } - if (command.Length > 1) { ConsoleIO.WriteLine((char)0x00 + "autocomplete" + (char)0x00 + ConsoleIO.AutoCompleteResult); } - else ConsoleIO.WriteLine((char)0x00 + "autocomplete" + (char)0x00); - break; - } + _ = HandleBasicIoAutocompleteRequestAsync(text, CancellationToken.None); + return; } - else - { - text = text.Trim(); - if (text.Length > 1 - && Config.Main.Advanced.InternalCmdChar == MainConfigHelper.MainConfig.AdvancedConfig.InternalCmdCharType.none - && text[0] == '/') - { - SendText(text); - } - else if (text.Length > 2 - && Config.Main.Advanced.InternalCmdChar != MainConfigHelper.MainConfig.AdvancedConfig.InternalCmdCharType.none - && text[0] == Config.Main.Advanced.InternalCmdChar.ToChar() - && text[1] == '/') - { - SendText(text[1..]); - } - else if (text.Length > 0) + text = text.Trim(); + + if (text.Length > 1 + && Config.Main.Advanced.InternalCmdChar == MainConfigHelper.MainConfig.AdvancedConfig.InternalCmdCharType.none + && text[0] == '/') + { + SendText(text); + } + else if (text.Length > 2 + && Config.Main.Advanced.InternalCmdChar != MainConfigHelper.MainConfig.AdvancedConfig.InternalCmdCharType.none + && text[0] == Config.Main.Advanced.InternalCmdChar.ToChar() + && text[1] == '/') + { + SendText(text[1..]); + } + else if (text.Length > 0) + { + if (Config.Main.Advanced.InternalCmdChar == MainConfigHelper.MainConfig.AdvancedConfig.InternalCmdCharType.none + || text[0] == Config.Main.Advanced.InternalCmdChar.ToChar()) { - if (Config.Main.Advanced.InternalCmdChar == MainConfigHelper.MainConfig.AdvancedConfig.InternalCmdCharType.none - || text[0] == Config.Main.Advanced.InternalCmdChar.ToChar()) + CmdResult result = new(); + string command = Config.Main.Advanced.InternalCmdChar.ToChar() == ' ' ? text : text[1..]; + if (!PerformInternalCommand(Config.AppVar.ExpandVars(command), ref result, Settings.Config.AppVar.GetVariables()) && Config.Main.Advanced.InternalCmdChar.ToChar() == '/') { - CmdResult result = new(); - string command = Config.Main.Advanced.InternalCmdChar.ToChar() == ' ' ? text : text[1..]; - if (!PerformInternalCommand(Config.AppVar.ExpandVars(command), ref result, Settings.Config.AppVar.GetVariables()) && Config.Main.Advanced.InternalCmdChar.ToChar() == '/') - { - SendText(text); - } - else if (result.status != CmdResult.Status.NotRun && (result.status != CmdResult.Status.Done || !string.IsNullOrWhiteSpace(result.result))) - { - Log.Info(result); - } + SendText(text); } - else + else if (result.status != CmdResult.Status.NotRun && (result.status != CmdResult.Status.Done || !string.IsNullOrWhiteSpace(result.result))) { - SendText(text); + Log.Info(result); } } + else + { + SendText(text); + } } } @@ -1099,19 +1353,7 @@ public void UnloadAllBots() /// Type of the return value public T InvokeOnMainThread(Func task) { - if (!InvokeRequired) - { - return task(); - } - else - { - TaskWithResult taskWithResult = new(task); - lock (threadTasksLock) - { - threadTasks.Enqueue(taskWithResult.ExecuteSynchronously); - } - return taskWithResult.WaitGetResult(); - } + return InvokeOnMainThreadAsync(task).GetAwaiter().GetResult(); } /// @@ -1126,6 +1368,37 @@ public void InvokeOnMainThread(Action task) InvokeOnMainThread(() => { task(); return true; }); } + private Task InvokeOnMainThreadAsync(Func task) + { + if (!InvokeRequired) + { + try + { + return Task.FromResult(task()); + } + catch (Exception e) + { + return Task.FromException(e); + } + } + + TaskWithResult taskWithResult = new(task); + lock (threadTasksLock) + { + threadTasks.Enqueue(taskWithResult); + } + return taskWithResult.AsTask(); + } + + private Task InvokeOnMainThreadAsync(Action task) + { + return InvokeOnMainThreadAsync(() => + { + task(); + return true; + }); + } + /// /// Clear all tasks /// @@ -1133,7 +1406,8 @@ public void ClearTasks() { lock (threadTasksLock) { - threadTasks.Clear(); + while (threadTasks.Count > 0) + threadTasks.Dequeue().Cancel(); } } @@ -1145,16 +1419,13 @@ public bool InvokeRequired { get { - int callingThreadId = Environment.CurrentManagedThreadId; - if (handler is not null) - { - return handler.GetNetMainThreadId() != callingThreadId; - } - else + if (handler is null) { // net read thread (main thread) not yet ready return false; } + + return !MainThreadExecutionScope.IsActive(this); } } @@ -3040,6 +3311,7 @@ public void OnGameJoined(bool isOnlineMode) DispatchBotEvent(bot => bot.AfterGameJoined()); + BeginCommandListInitialization(); ConsoleIO.InitCommandList(dispatcher); } @@ -4445,6 +4717,8 @@ public void OnBlockEntityData(Location location, Dictionary? nbt public void OnAutoCompleteDone(int transactionId, string[] result) { ConsoleIO.OnAutoCompleteDone(transactionId, result); + CompletePendingNetworkAutoComplete(result); + CompletePendingCommandListInitialization(); } public void SetCanSendMessage(bool canSendMessage) diff --git a/MinecraftClient/Mcp/IMccMcpCapabilities.cs b/MinecraftClient/Mcp/IMccMcpCapabilities.cs index 055fb56146..273a37cb8e 100644 --- a/MinecraftClient/Mcp/IMccMcpCapabilities.cs +++ b/MinecraftClient/Mcp/IMccMcpCapabilities.cs @@ -1,3 +1,5 @@ +using System.Threading.Tasks; + namespace MinecraftClient.Mcp; public interface IMccMcpCapabilities @@ -33,6 +35,7 @@ public interface IMccMcpCapabilities MccMcpResult SelectHotbarItem(string itemType, bool preferLowestSlot); MccMcpResult UseItemOnBlock(double x, double y, double z); MccMcpResult DigBlock(double x, double y, double z, double durationSeconds); + Task DigBlockAsync(double x, double y, double z, double durationSeconds); MccMcpResult PlaceBlock(int x, int y, int z, string face, string hand, bool lookAtBlock); MccMcpResult InteractEntity(int entityId, string interaction, string hand); MccMcpResult AttackEntity(int entityId); @@ -43,7 +46,9 @@ public interface IMccMcpCapabilities MccMcpResult FindNearestEntity(string? typeFilter, string? nameFilter, double radius, bool includePlayers); MccMcpResult CanReachPosition(double x, double y, double z, bool allowUnsafe, int maxOffset, int minOffset, int timeoutMs); MccMcpResult MoveTo(double x, double y, double z, bool allowUnsafe, bool allowDirectTeleport, int maxOffset, int minOffset, int timeoutMs); + Task MoveToAsync(double x, double y, double z, bool allowUnsafe, bool allowDirectTeleport, int maxOffset, int minOffset, int timeoutMs); MccMcpResult MoveToPlayer(string playerName, bool allowUnsafe, bool allowDirectTeleport, int maxOffset, int minOffset, int timeoutMs); + Task MoveToPlayerAsync(string playerName, bool allowUnsafe, bool allowDirectTeleport, int maxOffset, int minOffset, int timeoutMs); MccMcpResult LookAt(double x, double y, double z); MccMcpResult LookDirection(string direction); MccMcpResult LookAngles(float yaw, float pitch); @@ -51,16 +56,22 @@ public interface IMccMcpCapabilities MccMcpResult GetInventorySnapshot(int inventoryId); MccMcpResult SearchInventories(string query, int maxCount, bool exactMatch, bool includeContainers); MccMcpResult OpenContainerAt(int x, int y, int z, int timeoutMs, bool closeCurrent); + Task OpenContainerAtAsync(int x, int y, int z, int timeoutMs, bool closeCurrent); MccMcpResult CloseContainer(int inventoryId, int timeoutMs); + Task CloseContainerAsync(int inventoryId, int timeoutMs); MccMcpResult InventoryWindowAction(int inventoryId, int slotId, string actionType); MccMcpResult DropInventoryItem(string itemType, int count, int inventoryId, bool preferStack); + Task DropInventoryItemAsync(string itemType, int count, int inventoryId, bool preferStack); MccMcpResult DepositContainerItem(string itemType, int count, int inventoryId, bool preferLargestStack); + Task DepositContainerItemAsync(string itemType, int count, int inventoryId, bool preferLargestStack); MccMcpResult WithdrawContainerItem(string itemType, int count, int inventoryId, bool preferLargestStack); + Task WithdrawContainerItemAsync(string itemType, int count, int inventoryId, bool preferLargestStack); MccMcpResult QueryEntities(int maxCount); MccMcpResult ListEntities(int maxCount, string? typeFilter, double radius); MccMcpResult GetEntityInfo(int entityId, bool includeMetadata, bool includeEquipment, bool includeEffects); MccMcpResult FindSigns(string text, bool exactMatch, int radius, int maxCount, bool includeBackText); MccMcpResult ListItemEntities(string? itemType, double radius, int maxCount); MccMcpResult PickupItems(string itemType, double radius, int maxItems, bool allowUnsafe, int timeoutMs); + Task PickupItemsAsync(string itemType, double radius, int maxItems, bool allowUnsafe, int timeoutMs); MccMcpResult GetWorldBlockAt(int x, int y, int z); } diff --git a/MinecraftClient/Mcp/MccEmbeddedMcpHost.cs b/MinecraftClient/Mcp/MccEmbeddedMcpHost.cs index a40b718342..b6a9eb7021 100644 --- a/MinecraftClient/Mcp/MccEmbeddedMcpHost.cs +++ b/MinecraftClient/Mcp/MccEmbeddedMcpHost.cs @@ -1,4 +1,6 @@ using System; +using System.Threading; +using System.Threading.Tasks; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Http; @@ -12,7 +14,7 @@ public sealed class MccEmbeddedMcpHost { private readonly MccMcpConfig config; private readonly IMccMcpCapabilities capabilities; - private readonly object stateLock = new(); + private readonly SemaphoreSlim stateLock = new(1, 1); private WebApplication? app; public MccEmbeddedMcpHost(MccMcpConfig config, IMccMcpCapabilities capabilities) @@ -25,10 +27,7 @@ public bool IsRunning { get { - lock (stateLock) - { - return app is not null; - } + return app is not null; } } @@ -36,29 +35,37 @@ public bool IsRunning public bool Start(out string? error) { - lock (stateLock) + (bool success, string? startError) = StartAsync().GetAwaiter().GetResult(); + error = startError; + return success; + } + + public bool Stop(out string? error) + { + (bool success, string? stopError) = StopAsync().GetAwaiter().GetResult(); + error = stopError; + return success; + } + + public async Task<(bool Success, string? Error)> StartAsync(CancellationToken cancellationToken = default) + { + await stateLock.WaitAsync(cancellationToken); + try { - error = null; if (app is not null) - return true; + return (true, null); string route = NormalizeRoute(config.Transport.Route); string bindHost = string.IsNullOrWhiteSpace(config.Transport.BindHost) ? "127.0.0.1" : config.Transport.BindHost.Trim(); if (config.Transport.Port is < 1 or > 65535) - { - error = "invalid_port"; - return false; - } + return (false, "invalid_port"); string? requiredToken = null; if (config.Transport.RequireAuthToken) { requiredToken = Environment.GetEnvironmentVariable(config.Transport.AuthTokenEnvVar); if (string.IsNullOrWhiteSpace(requiredToken)) - { - error = "missing_auth_token"; - return false; - } + return (false, "missing_auth_token"); } WebApplicationBuilder builder = WebApplication.CreateBuilder(); @@ -96,33 +103,49 @@ public bool Start(out string? error) } builtApp.MapMcp(route); - builtApp.StartAsync().GetAwaiter().GetResult(); - app = builtApp; - return true; + + try + { + await builtApp.StartAsync(cancellationToken); + app = builtApp; + return (true, null); + } + catch + { + await builtApp.DisposeAsync(); + throw; + } + } + finally + { + stateLock.Release(); } } - public bool Stop(out string? error) + public async Task<(bool Success, string? Error)> StopAsync(CancellationToken cancellationToken = default) { - lock (stateLock) + await stateLock.WaitAsync(cancellationToken); + try { - error = null; if (app is null) - return true; + return (true, null); try { - app.StopAsync().GetAwaiter().GetResult(); - app.DisposeAsync().AsTask().GetAwaiter().GetResult(); + await app.StopAsync(cancellationToken); + await app.DisposeAsync(); app = null; - return true; + return (true, null); } catch { - error = "stop_failed"; - return false; + return (false, "stop_failed"); } } + finally + { + stateLock.Release(); + } } private static string NormalizeRoute(string route) diff --git a/MinecraftClient/Mcp/MccMcpCapabilities.cs b/MinecraftClient/Mcp/MccMcpCapabilities.cs index 60a03f767d..01521c5cde 100644 --- a/MinecraftClient/Mcp/MccMcpCapabilities.cs +++ b/MinecraftClient/Mcp/MccMcpCapabilities.cs @@ -71,6 +71,8 @@ private sealed class NearbyItemSnapshot public required double Distance { get; init; } } + private readonly record struct ContainerOpenState(int InventoryId, Container? Inventory); + private enum InventoryTransferDirection { Deposit, @@ -859,6 +861,11 @@ public MccMcpResult UseItemOnBlock(double x, double y, double z) } public MccMcpResult DigBlock(double x, double y, double z, double durationSeconds) + { + return DigBlockAsync(x, y, z, durationSeconds).GetAwaiter().GetResult(); + } + + public async Task DigBlockAsync(double x, double y, double z, double durationSeconds) { if (!IsCategoryEnabled(t => t.Movement)) return MccMcpResult.Fail("capability_disabled"); @@ -921,7 +928,10 @@ public MccMcpResult DigBlock(double x, double y, double z, double durationSecond if (!accepted) continue; - if (WaitForBlockChange(client, target, beforeBlock, GetDigVerifyWaitMs(attemptDuration), out afterBlock)) + (bool blockChanged, Block updatedBlock) = + await WaitForBlockChangeAsync(client, target, beforeBlock, GetDigVerifyWaitMs(attemptDuration)); + afterBlock = updatedBlock; + if (blockChanged) { changed = true; break; @@ -1274,6 +1284,11 @@ public MccMcpResult FindNearestEntity(string? typeFilter, string? nameFilter, do } public MccMcpResult MoveTo(double x, double y, double z, bool allowUnsafe, bool allowDirectTeleport, int maxOffset, int minOffset, int timeoutMs) + { + return MoveToAsync(x, y, z, allowUnsafe, allowDirectTeleport, maxOffset, minOffset, timeoutMs).GetAwaiter().GetResult(); + } + + public async Task MoveToAsync(double x, double y, double z, bool allowUnsafe, bool allowDirectTeleport, int maxOffset, int minOffset, int timeoutMs) { if (!IsCategoryEnabled(t => t.Movement)) return MccMcpResult.Fail("capability_disabled"); @@ -1302,9 +1317,12 @@ public MccMcpResult MoveTo(double x, double y, double z, bool allowUnsafe, bool int verifyWaitMs = GetArrivalWaitMs(timeoutMs); double tolerance = GetArrivalTolerance(maxOffset, minOffset); - Location? finalLocation = null; - bool arrived = pathFound && WaitForArrival(client, goal, verifyWaitMs, tolerance, out finalLocation); - finalLocation ??= client.InvokeOnMainThread(client.GetCurrentLocation); + Location finalLocation = client.InvokeOnMainThread(client.GetCurrentLocation); + bool arrived = false; + if (pathFound) + { + (arrived, finalLocation) = await WaitForArrivalAsync(client, goal, verifyWaitMs, tolerance); + } object resultData = new { pathFound, @@ -1313,9 +1331,9 @@ public MccMcpResult MoveTo(double x, double y, double z, bool allowUnsafe, bool verifyWaitMs, target = ToCoordinate(goal), startLocation = ToCoordinate(startLocation), - finalLocation = ToCoordinate(finalLocation.Value), - finalDistance = GetDistance(finalLocation.Value, goal), - distanceMoved = GetDistance(startLocation, finalLocation.Value), + finalLocation = ToCoordinate(finalLocation), + finalDistance = GetDistance(finalLocation, goal), + distanceMoved = GetDistance(startLocation, finalLocation), allowUnsafe, allowDirectTeleport, maxOffset, @@ -1329,10 +1347,15 @@ public MccMcpResult MoveTo(double x, double y, double z, bool allowUnsafe, bool } public MccMcpResult MoveToPlayer(string playerName, bool allowUnsafe, bool allowDirectTeleport, int maxOffset, int minOffset, int timeoutMs) + { + return MoveToPlayerAsync(playerName, allowUnsafe, allowDirectTeleport, maxOffset, minOffset, timeoutMs).GetAwaiter().GetResult(); + } + + public async Task MoveToPlayerAsync(string playerName, bool allowUnsafe, bool allowDirectTeleport, int maxOffset, int minOffset, int timeoutMs) { if (!IsCategoryEnabled(t => t.Movement)) return MccMcpResult.Fail("capability_disabled"); - return ToMcpResult(game.MoveToPlayer(playerName, allowUnsafe, allowDirectTeleport, maxOffset, minOffset, timeoutMs)); + return ToMcpResult(await game.MoveToPlayerAsync(playerName, allowUnsafe, allowDirectTeleport, maxOffset, minOffset, timeoutMs)); } public MccMcpResult LookAt(double x, double y, double z) @@ -1464,6 +1487,11 @@ public MccMcpResult ListInventories() } public MccMcpResult OpenContainerAt(int x, int y, int z, int timeoutMs, bool closeCurrent) + { + return OpenContainerAtAsync(x, y, z, timeoutMs, closeCurrent).GetAwaiter().GetResult(); + } + + public async Task OpenContainerAtAsync(int x, int y, int z, int timeoutMs, bool closeCurrent) { if (!IsCategoryEnabled(t => t.Inventory)) return MccMcpResult.Fail("capability_disabled"); @@ -1495,10 +1523,15 @@ public MccMcpResult OpenContainerAt(int x, int y, int z, int timeoutMs, bool clo }); } - return OpenContainerCore(client, location, state.block, state.activeContainerId, waitMs, closeCurrent); + return await OpenContainerCoreAsync(client, location, state.block, state.activeContainerId, waitMs, closeCurrent); } public MccMcpResult CloseContainer(int inventoryId, int timeoutMs) + { + return CloseContainerAsync(inventoryId, timeoutMs).GetAwaiter().GetResult(); + } + + public async Task CloseContainerAsync(int inventoryId, int timeoutMs) { if (!IsCategoryEnabled(t => t.Inventory)) return MccMcpResult.Fail("capability_disabled"); @@ -1527,7 +1560,7 @@ public MccMcpResult CloseContainer(int inventoryId, int timeoutMs) } bool closeAccepted = client.CloseInventory(resolvedInventoryId); - bool closed = closeAccepted && WaitForContainerClose(client, resolvedInventoryId, waitMs); + bool closed = closeAccepted && await WaitForContainerCloseAsync(client, resolvedInventoryId, waitMs); var resultData = new { success = closeAccepted && closed, @@ -1565,6 +1598,11 @@ public MccMcpResult InventoryWindowAction(int inventoryId, int slotId, string ac } public MccMcpResult DropInventoryItem(string itemType, int count, int inventoryId, bool preferStack) + { + return DropInventoryItemAsync(itemType, count, inventoryId, preferStack).GetAwaiter().GetResult(); + } + + public async Task DropInventoryItemAsync(string itemType, int count, int inventoryId, bool preferStack) { if (!IsCategoryEnabled(t => t.Inventory)) return MccMcpResult.Fail("capability_disabled"); @@ -1587,10 +1625,10 @@ public MccMcpResult DropInventoryItem(string itemType, int count, int inventoryI }); } - return client.InvokeOnMainThread(() => + return await Task.Run(async () => { - Dictionary inventories = client.GetInventories(); - if (!inventories.TryGetValue(inventoryId, out Container? inventory)) + Container? inventory = client.GetInventory(inventoryId); + if (inventory is null) return MccMcpResult.Fail("invalid_state"); int cursorCount = GetCursorItemCount(inventory, parsedItemType); @@ -1634,7 +1672,7 @@ public MccMcpResult DropInventoryItem(string itemType, int count, int inventoryI int dropFromSlot = Math.Min(remaining, currentItem.Count); touchedSlots.Add(entry.slot); - bool ok = TryDropInventorySlotItems(client, inventoryId, inventory, entry.slot, parsedItemType, dropFromSlot, out int droppedFromSlot); + (bool ok, int droppedFromSlot) = await TryDropInventorySlotItemsAsync(client, inventoryId, inventory, entry.slot, parsedItemType, dropFromSlot); if (!ok) { @@ -1700,12 +1738,22 @@ public MccMcpResult DropInventoryItem(string itemType, int count, int inventoryI public MccMcpResult DepositContainerItem(string itemType, int count, int inventoryId, bool preferLargestStack) { - return TransferContainerItem(itemType, count, inventoryId, preferLargestStack, InventoryTransferDirection.Deposit); + return DepositContainerItemAsync(itemType, count, inventoryId, preferLargestStack).GetAwaiter().GetResult(); + } + + public Task DepositContainerItemAsync(string itemType, int count, int inventoryId, bool preferLargestStack) + { + return TransferContainerItemAsync(itemType, count, inventoryId, preferLargestStack, InventoryTransferDirection.Deposit); } public MccMcpResult WithdrawContainerItem(string itemType, int count, int inventoryId, bool preferLargestStack) { - return TransferContainerItem(itemType, count, inventoryId, preferLargestStack, InventoryTransferDirection.Withdraw); + return WithdrawContainerItemAsync(itemType, count, inventoryId, preferLargestStack).GetAwaiter().GetResult(); + } + + public Task WithdrawContainerItemAsync(string itemType, int count, int inventoryId, bool preferLargestStack) + { + return TransferContainerItemAsync(itemType, count, inventoryId, preferLargestStack, InventoryTransferDirection.Withdraw); } public MccMcpResult QueryEntities(int maxCount) @@ -1824,10 +1872,15 @@ public MccMcpResult ListItemEntities(string? itemType, double radius, int maxCou } public MccMcpResult PickupItems(string itemType, double radius, int maxItems, bool allowUnsafe, int timeoutMs) + { + return PickupItemsAsync(itemType, radius, maxItems, allowUnsafe, timeoutMs).GetAwaiter().GetResult(); + } + + public async Task PickupItemsAsync(string itemType, double radius, int maxItems, bool allowUnsafe, int timeoutMs) { if (!IsCategoryEnabled(t => t.EntityWorld) || !IsCategoryEnabled(t => t.Movement)) return MccMcpResult.Fail("capability_disabled"); - return ToMcpResult(game.PickupItems(itemType, radius, maxItems, allowUnsafe, timeoutMs)); + return ToMcpResult(await game.PickupItemsAsync(itemType, radius, maxItems, allowUnsafe, timeoutMs)); } public MccMcpResult GetWorldBlockAt(int x, int y, int z) @@ -1858,7 +1911,7 @@ public MccMcpResult GetWorldBlockAt(int x, int y, int z) }); } - private static MccMcpResult OpenContainerCore(McClient client, Location location, Block block, int activeContainerId, int waitMs, bool closeCurrent) + private static async Task OpenContainerCoreAsync(McClient client, Location location, Block block, int activeContainerId, int waitMs, bool closeCurrent) { if (activeContainerId > 0) { @@ -1876,7 +1929,7 @@ private static MccMcpResult OpenContainerCore(McClient client, Location location } bool closeAccepted = client.CloseInventory(activeContainerId); - bool closed = closeAccepted && WaitForContainerClose(client, activeContainerId, waitMs); + bool closed = closeAccepted && await WaitForContainerCloseAsync(client, activeContainerId, waitMs); if (!closeAccepted || !closed) { return MccMcpResult.Fail("action_incomplete", data: new @@ -1894,7 +1947,11 @@ private static MccMcpResult OpenContainerCore(McClient client, Location location int openedInventoryId = 0; Container? openedInventory = null; bool openAccepted = client.InvokeOnMainThread(() => client.PlaceBlock(location, Direction.Down, Hand.MainHand, lookAtBlock: true)); - bool opened = openAccepted && WaitForContainerOpen(client, beforeIds, waitMs, out openedInventoryId, out openedInventory); + bool opened = openAccepted && await WaitForContainerOpenAsync(client, beforeIds, waitMs, result => + { + openedInventoryId = result.InventoryId; + openedInventory = result.Inventory; + }); var resultData = new { success = openAccepted && opened && openedInventory is not null, @@ -1923,6 +1980,11 @@ private static MccMcpResult OpenContainerCore(McClient client, Location location } private MccMcpResult TransferContainerItem(string itemType, int count, int inventoryId, bool preferLargestStack, InventoryTransferDirection direction) + { + return TransferContainerItemAsync(itemType, count, inventoryId, preferLargestStack, direction).GetAwaiter().GetResult(); + } + + private async Task TransferContainerItemAsync(string itemType, int count, int inventoryId, bool preferLargestStack, InventoryTransferDirection direction) { if (!IsCategoryEnabled(t => t.Inventory)) return MccMcpResult.Fail("capability_disabled"); @@ -2044,7 +2106,18 @@ private MccMcpResult TransferContainerItem(string itemType, int count, int inven if (direction == InventoryTransferDirection.Withdraw) { - if (!WaitForRangeCount(client, resolvedInventoryId, parsedItemType, sourceStart, sourceEnd, countAfterShift => countAfterShift < beforeSourceCount, DefaultInventoryActionWaitMs, out Container? afterShift, out int afterSourceCount)) + Container? afterShift = null; + int afterSourceCount = beforeSourceCount; + if (!await WaitForRangeCountAsync(client, resolvedInventoryId, parsedItemType, sourceStart, sourceEnd, countAfterShift => countAfterShift < beforeSourceCount, DefaultInventoryActionWaitMs, result => + { + afterShift = result.Inventory; + afterSourceCount = result.ItemCount; + }, + onInitialize: () => + { + afterShift = null; + afterSourceCount = beforeSourceCount; + })) { afterShift = client.InvokeOnMainThread(() => client.GetInventory(resolvedInventoryId)); afterSourceCount = afterShift is null ? beforeSourceCount : CountItemInRange(afterShift, parsedItemType, sourceStart, sourceEnd); @@ -2054,7 +2127,18 @@ private MccMcpResult TransferContainerItem(string itemType, int count, int inven } else { - if (!WaitForRangeCount(client, resolvedInventoryId, parsedItemType, targetStart, targetEnd, countAfterShift => countAfterShift > beforeTargetCount, DefaultInventoryActionWaitMs, out Container? afterShift, out int afterTargetCount)) + Container? afterShift = null; + int afterTargetCount = beforeTargetCount; + if (!await WaitForRangeCountAsync(client, resolvedInventoryId, parsedItemType, targetStart, targetEnd, countAfterShift => countAfterShift > beforeTargetCount, DefaultInventoryActionWaitMs, result => + { + afterShift = result.Inventory; + afterTargetCount = result.ItemCount; + }, + onInitialize: () => + { + afterShift = null; + afterTargetCount = beforeTargetCount; + })) { afterShift = client.InvokeOnMainThread(() => client.GetInventory(resolvedInventoryId)); afterTargetCount = afterShift is null ? beforeTargetCount : CountItemInRange(afterShift, parsedItemType, targetStart, targetEnd); @@ -2066,7 +2150,7 @@ private MccMcpResult TransferContainerItem(string itemType, int count, int inven if (direction == InventoryTransferDirection.Withdraw && movedCount > remaining) { int excessCount = movedCount - remaining; - MccMcpResult returnExcess = TransferContainerItem(parsedItemType.ToString(), excessCount, resolvedInventoryId, preferLargestStack, InventoryTransferDirection.Deposit); + MccMcpResult returnExcess = await TransferContainerItemAsync(parsedItemType.ToString(), excessCount, resolvedInventoryId, preferLargestStack, InventoryTransferDirection.Deposit); if (!returnExcess.Success) { return MccMcpResult.Fail("action_incomplete", data: new @@ -2087,7 +2171,7 @@ private MccMcpResult TransferContainerItem(string itemType, int count, int inven } else { - movedCount = TransferPartialFromSlot( + movedCount = await TransferPartialFromSlotAsync( client, resolvedInventoryId, slot, @@ -2205,10 +2289,8 @@ private static int GetActiveContainerId(McClient client) return client.GetInventories().Keys.Where(id => id > 0).DefaultIfEmpty(0).Max(); } - private static bool WaitForContainerOpen(McClient client, ISet beforeIds, int waitMs, out int inventoryId, out Container? inventory) + private static async Task WaitForContainerOpenAsync(McClient client, ISet beforeIds, int waitMs, Action onOpened) { - inventoryId = 0; - inventory = null; DateTime deadline = DateTime.UtcNow.AddMilliseconds(waitMs); while (true) { @@ -2221,19 +2303,18 @@ private static bool WaitForContainerOpen(McClient client, ISet beforeIds, i if (state.activeId > 0 && (!beforeIds.Contains(state.activeId) || beforeIds.Count == 0) && state.activeInventory is not null) { - inventoryId = state.activeId; - inventory = state.activeInventory; + onOpened(new ContainerOpenState(state.activeId, state.activeInventory)); return true; } if (DateTime.UtcNow >= deadline) return false; - Thread.Sleep(ArrivalPollIntervalMs); + await Task.Delay(ArrivalPollIntervalMs); } } - private static bool WaitForContainerClose(McClient client, int inventoryId, int waitMs) + private static async Task WaitForContainerCloseAsync(McClient client, int inventoryId, int waitMs) { DateTime deadline = DateTime.UtcNow.AddMilliseconds(waitMs); while (true) @@ -2245,7 +2326,7 @@ private static bool WaitForContainerClose(McClient client, int inventoryId, int if (DateTime.UtcNow >= deadline) return false; - Thread.Sleep(ArrivalPollIntervalMs); + await Task.Delay(ArrivalPollIntervalMs); } } @@ -2328,20 +2409,6 @@ private static int GetCursorItemCount(Container inventory, ItemType itemType) : 0; } - private static bool TryDropInventorySlotItems(McClient client, int inventoryId, Container inventory, int slotId, ItemType itemType, int dropCount, out int droppedCount) - { - droppedCount = 0; - if (!inventory.Items.TryGetValue(slotId, out Item? currentItem) || currentItem.Type != itemType || currentItem.Count <= 0) - return false; - - if (inventoryId == 0 && inventory.IsHotbar(slotId, out int hotbarSlot)) - { - return TryDropHotbarSlotItems(client, slotId, hotbarSlot, itemType, dropCount, currentItem.Count, out droppedCount); - } - - return TryDropWindowSlotItems(client, inventoryId, slotId, itemType, dropCount, currentItem.Count, out droppedCount); - } - private static int CountItemInRange(Container inventory, ItemType itemType, int startSlot, int endSlot) { return inventory.Items @@ -2363,7 +2430,7 @@ private static (int slot, int count)[] GetOrderedItemSlots(Container inventory, .ToArray(); } - private static int TransferPartialFromSlot(McClient client, int inventoryId, int sourceSlot, ItemType itemType, int requestedCount, int sourceStart, int sourceEnd, int targetStart, int targetEnd, List touchedTargetSlots) + private static async Task TransferPartialFromSlotAsync(McClient client, int inventoryId, int sourceSlot, ItemType itemType, int requestedCount, int sourceStart, int sourceEnd, int targetStart, int targetEnd, List touchedTargetSlots) { Container? inventory = client.InvokeOnMainThread(() => client.GetInventory(inventoryId)); if (inventory is null || !inventory.Items.TryGetValue(sourceSlot, out Item? sourceItem) || sourceItem.Count <= 0) @@ -2373,7 +2440,7 @@ private static int TransferPartialFromSlot(McClient client, int inventoryId, int if (!client.DoWindowAction(inventoryId, sourceSlot, WindowActionType.LeftClick)) return 0; - if (!WaitForCursorItem(client, itemType, DefaultInventoryActionWaitMs, out _)) + if (!await WaitForCursorItemAsync(client, itemType, DefaultInventoryActionWaitMs)) return 0; int moved = 0; @@ -2392,7 +2459,7 @@ private static int TransferPartialFromSlot(McClient client, int inventoryId, int if (step <= 0 || !PlaceItemsFromCursor(client, inventoryId, targetSlot, step)) break; - if (!WaitForPlacement(client, inventoryId, targetSlot, itemType, beforeTargetCount, beforeCursorCount, step)) + if (!await WaitForPlacementAsync(client, inventoryId, targetSlot, itemType, beforeTargetCount, beforeCursorCount, step)) break; touchedTargetSlots.Add(targetSlot); @@ -2409,7 +2476,7 @@ private static int TransferPartialFromSlot(McClient client, int inventoryId, int if (!client.DoWindowAction(inventoryId, returnSlot, WindowActionType.LeftClick)) return 0; - if (!WaitForCursorClear(client, DefaultInventoryActionWaitMs)) + if (!await WaitForCursorClearAsync(client, DefaultInventoryActionWaitMs)) return 0; } @@ -2495,27 +2562,27 @@ private static int GetSlotItemCount(Container inventory, int slot, ItemType item return inventory.Items.TryGetValue(slot, out Item? item) && item.Type == itemType ? item.Count : 0; } - private static bool TryDropHotbarSlotItems(McClient client, int slotId, int hotbarSlot, ItemType itemType, int dropCount, int availableInSlot, out int droppedCount) + private static async Task<(bool Success, int DroppedCount)> TryDropHotbarSlotItemsAsync(McClient client, int slotId, int hotbarSlot, ItemType itemType, int dropCount, int availableInSlot) { - droppedCount = 0; - byte previousSlot = client.GetCurrentSlot(); + int droppedCount = 0; + byte previousSlot = client.InvokeOnMainThread(client.GetCurrentSlot); bool restoreSlot = previousSlot != hotbarSlot; if (restoreSlot && !client.ChangeSlot((short)hotbarSlot)) - return false; + return (false, 0); try { if (dropCount >= availableInSlot) { if (!client.DropSelectedItem(dropEntireStack: true)) - return false; + return (false, droppedCount); - if (!WaitForSlotItemCount(client, 0, slotId, itemType, count => count == 0, DefaultInventoryActionWaitMs, out _, out _)) - return false; + if (!await WaitForSlotItemCountAsync(client, 0, slotId, itemType, count => count == 0, DefaultInventoryActionWaitMs)) + return (false, droppedCount); droppedCount = availableInSlot; - return true; + return (true, droppedCount); } int remaining = dropCount; @@ -2523,23 +2590,23 @@ private static bool TryDropHotbarSlotItems(McClient client, int slotId, int hotb { Container? currentInventory = client.GetInventory(0); if (currentInventory is null) - return false; + return (false, droppedCount); int beforeSlotCount = GetSlotItemCount(currentInventory, slotId, itemType); if (beforeSlotCount <= 0) break; if (!client.DropSelectedItem(dropEntireStack: false)) - return false; + return (false, droppedCount); - if (!WaitForSlotItemCount(client, 0, slotId, itemType, count => count <= beforeSlotCount - 1, DefaultInventoryActionWaitMs, out _, out _)) - return false; + if (!await WaitForSlotItemCountAsync(client, 0, slotId, itemType, count => count <= beforeSlotCount - 1, DefaultInventoryActionWaitMs)) + return (false, droppedCount); remaining--; droppedCount++; } - return remaining == 0; + return (remaining == 0, droppedCount); } finally { @@ -2548,20 +2615,20 @@ private static bool TryDropHotbarSlotItems(McClient client, int slotId, int hotb } } - private static bool TryDropWindowSlotItems(McClient client, int inventoryId, int slotId, ItemType itemType, int dropCount, int availableInSlot, out int droppedCount) + private static async Task<(bool Success, int DroppedCount)> TryDropWindowSlotItemsAsync(McClient client, int inventoryId, int slotId, ItemType itemType, int dropCount, int availableInSlot) { - droppedCount = 0; + int droppedCount = 0; if (dropCount >= availableInSlot) { if (!client.DoWindowAction(inventoryId, slotId, WindowActionType.DropItemStack)) - return false; + return (false, droppedCount); - if (!WaitForSlotItemCount(client, inventoryId, slotId, itemType, count => count == 0, DefaultInventoryActionWaitMs, out _, out _)) - return false; + if (!await WaitForSlotItemCountAsync(client, inventoryId, slotId, itemType, count => count == 0, DefaultInventoryActionWaitMs)) + return (false, droppedCount); droppedCount = availableInSlot; - return true; + return (true, droppedCount); } int remaining = dropCount; @@ -2569,42 +2636,54 @@ private static bool TryDropWindowSlotItems(McClient client, int inventoryId, int { Container? currentInventory = client.GetInventory(inventoryId); if (currentInventory is null) - return false; + return (false, droppedCount); int beforeSlotCount = GetSlotItemCount(currentInventory, slotId, itemType); if (beforeSlotCount <= 0) break; if (!client.DoWindowAction(inventoryId, slotId, WindowActionType.DropItem)) - return false; + return (false, droppedCount); - if (!WaitForSlotItemCount(client, inventoryId, slotId, itemType, count => count <= beforeSlotCount - 1, DefaultInventoryActionWaitMs, out _, out _)) - return false; + if (!await WaitForSlotItemCountAsync(client, inventoryId, slotId, itemType, count => count <= beforeSlotCount - 1, DefaultInventoryActionWaitMs)) + return (false, droppedCount); remaining--; droppedCount++; } - return remaining == 0; + return (remaining == 0, droppedCount); + } + + private static Task<(bool Success, int DroppedCount)> TryDropInventorySlotItemsAsync(McClient client, int inventoryId, Container inventory, int slotId, ItemType itemType, int dropCount) + { + if (!inventory.Items.TryGetValue(slotId, out Item? currentItem) || currentItem.Type != itemType || currentItem.Count <= 0) + return Task.FromResult((false, 0)); + + if (inventoryId == 0 && inventory.IsHotbar(slotId, out int hotbarSlot)) + return TryDropHotbarSlotItemsAsync(client, slotId, hotbarSlot, itemType, dropCount, currentItem.Count); + + return TryDropWindowSlotItemsAsync(client, inventoryId, slotId, itemType, dropCount, currentItem.Count); } - private static bool WaitForCursorItem(McClient client, ItemType itemType, int waitMs, out Item? cursorItem) + private readonly record struct InventoryCountState(Container? Inventory, int ItemCount); + + private static async Task WaitForCursorItemAsync(McClient client, ItemType itemType, int waitMs) { - cursorItem = null; DateTime deadline = DateTime.UtcNow.AddMilliseconds(waitMs); while (true) { - if (TryGetCursorItem(client, out cursorItem) && cursorItem is not null && cursorItem.Type == itemType) + if (TryGetCursorItem(client, out Item? cursorItem) && cursorItem is not null && cursorItem.Type == itemType) return true; if (DateTime.UtcNow >= deadline) return false; - Thread.Sleep(ArrivalPollIntervalMs); + await Task.Delay(ArrivalPollIntervalMs); } } - private static bool WaitForCursorClear(McClient client, int waitMs) + private static async Task WaitForCursorClearAsync(McClient client, int waitMs) { DateTime deadline = DateTime.UtcNow.AddMilliseconds(waitMs); while (true) @@ -2615,11 +2694,11 @@ private static bool WaitForCursorClear(McClient client, int waitMs) if (DateTime.UtcNow >= deadline) return false; - Thread.Sleep(ArrivalPollIntervalMs); + await Task.Delay(ArrivalPollIntervalMs); } } - private static bool WaitForPlacement(McClient client, int inventoryId, int targetSlot, ItemType itemType, int beforeTargetCount, int beforeCursorCount, int placedCount) + private static async Task WaitForPlacementAsync(McClient client, int inventoryId, int targetSlot, ItemType itemType, int beforeTargetCount, int beforeCursorCount, int placedCount) { DateTime deadline = DateTime.UtcNow.AddMilliseconds(DefaultInventoryActionWaitMs); while (true) @@ -2627,7 +2706,7 @@ private static bool WaitForPlacement(McClient client, int inventoryId, int targe bool targetUpdated = false; bool cursorUpdated = false; - Container? inventory = client.InvokeOnMainThread(() => client.GetInventory(inventoryId)); + Container? inventory = client.GetInventory(inventoryId); if (inventory is not null) { int currentTargetCount = GetSlotItemCount(inventory, targetSlot, itemType); @@ -2649,21 +2728,19 @@ private static bool WaitForPlacement(McClient client, int inventoryId, int targe if (DateTime.UtcNow >= deadline) return false; - Thread.Sleep(ArrivalPollIntervalMs); + await Task.Delay(ArrivalPollIntervalMs); } } - private static bool WaitForSlotItemCount(McClient client, int inventoryId, int slotId, ItemType itemType, Func predicate, int waitMs, out Container? inventory, out int itemCount) + private static async Task WaitForSlotItemCountAsync(McClient client, int inventoryId, int slotId, ItemType itemType, Func predicate, int waitMs) { - inventory = null; - itemCount = 0; DateTime deadline = DateTime.UtcNow.AddMilliseconds(waitMs); while (true) { - inventory = client.GetInventory(inventoryId); + Container? inventory = client.GetInventory(inventoryId); if (inventory is not null) { - itemCount = GetSlotItemCount(inventory, slotId, itemType); + int itemCount = GetSlotItemCount(inventory, slotId, itemType); if (predicate(itemCount)) return true; } @@ -2671,21 +2748,22 @@ private static bool WaitForSlotItemCount(McClient client, int inventoryId, int s if (DateTime.UtcNow >= deadline) return false; - Thread.Sleep(ArrivalPollIntervalMs); + await Task.Delay(ArrivalPollIntervalMs); } } - private static bool WaitForRangeCount(McClient client, int inventoryId, ItemType itemType, int startSlot, int endSlot, Func predicate, int waitMs, out Container? inventory, out int itemCount) + private static async Task WaitForRangeCountAsync(McClient client, int inventoryId, ItemType itemType, int startSlot, int endSlot, Func predicate, int waitMs, Action onObserved, Action onInitialize) { - inventory = null; - itemCount = 0; + onInitialize(); DateTime deadline = DateTime.UtcNow.AddMilliseconds(waitMs); while (true) { - inventory = client.InvokeOnMainThread(() => client.GetInventory(inventoryId)); + Container? inventory = client.GetInventory(inventoryId); + int itemCount = 0; if (inventory is not null) { itemCount = CountItemInRange(inventory, itemType, startSlot, endSlot); + onObserved(new InventoryCountState(inventory, itemCount)); if (predicate(itemCount)) return true; } @@ -2693,7 +2771,7 @@ private static bool WaitForRangeCount(McClient client, int inventoryId, ItemType if (DateTime.UtcNow >= deadline) return false; - Thread.Sleep(ArrivalPollIntervalMs); + await Task.Delay(ArrivalPollIntervalMs); } } @@ -2882,22 +2960,21 @@ private static string NormalizeToken(string value) return new string(buffer); } - private static bool WaitForArrival(McClient client, Location goal, int waitMs, double tolerance, out Location? finalLocation) + private static async Task<(bool Arrived, Location FinalLocation)> WaitForArrivalAsync(McClient client, Location goal, int waitMs, double tolerance) { - finalLocation = null; DateTime deadline = DateTime.UtcNow.AddMilliseconds(waitMs); + Location finalLocation = client.InvokeOnMainThread(client.GetCurrentLocation); while (true) { - Location location = client.InvokeOnMainThread(client.GetCurrentLocation); - finalLocation = location; - double distance = GetDistance(location, goal); + finalLocation = client.InvokeOnMainThread(client.GetCurrentLocation); + double distance = GetDistance(finalLocation, goal); if (distance <= tolerance) - return true; + return (true, finalLocation); if (DateTime.UtcNow >= deadline) - return false; + return (false, finalLocation); - Thread.Sleep(ArrivalPollIntervalMs); + await Task.Delay(ArrivalPollIntervalMs); } } @@ -2929,21 +3006,21 @@ private static int GetPathQueryTimeoutMs(int timeoutMs) return Math.Clamp(timeoutMs, MinPathQueryTimeoutMs, MaxPathQueryTimeoutMs); } - private static bool WaitForBlockChange(McClient client, Location target, Block beforeBlock, int waitMs, out Block afterBlock) + private static async Task<(bool Changed, Block AfterBlock)> WaitForBlockChangeAsync(McClient client, Location target, Block beforeBlock, int waitMs) { - afterBlock = beforeBlock; + Block afterBlock = beforeBlock; DateTime deadline = DateTime.UtcNow.AddMilliseconds(waitMs); while (true) { Block current = client.InvokeOnMainThread(() => client.GetWorld().GetBlock(target)); afterBlock = current; if (!AreEquivalentBlocks(current, beforeBlock)) - return true; + return (true, afterBlock); if (DateTime.UtcNow >= deadline) - return false; + return (false, afterBlock); - Thread.Sleep(ArrivalPollIntervalMs); + await Task.Delay(ArrivalPollIntervalMs); } } @@ -3062,22 +3139,6 @@ private static NearbyItemSnapshot[] BuildNearbyItemSnapshots(McClient client, It .ToArray(); } - private static bool WaitForEntityRemoval(McClient client, int entityId, int waitMs) - { - DateTime deadline = DateTime.UtcNow.AddMilliseconds(waitMs); - while (true) - { - bool exists = client.InvokeOnMainThread(() => client.GetEntities().ContainsKey(entityId)); - if (!exists) - return true; - - if (DateTime.UtcNow >= deadline) - return false; - - Thread.Sleep(ArrivalPollIntervalMs); - } - } - private static int GetInventoryItemCount(McClient client, ItemType itemType) { Container? inventory = client.GetInventory(0); diff --git a/MinecraftClient/Mcp/MccMcpToolSet.cs b/MinecraftClient/Mcp/MccMcpToolSet.cs index 84305e3480..0be91c01d6 100644 --- a/MinecraftClient/Mcp/MccMcpToolSet.cs +++ b/MinecraftClient/Mcp/MccMcpToolSet.cs @@ -1,4 +1,5 @@ using System.ComponentModel; +using System.Threading.Tasks; using ModelContextProtocol.Server; namespace MinecraftClient.Mcp; @@ -202,9 +203,9 @@ public object UseItemOnBlock(double x, double y, double z) } [McpServerTool(Name = "mcc_dig_block"), Description("Dig a block at target location.")] - public object DigBlock(double x, double y, double z, double durationSeconds = 0) + public async Task DigBlock(double x, double y, double z, double durationSeconds = 0) { - return capabilities.DigBlock(x, y, z, durationSeconds); + return await capabilities.DigBlockAsync(x, y, z, durationSeconds); } [McpServerTool(Name = "mcc_place_block"), Description("Place the currently held block/item at a target block location.")] @@ -262,15 +263,15 @@ public object CanReachPosition(double x, double y, double z, bool allowUnsafe = } [McpServerTool(Name = "mcc_move_to"), Description("Request movement/pathing to a world coordinate and verify arrival.")] - public object MoveTo(double x, double y, double z, bool allowUnsafe = false, bool allowDirectTeleport = false, int maxOffset = 0, int minOffset = 0, int timeoutMs = 0) + public async Task MoveTo(double x, double y, double z, bool allowUnsafe = false, bool allowDirectTeleport = false, int maxOffset = 0, int minOffset = 0, int timeoutMs = 0) { - return capabilities.MoveTo(x, y, z, allowUnsafe, allowDirectTeleport, maxOffset, minOffset, timeoutMs); + return await capabilities.MoveToAsync(x, y, z, allowUnsafe, allowDirectTeleport, maxOffset, minOffset, timeoutMs); } [McpServerTool(Name = "mcc_move_to_player"), Description("Locate a tracked player entity, request movement/pathing, and verify arrival.")] - public object MoveToPlayer(string playerName, bool allowUnsafe = false, bool allowDirectTeleport = false, int maxOffset = 0, int minOffset = 0, int timeoutMs = 0) + public async Task MoveToPlayer(string playerName, bool allowUnsafe = false, bool allowDirectTeleport = false, int maxOffset = 0, int minOffset = 0, int timeoutMs = 0) { - return capabilities.MoveToPlayer(playerName, allowUnsafe, allowDirectTeleport, maxOffset, minOffset, timeoutMs); + return await capabilities.MoveToPlayerAsync(playerName, allowUnsafe, allowDirectTeleport, maxOffset, minOffset, timeoutMs); } [McpServerTool(Name = "mcc_look_at"), Description("Rotate player view toward world coordinates.")] @@ -310,15 +311,15 @@ public object InventoriesList() } [McpServerTool(Name = "mcc_container_open_at"), Description("Open an interactable container block at world coordinates and wait for the container inventory to appear.")] - public object ContainerOpenAt(int x, int y, int z, int timeoutMs = 0, bool closeCurrent = true) + public async Task ContainerOpenAt(int x, int y, int z, int timeoutMs = 0, bool closeCurrent = true) { - return capabilities.OpenContainerAt(x, y, z, timeoutMs, closeCurrent); + return await capabilities.OpenContainerAtAsync(x, y, z, timeoutMs, closeCurrent); } [McpServerTool(Name = "mcc_container_close"), Description("Close an open non-player container. Use inventoryId=-1 to close the active container.")] - public object ContainerClose([Description("Container inventory ID, or -1 for the active non-player container.")] int inventoryId = -1, int timeoutMs = 0) + public async Task ContainerClose([Description("Container inventory ID, or -1 for the active non-player container.")] int inventoryId = -1, int timeoutMs = 0) { - return capabilities.CloseContainer(inventoryId, timeoutMs); + return await capabilities.CloseContainerAsync(inventoryId, timeoutMs); } [McpServerTool(Name = "mcc_inventory_window_action"), Description("Perform a window action on an inventory slot.")] @@ -328,33 +329,33 @@ public object InventoryWindowAction(int inventoryId, int slotId, [Description("W } [McpServerTool(Name = "mcc_inventory_drop_item"), Description("Drop an exact item count from an inventory by item type.")] - public object InventoryDropItem( + public async Task InventoryDropItem( [Description("Item type enum name (e.g. Diamond).")] string itemType, [Description("Exact number of items to drop.")] int count, [Description("Inventory ID. 0 is the player inventory.")] int inventoryId = 0, [Description("Prefer dropping from larger stacks first when true.")] bool preferStack = false) { - return capabilities.DropInventoryItem(itemType, count, inventoryId, preferStack); + return await capabilities.DropInventoryItemAsync(itemType, count, inventoryId, preferStack); } [McpServerTool(Name = "mcc_container_deposit_item"), Description("Move an exact item count from the player inventory into an open container and verify the transfer.")] - public object ContainerDepositItem( + public async Task ContainerDepositItem( [Description("Item type enum name (e.g. Diamond).")] string itemType, [Description("Exact number of items to move into the container.")] int count, [Description("Container inventory ID, or -1 for the active non-player container.")] int inventoryId = -1, [Description("Prefer larger source stacks first when true.")] bool preferLargestStack = true) { - return capabilities.DepositContainerItem(itemType, count, inventoryId, preferLargestStack); + return await capabilities.DepositContainerItemAsync(itemType, count, inventoryId, preferLargestStack); } [McpServerTool(Name = "mcc_container_withdraw_item"), Description("Move an exact item count from an open container into the player inventory and verify the transfer.")] - public object ContainerWithdrawItem( + public async Task ContainerWithdrawItem( [Description("Item type enum name (e.g. Diamond).")] string itemType, [Description("Exact number of items to move into the player inventory.")] int count, [Description("Container inventory ID, or -1 for the active non-player container.")] int inventoryId = -1, [Description("Prefer larger source stacks first when true.")] bool preferLargestStack = true) { - return capabilities.WithdrawContainerItem(itemType, count, inventoryId, preferLargestStack); + return await capabilities.WithdrawContainerItemAsync(itemType, count, inventoryId, preferLargestStack); } [McpServerTool(Name = "mcc_entities_query"), Description("Query tracked entities.")] @@ -388,9 +389,9 @@ public object ItemsList(string? itemType = null, double radius = 32, int maxCoun } [McpServerTool(Name = "mcc_items_pickup"), Description("Move to and pick up nearby dropped items of a given item type.")] - public object ItemsPickup(string itemType, double radius = 32, int maxItems = 20, bool allowUnsafe = false, int timeoutMs = 0) + public async Task ItemsPickup(string itemType, double radius = 32, int maxItems = 20, bool allowUnsafe = false, int timeoutMs = 0) { - return capabilities.PickupItems(itemType, radius, maxItems, allowUnsafe, timeoutMs); + return await capabilities.PickupItemsAsync(itemType, radius, maxItems, allowUnsafe, timeoutMs); } [McpServerTool(Name = "mcc_world_block_at"), Description("Get block information at world coordinates.")] diff --git a/MinecraftClient/Program.cs b/MinecraftClient/Program.cs index 58a42b9ba2..cfdb84fc8d 100644 --- a/MinecraftClient/Program.cs +++ b/MinecraftClient/Program.cs @@ -74,7 +74,7 @@ internal sealed class StartupState /// /// The main entry point of Minecraft Console Client /// - static void Main(string[] args) + static async Task Main(string[] args) { // [SENTRY] Initialize Sentry SDK only if the DSN is not empty if (SentryDSN != string.Empty) @@ -94,7 +94,7 @@ static void Main(string[] args) }; } - Task.Run(() => + _ = Task.Run(() => { // "ToLower" require "CultureInfo" to be initialized on first run, which can take a lot of time. _ = "a".ToLower(); @@ -208,7 +208,7 @@ static void Main(string[] args) // Wait for this issue to be fixed before enabling it: https://github.com/Consolonia/Consolonia/issues/602 // MaybePrintClassicModeTuiRecommendation(); - RunStartupSequence(args); + await RunStartupSequenceAsync(args); } /// @@ -381,7 +381,10 @@ internal static void HandleConfigLoadFailure() /// Called from Main() for classic/basic mode, or from TuiConsoleBackend on a /// background thread after the Avalonia UI loop has started. /// - internal static void RunStartupSequence(string[] args) + internal static void RunStartupSequence(string[] args) => + RunStartupSequenceAsync(args).GetAwaiter().GetResult(); + + internal static async Task RunStartupSequenceAsync(string[] args) { //Other command-line arguments if (args.Length >= 1) @@ -574,20 +577,20 @@ internal static void RunStartupSequence(string[] args) if (string.IsNullOrWhiteSpace(InternalConfig.Account.Password) && !skipPassword && (Config.Main.Advanced.SessionCache == CacheType.none || !SessionCache.Contains(ToLowerIfNeed(InternalConfig.Account.Login)))) { - RequestPassword(); + await RequestPasswordAsync(); } startupargs = args; - InitializeClient(); + await InitializeClientAsync(); } /// /// Reduest user to submit password. /// - private static void RequestPassword() + private static async Task RequestPasswordAsync() { ConsoleIO.WriteLine(ConsoleIO.BasicIO ? string.Format(Translations.mcc_password_basic_io, InternalConfig.Account.Login) + "\n" : Translations.mcc_password_hidden); - string? password = ConsoleIO.BasicIO ? Console.ReadLine() : ConsoleIO.ReadPassword(); + string? password = await ConsoleIO.ReadPasswordAsync(); if (string.IsNullOrWhiteSpace(password)) InternalConfig.Account.Password = "-"; else @@ -597,7 +600,7 @@ private static void RequestPassword() /// /// Start a new Client /// - private static void InitializeClient() + private static async Task InitializeClientAsync() { // Ensure that we use the provided Minecraft version if we can't connect automatically. // @@ -634,7 +637,9 @@ private static void InitializeClient() { try { - result = ProtocolHandler.MicrosoftLoginRefresh(session.RefreshToken, out session); + var refreshResult = await ProtocolHandler.MicrosoftLoginRefreshAsync(session.RefreshToken); + result = refreshResult.Result; + session = refreshResult.Session; } catch (Exception ex) { @@ -646,7 +651,7 @@ private static void InitializeClient() if (result != ProtocolHandler.LoginResult.Success && string.IsNullOrWhiteSpace(InternalConfig.Account.Password) && !(Config.Main.General.AccountType == LoginType.microsoft)) - RequestPassword(); + await RequestPasswordAsync(); } else ConsoleIO.WriteLineFormatted("§8" + string.Format(Translations.mcc_session_valid, session.PlayerName)); } @@ -654,14 +659,16 @@ private static void InitializeClient() if (result != ProtocolHandler.LoginResult.Success) { ConsoleIO.WriteLine(string.Format(Translations.mcc_connecting, Config.Main.General.AccountType == LoginType.mojang ? "Minecraft.net" : (Config.Main.General.AccountType == LoginType.microsoft ? "Microsoft" : Config.Main.General.AuthServer.Host))); - result = ProtocolHandler.GetLogin(InternalConfig.Account.Login, InternalConfig.Account.Password, Config.Main.General.AccountType, out session); + var loginResult = await ProtocolHandler.GetLoginAsync(InternalConfig.Account.Login, InternalConfig.Account.Password, Config.Main.General.AccountType); + result = loginResult.Result; + session = loginResult.Session; } if (result == ProtocolHandler.LoginResult.Success && Config.Main.Advanced.SessionCache != CacheType.none) SessionCache.Store(loginLower, session); if (result == ProtocolHandler.LoginResult.Success) - session.SessionPreCheckTask = Task.Factory.StartNew(() => session.SessionPreCheck(Config.Main.General.AccountType)); + session.SessionPreCheckTask = session.SessionPreCheckAsync(Config.Main.General.AccountType); } if (result == ProtocolHandler.LoginResult.Success) @@ -680,7 +687,7 @@ private static void InitializeClient() List availableWorlds = new(); if (Config.Main.Advanced.MinecraftRealms && !String.IsNullOrEmpty(session.ID)) - availableWorlds = ProtocolHandler.RealmsListWorlds(InternalConfig.Username, session.PlayerID, session.ID); + availableWorlds = await ProtocolHandler.RealmsListWorldsAsync(InternalConfig.Username, session.PlayerID, session.ID); if (InternalConfig.ServerIP == string.Empty) { @@ -700,7 +707,7 @@ private static void InitializeClient() worldId = availableWorlds[worldIndex]; if (availableWorlds.Contains(worldId)) { - string realmsAddress = ProtocolHandler.GetRealmsWorldServerAddress(worldId, InternalConfig.Username, session.PlayerID, session.ID); + string realmsAddress = await ProtocolHandler.GetRealmsWorldServerAddressAsync(worldId, InternalConfig.Username, session.PlayerID, session.ID); if (realmsAddress != "") { addressInput = realmsAddress; @@ -756,11 +763,15 @@ private static void InitializeClient() ConsoleIO.WriteLine(Translations.mcc_forge); else ConsoleIO.WriteLine(Translations.mcc_retrieve); - if (!ProtocolHandler.GetServerInfo(InternalConfig.ServerIP, InternalConfig.ServerPort, ref protocolversion, ref forgeInfo)) + var serverInfo = await ProtocolHandler.GetServerInfoAsync(InternalConfig.ServerIP, InternalConfig.ServerPort, protocolversion); + if (!serverInfo.Success) { HandleFailure(Translations.error_ping, true, ChatBot.DisconnectReason.ConnectionLost); return; } + + protocolversion = serverInfo.ProtocolVersion; + forgeInfo = serverInfo.ForgeInfo; } if ((Config.Main.General.AccountType == LoginType.microsoft || Config.Main.General.AccountType == LoginType.yggdrasil) @@ -890,30 +901,41 @@ public static void WriteBackSettings(bool enableBackup = true) /// /// Optional delay, in seconds, before restarting /// Optional, keep account and server settings - public static void Restart(int delaySeconds = 0, bool keepAccountAndServerSettings = false) + public static void Restart(int delaySeconds = 0, bool keepAccountAndServerSettings = false, bool announceDelay = true) { ConsoleIO.Backend?.StopReadThread(); - new Thread(new ThreadStart(delegate + StartLifecycleTask(RestartAsync(delaySeconds, keepAccountAndServerSettings, announceDelay)); + } + + private static async Task RestartAsync(int delaySeconds, bool keepAccountAndServerSettings, bool announceDelay) + { + if (client is not null) { client.Disconnect(); ConsoleIO.Reset(); } + if (offlinePrompt is not null) { - if (client is not null) { client.Disconnect(); ConsoleIO.Reset(); } - if (offlinePrompt is not null) - { - if (ConsoleIO.Backend is not null) - ConsoleIO.Backend.OnInputChange -= ConsoleIO.OfflineAutocompleteHandler; - offlinePrompt.Item2.Cancel(); offlinePrompt.Item1.Join(); offlinePrompt = null; ConsoleIO.Reset(); - } - if (delaySeconds > 0) - { + if (ConsoleIO.Backend is not null) + ConsoleIO.Backend.OnInputChange -= ConsoleIO.OfflineAutocompleteHandler; + offlinePrompt.Item2.Cancel(); + offlinePrompt.Item1.Join(); + offlinePrompt = null; + ConsoleIO.Reset(); + } + if (delaySeconds > 0) + { + if (announceDelay) ConsoleIO.WriteLine(string.Format(Translations.mcc_restart_delay, delaySeconds)); - Thread.Sleep(delaySeconds * 1000); - } - ConsoleIO.WriteLine(Translations.mcc_restart); - ReloadSettings(keepAccountAndServerSettings); - InitializeClient(); - })).Start(); + await Task.Delay(TimeSpan.FromSeconds(delaySeconds)); + } + ConsoleIO.WriteLine(Translations.mcc_restart); + ReloadSettings(keepAccountAndServerSettings); + await InitializeClientAsync(); } public static void DoExit(int exitcode = 0) + { + DoExitAsync(exitcode).GetAwaiter().GetResult(); + } + + private static Task DoExitAsync(int exitcode = 0) { WriteBackSettings(); ConsoleIO.WriteLineFormatted("§a" + string.Format(Translations.config_saving, settingsIniPath)); @@ -932,6 +954,7 @@ public static void DoExit(int exitcode = 0) if (Config.Main.Advanced.PlayerHeadAsIcon) { ConsoleIcon.RevertToMCCIcon(); } ConsoleIO.Backend?.Shutdown(); Environment.Exit(exitcode); + return Task.CompletedTask; } /// @@ -939,7 +962,26 @@ public static void DoExit(int exitcode = 0) /// public static void Exit(int exitcode = 0) { - new Thread(() => { DoExit(exitcode); }).Start(); + StartLifecycleTask(DoExitAsync(exitcode)); + } + + private static void StartLifecycleTask(Task lifecycleTask) + { + _ = ObserveLifecycleTaskAsync(lifecycleTask); + } + + private static async Task ObserveLifecycleTaskAsync(Task lifecycleTask) + { + try + { + await lifecycleTask; + } + catch (Exception ex) + { + SentrySdk.CaptureException(ex); + if (Settings.Config.Logging.DebugMessages) + ConsoleIO.WriteLineFormatted("§8" + ex); + } } /// diff --git a/MinecraftClient/Protocol/Handlers/DataTypes.cs b/MinecraftClient/Protocol/Handlers/DataTypes.cs index bdbdecea39..dab5b55eb5 100644 --- a/MinecraftClient/Protocol/Handlers/DataTypes.cs +++ b/MinecraftClient/Protocol/Handlers/DataTypes.cs @@ -2,10 +2,13 @@ using System.Collections.Generic; using System.Runtime.CompilerServices; using System.Text; +using System.Threading; +using System.Threading.Tasks; using MinecraftClient.Inventory; using MinecraftClient.Inventory.ItemPalettes; using MinecraftClient.Mapping; using MinecraftClient.Mapping.EntityPalettes; +using MinecraftClient.Protocol.PacketPipeline; using MinecraftClient.Protocol.Handlers.StructuredComponents; using MinecraftClient.Protocol.Handlers.StructuredComponents.Core; using MinecraftClient.Protocol.Message; @@ -42,6 +45,12 @@ public byte[] ReadData(int offset, Queue cache) return result; } + [MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)] + public byte[] ReadData(int offset, PacketReader reader) + { + return reader.ReadData(offset); + } + /// /// Read some data from a cache of bytes and remove it from the cache /// @@ -54,6 +63,12 @@ public void ReadDataReverse(Queue cache, Span dest) dest[i] = cache.Dequeue(); } + [MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)] + public void ReadDataReverse(PacketReader reader, Span dest) + { + reader.ReadDataReverse(dest); + } + /// /// Remove some data from the cache /// @@ -66,6 +81,12 @@ public void DropData(int offset, Queue cache) cache.Dequeue(); } + [MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)] + public void DropData(int offset, PacketReader reader) + { + reader.Skip(offset); + } + /// /// Read a string from a cache of bytes and remove it from the cache /// @@ -82,6 +103,13 @@ public string ReadNextString(Queue cache) else return ""; } + [MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)] + public string ReadNextString(PacketReader reader) + { + int length = ReadNextVarInt(reader); + return length > 0 ? Encoding.UTF8.GetString(ReadData(length, reader)) : ""; + } + /// /// Skip a string from a cache of bytes and remove it from the cache /// @@ -92,6 +120,12 @@ public void SkipNextString(Queue cache) DropData(length, cache); } + public void SkipNextString(PacketReader reader) + { + int length = ReadNextVarInt(reader); + DropData(length, reader); + } + /// /// Read a boolean from a cache of bytes and remove it from the cache /// @@ -102,6 +136,12 @@ public bool ReadNextBool(Queue cache) return ReadNextByte(cache) != 0x00; } + [MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)] + public bool ReadNextBool(PacketReader reader) + { + return ReadNextByte(reader) != 0x00; + } + /// /// Read a short integer from a cache of bytes and remove it from the cache /// @@ -115,6 +155,12 @@ public short ReadNextShort(Queue cache) return BitConverter.ToInt16(rawValue); } + [MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)] + public short ReadNextShort(PacketReader reader) + { + return reader.ReadInt16BigEndian(); + } + /// /// Read an integer from a cache of bytes and remove it from the cache /// @@ -128,6 +174,12 @@ public int ReadNextInt(Queue cache) return BitConverter.ToInt32(rawValue); } + [MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)] + public int ReadNextInt(PacketReader reader) + { + return reader.ReadInt32BigEndian(); + } + /// /// Read a long integer from a cache of bytes and remove it from the cache /// @@ -141,6 +193,12 @@ public long ReadNextLong(Queue cache) return BitConverter.ToInt64(rawValue); } + [MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)] + public long ReadNextLong(PacketReader reader) + { + return reader.ReadInt64BigEndian(); + } + /// /// Read an unsigned short integer from a cache of bytes and remove it from the cache /// @@ -154,6 +212,12 @@ public ushort ReadNextUShort(Queue cache) return BitConverter.ToUInt16(rawValue); } + [MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)] + public ushort ReadNextUShort(PacketReader reader) + { + return reader.ReadUInt16BigEndian(); + } + /// /// Read an unsigned long integer from a cache of bytes and remove it from the cache /// @@ -167,6 +231,12 @@ public ulong ReadNextULong(Queue cache) return BitConverter.ToUInt64(rawValue); } + [MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)] + public ulong ReadNextULong(PacketReader reader) + { + return unchecked((ulong)reader.ReadInt64BigEndian()); + } + /// /// Read a Location encoded as an ulong field and remove it from the cache /// @@ -197,6 +267,32 @@ public Location ReadNextLocation(Queue cache) return new Location(x, y, z); } + public Location ReadNextLocation(PacketReader reader) + { + ulong locEncoded = ReadNextULong(reader); + int x, y, z; + if (protocolversion >= Protocol18Handler.MC_1_14_Version) + { + x = (int)(locEncoded >> 38); + y = (int)(locEncoded & 0xFFF); + z = (int)(locEncoded << 26 >> 38); + } + else + { + x = (int)(locEncoded >> 38); + y = (int)((locEncoded >> 26) & 0xFFF); + z = (int)(locEncoded << 38 >> 38); + } + + if (x >= 0x02000000) + x -= 0x04000000; + if (y >= 0x00000800) + y -= 0x00001000; + if (z >= 0x02000000) + z -= 0x04000000; + return new Location(x, y, z); + } + /// /// Read several little endian unsigned short integers at once from a cache of bytes and remove them from the cache /// @@ -210,6 +306,15 @@ public ushort[] ReadNextUShortsLittleEndian(int amount, Queue cache) return result; } + public ushort[] ReadNextUShortsLittleEndian(int amount, PacketReader reader) + { + byte[] rawValues = ReadData(2 * amount, reader); + ushort[] result = new ushort[amount]; + for (int i = 0; i < amount; i++) + result[i] = BitConverter.ToUInt16(rawValues, i * 2); + return result; + } + /// /// Read a uuid from a cache of bytes and remove it from the cache /// @@ -226,6 +331,16 @@ public Guid ReadNextUUID(Queue cache) return guid; } + public Guid ReadNextUUID(PacketReader reader) + { + Span javaUUID = stackalloc byte[16]; + reader.ReadData(javaUUID); + Guid guid = new(javaUUID); + if (BitConverter.IsLittleEndian) + guid = guid.ToLittleEndian(); + return guid; + } + /// /// Read a byte array from a cache of bytes and remove it from the cache /// @@ -240,6 +355,15 @@ public byte[] ReadNextByteArray(Queue cache) return ReadData(len, cache); } + [MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)] + public byte[] ReadNextByteArray(PacketReader reader) + { + int len = protocolversion >= Protocol18Handler.MC_1_8_Version + ? ReadNextVarInt(reader) + : ReadNextShort(reader); + return ReadData(len, reader); + } + /// /// Read a byte array with given length from a cache of bytes and remove it from the cache /// @@ -252,6 +376,12 @@ public byte[] ReadNextByteArray(Queue cache, int length) return ReadData(length, cache); } + [MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)] + public byte[] ReadNextByteArray(PacketReader reader, int length) + { + return ReadData(length, reader); + } + /// /// Reads a length-prefixed array of unsigned long integers and removes it from the cache /// @@ -266,6 +396,15 @@ public ulong[] ReadNextULongArray(Queue cache) return result; } + public ulong[] ReadNextULongArray(PacketReader reader) + { + int len = ReadNextVarInt(reader); + ulong[] result = new ulong[len]; + for (int i = 0; i < len; i++) + result[i] = ReadNextULong(reader); + return result; + } + /// /// Read a double from a cache of bytes and remove it from the cache /// @@ -279,6 +418,12 @@ public double ReadNextDouble(Queue cache) return BitConverter.ToDouble(rawValue); } + [MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)] + public double ReadNextDouble(PacketReader reader) + { + return BitConverter.Int64BitsToDouble(ReadNextLong(reader)); + } + /// /// Read a float from a cache of bytes and remove it from the cache /// @@ -292,6 +437,12 @@ public float ReadNextFloat(Queue cache) return BitConverter.ToSingle(rawValue); } + [MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)] + public float ReadNextFloat(PacketReader reader) + { + return BitConverter.Int32BitsToSingle(ReadNextInt(reader)); + } + /// /// Read an integer from the network /// @@ -304,7 +455,28 @@ public int ReadNextVarIntRAW(SocketWrapper socket) byte b; while (true) { - b = socket.ReadDataRAW(1)[0]; + b = socket.ReadByteRAW(); + i |= (b & 0x7F) << j++ * 7; + if (j > 5) throw new OverflowException("VarInt too big"); + if ((b & 0x80) != 128) break; + } + + return i; + } + + /// + /// Read an integer from the network asynchronously. + /// + /// The integer + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public async Task ReadNextVarIntRAWAsync(SocketWrapper socket, CancellationToken cancellationToken) + { + int i = 0; + int j = 0; + byte b; + while (true) + { + b = await socket.ReadByteRAWAsync(cancellationToken); i |= (b & 0x7F) << j++ * 7; if (j > 5) throw new OverflowException("VarInt too big"); if ((b & 0x80) != 128) break; @@ -334,6 +506,22 @@ public int ReadNextVarInt(Queue cache) return i; } + [MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)] + public int ReadNextVarInt(PacketReader reader) + { + int i = 0; + int j = 0; + byte b; + do + { + b = reader.ReadByte(); + i |= (b & 0x7F) << j++ * 7; + if (j > 5) throw new OverflowException("VarInt too big"); + } while ((b & 0x80) == 128); + + return i; + } + /// /// Skip a VarInt from a cache of bytes with better performance /// @@ -346,6 +534,14 @@ public void SkipNextVarInt(Queue cache) break; } + [MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)] + public void SkipNextVarInt(PacketReader reader) + { + while (true) + if ((ReadNextByte(reader) & 0x80) != 128) + break; + } + /// /// Read an "extended short", which is actually an int of some kind, from the cache of bytes. /// This is only done with forge. It looks like it's a normal short, except that if the high @@ -366,6 +562,19 @@ public int ReadNextVarShort(Queue cache) return ((high & 0xFF) << 15) | low; } + public int ReadNextVarShort(PacketReader reader) + { + ushort low = ReadNextUShort(reader); + byte high = 0; + if ((low & 0x8000) != 0) + { + low &= 0x7FFF; + high = ReadNextByte(reader); + } + + return ((high & 0xFF) << 15) | low; + } + /// /// Read a long from a cache of bytes and remove it from the cache /// @@ -392,6 +601,27 @@ public long ReadNextVarLong(Queue cache) return result; } + public long ReadNextVarLong(PacketReader reader) + { + int numRead = 0; + long result = 0; + byte read; + do + { + read = ReadNextByte(reader); + long value = (read & 0x7F); + result |= (value << (7 * numRead)); + + numRead++; + if (numRead > 10) + { + throw new OverflowException("VarLong is too big"); + } + } while ((read & 0x80) != 0); + + return result; + } + /// /// Read a single byte from a cache of bytes and remove it from the cache /// @@ -403,6 +633,12 @@ public byte ReadNextByte(Queue cache) return result; } + [MethodImpl(MethodImplOptions.AggressiveInlining | MethodImplOptions.AggressiveOptimization)] + public byte ReadNextByte(PacketReader reader) + { + return reader.ReadByte(); + } + /// /// Read an uncompressed Named Binary Tag blob and remove it from the cache /// @@ -411,6 +647,11 @@ public Dictionary ReadNextNbt(Queue cache) return ReadNextNbt(cache, true); } + public Dictionary ReadNextNbt(PacketReader reader) + { + return ReadWithQueueFallback(reader, cache => ReadNextNbt(cache, true)); + } + /// /// Read an ItemStackTemplate (26.1+) from a cache of bytes. /// Unlike ItemStack, this uses item-first encoding: item_id, count, DataComponentPatch. @@ -442,6 +683,11 @@ public Item ReadNextItemStackTemplate(Queue cache, ItemPalette itemPalette return item; } + public Item ReadNextItemStackTemplate(PacketReader reader, ItemPalette itemPalette) + { + return ReadWithQueueFallback(reader, cache => ReadNextItemStackTemplate(cache, itemPalette)); + } + /// /// Read a single item slot from a cache of bytes and remove it from the cache /// @@ -545,6 +791,11 @@ public Item ReadNextItemStackTemplate(Queue cache, ItemPalette itemPalette } } + public Item? ReadNextItemSlot(PacketReader reader, ItemPalette itemPalette) + { + return ReadWithQueueFallback(reader, cache => ReadNextItemSlot(cache, itemPalette)); + } + private void ReadNextDetail(Queue cache) { var potionEffectId = ReadNextVarInt(cache); @@ -701,6 +952,11 @@ public Entity ReadNextEntity(Queue cache, EntityPalette entityPalette, boo return entity; } + public Entity ReadNextEntity(PacketReader reader, EntityPalette entityPalette, bool living) + { + return ReadWithQueueFallback(reader, cache => ReadNextEntity(cache, entityPalette, living)); + } + /// /// Read an uncompressed Named Binary Tag blob and remove it from the cache (internal) /// @@ -1054,6 +1310,12 @@ private object ReadNbtField(Queue cache, int fieldType) } } + public Dictionary ReadNextMetadata(PacketReader reader, ItemPalette itemPalette, + EntityMetadataPalette metadataPalette) + { + return ReadWithQueueFallback(reader, cache => ReadNextMetadata(cache, itemPalette, metadataPalette)); + } + private static bool HasLpVec3Continuation(int firstByte) => (firstByte & 4) == 4; private static double UnpackLpVec3(long packedAxis) @@ -1086,6 +1348,27 @@ private static double UnpackLpVec3(long packedAxis) ); } + public (double X, double Y, double Z) ReadNextLpVec3Values(PacketReader reader) + { + int first = ReadNextByte(reader); + if (first == 0) + return (0.0, 0.0, 0.0); + + int second = ReadNextByte(reader); + uint high = (uint)ReadNextInt(reader); + long packed = ((long)high << 16) | (long)(second << 8) | (uint)first; + + long scale = first & 3; + if (HasLpVec3Continuation(first)) + scale |= ((long)ReadNextVarInt(reader) & 0xFFFFFFFFL) << 2; + + return ( + UnpackLpVec3(packed >> 3) * scale, + UnpackLpVec3(packed >> 18) * scale, + UnpackLpVec3(packed >> 33) * scale + ); + } + /// /// Read an LpVec3 (low-precision vec3) from the cache (1.21.9+) and discard it. /// @@ -1094,6 +1377,11 @@ public void ReadNextLpVec3(Queue cache) ReadNextLpVec3Values(cache); } + public void ReadNextLpVec3(PacketReader reader) + { + ReadNextLpVec3Values(reader); + } + /// /// Consume bytes for a ResolvableProfile (1.21.9+). /// Wire: Either(GameProfile, Partial) + PlayerSkin.Patch @@ -1356,6 +1644,11 @@ public void ReadParticleData(Queue cache, ItemPalette itemPalette) } } + public void ReadParticleData(PacketReader reader, ItemPalette itemPalette) + { + ReadWithQueueFallback(reader, cache => ReadParticleData(cache, itemPalette)); + } + private void ReadDustParticle(Queue cache) { ReadNextFloat(cache); // Red @@ -1414,6 +1707,11 @@ public VillagerTrade ReadNextTrade(Queue cache, ItemPalette itemPalette) maximumNumberOfTradeUses, xp, specialPrice, priceMultiplier, demand); } + public VillagerTrade ReadNextTrade(PacketReader reader, ItemPalette itemPalette) + { + return ReadWithQueueFallback(reader, cache => ReadNextTrade(cache, itemPalette)); + } + public string ReadNextChat(Queue cache) { if (protocolversion >= Protocol18Handler.MC_1_20_4_Version) @@ -1431,6 +1729,11 @@ public string ReadNextChat(Queue cache) } } + public string ReadNextChat(PacketReader reader) + { + return ReadWithQueueFallback(reader, ReadNextChat); + } + /// /// Build an uncompressed Named Binary Tag blob for sending over the network /// @@ -2012,5 +2315,22 @@ public byte[] GetAcknowledgment(Message.LastSeenMessageList.Acknowledgment ack, return fields.ToArray(); } + + private static T ReadWithQueueFallback(PacketReader reader, Func, T> read) + { + byte[] remaining = reader.CopyRemaining(); + Queue cache = new(remaining); + T result = read(cache); + reader.Skip(remaining.Length - cache.Count); + return result; + } + + private static void ReadWithQueueFallback(PacketReader reader, Action> read) + { + byte[] remaining = reader.CopyRemaining(); + Queue cache = new(remaining); + read(cache); + reader.Skip(remaining.Length - cache.Count); + } } } diff --git a/MinecraftClient/Protocol/Handlers/Packet/s2c/DeclareCommands.cs b/MinecraftClient/Protocol/Handlers/Packet/s2c/DeclareCommands.cs index 57f1bb92da..467cb13bd0 100644 --- a/MinecraftClient/Protocol/Handlers/Packet/s2c/DeclareCommands.cs +++ b/MinecraftClient/Protocol/Handlers/Packet/s2c/DeclareCommands.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using MinecraftClient.Protocol.PacketPipeline; namespace MinecraftClient.Protocol.Handlers.packet.s2c { @@ -59,7 +60,7 @@ internal static class DeclareCommands public static bool IsCommandTreeAvailable => HasValidCommandTree(); - public static void Read(DataTypes dataTypes, Queue packetData, int protocolVersion) + public static void Read(DataTypes dataTypes, PacketReader packetData, int protocolVersion) { Reset(); ConsoleIO.OnDeclareMinecraftCommand(Array.Empty()); @@ -88,7 +89,7 @@ public static List> CollectSignArguments(string command) : []; } - private static void ReadCommandTree(DataTypes dataTypes, Queue packetData, int protocolVersion) + private static void ReadCommandTree(DataTypes dataTypes, PacketReader packetData, int protocolVersion) { int count = dataTypes.ReadNextVarInt(packetData); Nodes = new CommandNode[count]; @@ -117,7 +118,7 @@ private static void ReadCommandTree(DataTypes dataTypes, Queue packetData, private static CommandNode ReadArgumentNode( DataTypes dataTypes, - Queue packetData, + PacketReader packetData, int protocolVersion, byte flags, int[] children, @@ -135,7 +136,7 @@ private static CommandNode ReadArgumentNode( return new(flags, children, redirectNode, name, descriptor, suggestionsType, parserId); } - private static int[] ReadChildIndices(DataTypes dataTypes, Queue packetData) + private static int[] ReadChildIndices(DataTypes dataTypes, PacketReader packetData) { int childCount = dataTypes.ReadNextVarInt(packetData); int[] children = new int[childCount]; @@ -146,7 +147,7 @@ private static int[] ReadChildIndices(DataTypes dataTypes, Queue packetDat return children; } - private static CommandArgumentDescriptor ReadArgumentDescriptor(DataTypes dataTypes, Queue packetData, ArgumentTypeLayout layout) + private static CommandArgumentDescriptor ReadArgumentDescriptor(DataTypes dataTypes, PacketReader packetData, ArgumentTypeLayout layout) { switch (layout.PayloadKind) { @@ -189,7 +190,7 @@ private static CommandArgumentDescriptor ReadArgumentDescriptor(DataTypes dataTy } } - private static void ReadNumberBounds(DataTypes dataTypes, Queue packetData, Func, TValue> readValue) + private static void ReadNumberBounds(DataTypes dataTypes, PacketReader packetData, Func readValue) { byte flags = dataTypes.ReadNextByte(packetData); if ((flags & 0x01) != 0) diff --git a/MinecraftClient/Protocol/Handlers/Protocol16.cs b/MinecraftClient/Protocol/Handlers/Protocol16.cs index 6777200d03..a25951494b 100644 --- a/MinecraftClient/Protocol/Handlers/Protocol16.cs +++ b/MinecraftClient/Protocol/Handlers/Protocol16.cs @@ -2,11 +2,13 @@ using System.Collections.Generic; using System.Diagnostics; using System.Globalization; +using System.IO; using System.Linq; using System.Net.Sockets; using System.Security.Cryptography; using System.Text; using System.Threading; +using System.Threading.Tasks; using MinecraftClient.Crypto; using MinecraftClient.Inventory; using MinecraftClient.Mapping; @@ -29,7 +31,9 @@ class Protocol16Handler : IMinecraftCom readonly IMinecraftComHandler handler; private bool encrypted = false; private readonly int protocolversion; - private Tuple? netRead = null; + private Task? netReadTask; + private CancellationTokenSource? netReadCancellationTokenSource; + private int netReadThreadId = -1; Crypto.AesCfb8Stream? s; readonly TcpClient c; @@ -69,15 +73,15 @@ private Protocol16Handler(TcpClient Client) c = Client; } - private void Updater(object? o) + private async Task UpdaterAsync(CancellationToken cancelToken) { - var cancelToken = (CancellationToken)o!; - if (cancelToken.IsCancellationRequested) return; try { + netReadThreadId = Environment.CurrentManagedThreadId; + using IDisposable _ = MainThreadExecutionScope.Enter(handler); Stopwatch stopWatch = Stopwatch.StartNew(); long nextUpdateDue = 0; @@ -97,13 +101,15 @@ private void Updater(object? o) long sleepLength = nextUpdateDue - stopWatch.ElapsedMilliseconds; if (sleepLength > 1) - Thread.Sleep((int)Math.Min(sleepLength, ClientTickIntervalMilliseconds)); + await Task.Delay((int)Math.Min(sleepLength, ClientTickIntervalMilliseconds), cancelToken); } } catch (System.IO.IOException) { } catch (SocketException) { } catch (ObjectDisposedException) { } catch (OperationCanceledException) { } + catch (Exception) { } + finally { netReadThreadId = -1; } if (cancelToken.IsCancellationRequested) return; @@ -240,9 +246,9 @@ private bool ProcessPacket(byte id) private void StartUpdating() { - netRead = new(new Thread(new ParameterizedThreadStart(Updater)), new CancellationTokenSource()); - netRead.Item1.Name = "ProtocolPacketHandler"; - netRead.Item1.Start(netRead.Item2.Token); + CancellationTokenSource netReadCts = new(); + netReadCancellationTokenSource = netReadCts; + netReadTask = Task.Run(() => UpdaterAsync(netReadCts.Token), netReadCts.Token); } /// @@ -251,7 +257,7 @@ private void StartUpdating() /// Net read thread ID public int GetNetMainThreadId() { - return netRead is not null ? netRead.Item1.ManagedThreadId : -1; + return netReadThreadId; } public bool SendCookieResponse(string name, byte[]? data) @@ -268,9 +274,9 @@ public void Dispose() { try { - if (netRead is not null) + if (netReadCancellationTokenSource is not null) { - netRead.Item2.Cancel(); + netReadCancellationTokenSource.Cancel(); c.Close(); } } @@ -471,6 +477,15 @@ private void Receive(byte[] buffer, int start, int offset, SocketFlags f) } } + private async Task ReceiveAsync(byte[] buffer, int start, int offset, CancellationToken cancellationToken = default) + { + if (offset <= 0) + return; + + Stream stream = encrypted ? s! : c.GetStream(); + await stream.ReadExactlyAsync(buffer.AsMemory(start, offset), cancellationToken); + } + private void Send(byte[] buffer) { if (encrypted) @@ -479,7 +494,61 @@ private void Send(byte[] buffer) c.Client.Send(buffer); } - private bool Handshake(string uuid, string username, string sessionID, string host, int port, SessionToken session) + private async Task SendAsync(byte[] buffer, CancellationToken cancellationToken = default) + { + if (buffer.Length == 0) + return; + + Stream stream = encrypted ? s! : c.GetStream(); + await stream.WriteAsync(buffer.AsMemory(0, buffer.Length), cancellationToken); + await stream.FlushAsync(cancellationToken); + } + + private async Task ReadDataAsync(int offset, CancellationToken cancellationToken = default) + { + if (offset <= 0) + return []; + + byte[] cache = new byte[offset]; + await ReceiveAsync(cache, 0, offset, cancellationToken); + return cache; + } + + private async Task ReadNextStringAsync(CancellationToken cancellationToken = default) + { + ushort length = (ushort)await ReadNextShortAsync(cancellationToken); + if (length <= 0) + return ""; + + byte[] cache = new byte[length * 2]; + await ReceiveAsync(cache, 0, length * 2, cancellationToken); + return Encoding.BigEndianUnicode.GetString(cache); + } + + private async Task ReadNextByteArrayAsync(CancellationToken cancellationToken = default) + { + short len = await ReadNextShortAsync(cancellationToken); + byte[] data = new byte[len]; + await ReceiveAsync(data, 0, len, cancellationToken); + return data; + } + + private async Task ReadNextShortAsync(CancellationToken cancellationToken = default) + { + byte[] tmp = new byte[2]; + await ReceiveAsync(tmp, 0, 2, cancellationToken); + Array.Reverse(tmp); + return BitConverter.ToInt16(tmp, 0); + } + + private async Task ReadNextByteAsync(CancellationToken cancellationToken = default) + { + byte[] result = new byte[1]; + await ReceiveAsync(result, 0, 1, cancellationToken); + return result[0]; + } + + private async Task HandshakeAsync(string uuid, string username, string sessionID, string host, int port, SessionToken session, CancellationToken cancellationToken = default) { //array byte[] data = new byte[10 + (username.Length + host.Length) * 2]; @@ -513,27 +582,28 @@ private bool Handshake(string uuid, string username, string sessionID, string ho Array.Reverse(sh); sh.CopyTo(data, 6 + (username.Length * 2) + (host.Length * 2)); - Send(data); + await SendAsync(data, cancellationToken); byte[] pid = new byte[1]; - Receive(pid, 0, 1, SocketFlags.None); + await ReceiveAsync(pid, 0, 1, cancellationToken); while (pid[0] == 0xFA) //Skip some early plugin messages { - ProcessPacket(pid[0]); - Receive(pid, 0, 1, SocketFlags.None); + using (MainThreadExecutionScope.Enter(handler)) + ProcessPacket(pid[0]); + await ReceiveAsync(pid, 0, 1, cancellationToken); } if (pid[0] == 0xFD) { - string serverID = ReadNextString(); - byte[] PublicServerkey = ReadNextByteArray(); - byte[] token = ReadNextByteArray(); + string serverID = await ReadNextStringAsync(cancellationToken); + byte[] PublicServerkey = await ReadNextByteArrayAsync(cancellationToken); + byte[] token = await ReadNextByteArrayAsync(cancellationToken); if (serverID == "-") ConsoleIO.WriteLineFormatted("§8" + Translations.mcc_server_offline, acceptnewlines: true); else if (Settings.Config.Logging.DebugMessages) ConsoleIO.WriteLineFormatted("§8" + string.Format(Translations.mcc_handshake, serverID)); - return StartEncryption(uuid, username, sessionID, Config.Main.General.AccountType, token, serverID, PublicServerkey, session); + return await StartEncryptionAsync(uuid, username, sessionID, Config.Main.General.AccountType, token, serverID, PublicServerkey, session, cancellationToken); } else { @@ -542,7 +612,7 @@ private bool Handshake(string uuid, string username, string sessionID, string ho } } - private bool StartEncryption(string uuid, string username, string sessionID, LoginType type, byte[] token, string serverIDhash, byte[] serverPublicKey, SessionToken session) + private async Task StartEncryptionAsync(string uuid, string username, string sessionID, LoginType type, byte[] token, string serverIDhash, byte[] serverPublicKey, SessionToken session, CancellationToken cancellationToken = default) { RSACryptoServiceProvider RSAService = CryptoHandler.DecodeRSAPublicKey(serverPublicKey)!; byte[] secretKey = CryptoHandler.ClientAESPrivateKey ?? CryptoHandler.GenerateAESPrivateKey(); @@ -559,14 +629,13 @@ private bool StartEncryption(string uuid, string username, string sessionID, Log if (session.ServerPublicKey is not null && session.SessionPreCheckTask is not null && serverIDhash == session.ServerIDhash && Enumerable.SequenceEqual(serverPublicKey, session.ServerPublicKey)) { - session.SessionPreCheckTask.Wait(); - if (session.SessionPreCheckTask.Result) // PreCheck Successed + if (session.SessionPreCheckTask.IsCompletedSuccessfully && session.SessionPreCheckTask.Result) needCheckSession = false; } if (needCheckSession) { - if (ProtocolHandler.SessionCheck(uuid, sessionID, serverHash, type)) + if (await ProtocolHandler.SessionCheckAsync(uuid, sessionID, serverHash, type)) { session.ServerIDhash = serverIDhash; session.ServerPublicKey = serverPublicKey; @@ -599,14 +668,14 @@ private bool StartEncryption(string uuid, string username, string sessionID, Log token_enc.CopyTo(data, 5 + (short)key_enc.Length); //Send it back - Send(data); + await SendAsync(data, cancellationToken); //Getting the next packet byte[] pid = new byte[1]; - Receive(pid, 0, 1, SocketFlags.None); + await ReceiveAsync(pid, 0, 1, cancellationToken); if (pid[0] == 0xFC) { - ReadData(4); + await ReadDataAsync(4, cancellationToken); s = new AesCfb8Stream(c.GetStream(), secretKey); encrypted = true; return true; @@ -620,9 +689,14 @@ private bool StartEncryption(string uuid, string username, string sessionID, Log public bool Login(PlayerKeyPair? playerKeyPair, SessionToken session, bool isTransfer = false) { - if (Handshake(handler.GetUserUuidStr(), handler.GetUsername(), handler.GetSessionID(), handler.GetServerHost(), handler.GetServerPort(), session)) + return LoginAsync(playerKeyPair, session, isTransfer).GetAwaiter().GetResult(); + } + + private async Task LoginAsync(PlayerKeyPair? playerKeyPair, SessionToken session, bool isTransfer = false) + { + if (await HandshakeAsync(handler.GetUserUuidStr(), handler.GetUsername(), handler.GetSessionID(), handler.GetServerHost(), handler.GetServerPort(), session)) { - Send(new byte[] { 0xCD, 0 }); + await SendAsync([0xCD, 0]); try { byte[] pid = new byte[1]; @@ -630,21 +704,24 @@ public bool Login(PlayerKeyPair? playerKeyPair, SessionToken session, bool isTra { if (c.Connected) { - Receive(pid, 0, 1, SocketFlags.None); + await ReceiveAsync(pid, 0, 1); while (pid[0] >= 0xC0 && pid[0] != 0xFF) //Skip some early packets or plugin messages { - ProcessPacket(pid[0]); - Receive(pid, 0, 1, SocketFlags.None); + using (MainThreadExecutionScope.Enter(handler)) + ProcessPacket(pid[0]); + await ReceiveAsync(pid, 0, 1); } if (pid[0] == (byte)1) { - ReadData(4); ReadNextString(); ReadData(5); + await ReadDataAsync(4); + _ = await ReadNextStringAsync(); + await ReadDataAsync(5); StartUpdating(); return true; //The Server accepted the request } else if (pid[0] == (byte)0xFF) { - string reason = ReadNextString(); + string reason = await ReadNextStringAsync(); handler.OnConnectionLost(ChatBot.DisconnectReason.LoginRejected, reason); return false; } diff --git a/MinecraftClient/Protocol/Handlers/Protocol18.cs b/MinecraftClient/Protocol/Handlers/Protocol18.cs index 21d37c887e..bdff67a046 100644 --- a/MinecraftClient/Protocol/Handlers/Protocol18.cs +++ b/MinecraftClient/Protocol/Handlers/Protocol18.cs @@ -9,6 +9,7 @@ using System.Text; using System.Text.RegularExpressions; using System.Threading; +using System.Threading.Tasks; using MinecraftClient.Crypto; using MinecraftClient.Inventory; using MinecraftClient.Inventory.ItemPalettes; @@ -20,6 +21,7 @@ using MinecraftClient.Protocol.Handlers.packet.s2c; using MinecraftClient.Protocol.Handlers.PacketPalettes; using MinecraftClient.Protocol.Message; +using MinecraftClient.Protocol.PacketPipeline; using MinecraftClient.Protocol.ProfileKey; using MinecraftClient.Protocol.Session; using MinecraftClient.Proxy; @@ -90,7 +92,7 @@ class Protocol18Handler : IMinecraftCom private readonly int rawProtocolVersion; private int currentDimension; private bool isOnlineMode = false; - private readonly BlockingCollection>> packetQueue = new(); + private readonly BlockingCollection packetQueue = new(); private readonly Dictionary legacyAchievementProgress = new(StringComparer.Ordinal); private float LastYaw, LastPitch; private double lastSentX, lastSentY, lastSentZ; @@ -117,8 +119,11 @@ class Protocol18Handler : IMinecraftCom readonly PacketTypePalette packetPalette; readonly SocketWrapper socketWrapper; readonly DataTypes dataTypes; - Tuple? netMain = null; // main thread - Tuple? netReader = null; // reader thread + private Task? netMainTask; + private CancellationTokenSource? netMainCancellationTokenSource; + private int netMainThreadId = -1; + private Task? netReaderTask; + private CancellationTokenSource? netReaderCancellationTokenSource; readonly ILogger log; readonly RandomNumberGenerator randomGen; private bool legacyAchievementsInitialized; @@ -278,17 +283,17 @@ public Protocol18Handler(TcpClient Client, int protocolVersion, IMinecraftComHan } /// - /// Separate thread. Network reading loop. + /// Serialized packet/tick loop. /// - private void Updater(object? o) + private async Task UpdaterAsync(CancellationToken cancelToken) { - var cancelToken = (CancellationToken)o!; - if (cancelToken.IsCancellationRequested) return; try { + netMainThreadId = Environment.CurrentManagedThreadId; + using IDisposable _ = MainThreadExecutionScope.Enter(handler); Stopwatch stopWatch = Stopwatch.StartNew(); long nextUpdateDue = 0; while (!packetQueue.IsAddingCompleted) @@ -305,14 +310,13 @@ private void Updater(object? o) if (packetQueue.TryTake(out var packetInfo, 1)) { - var (packetId, packetData) = packetInfo; - HandlePacket(packetId, packetData); + HandlePacket(packetInfo.PacketId, packetInfo.CreateReader()); continue; } long sleepLength = nextUpdateDue - stopWatch.ElapsedMilliseconds; if (sleepLength > 1) - Thread.Sleep((int)Math.Min(sleepLength, ClientTickIntervalMilliseconds)); + await Task.Delay((int)Math.Min(sleepLength, ClientTickIntervalMilliseconds), cancelToken); } } catch (ObjectDisposedException) @@ -330,6 +334,13 @@ private void Updater(object? o) catch (System.IO.IOException) { } + catch (Exception) + { + } + finally + { + netMainThreadId = -1; + } if (cancelToken.IsCancellationRequested) return; @@ -340,20 +351,13 @@ private void Updater(object? o) /// /// Read and decompress packets. /// - internal void PacketReader(object? o) + internal async Task PacketReaderAsync(CancellationToken cancelToken) { - var cancelToken = (CancellationToken)o!; - while (socketWrapper.IsConnected() && !cancelToken.IsCancellationRequested) + while (!cancelToken.IsCancellationRequested) { try { - while (socketWrapper.HasDataAvailable()) - { - packetQueue.Add(ReadNextPacket(), cancelToken); - - if (cancelToken.IsCancellationRequested) - break; - } + packetQueue.Add(await ReadNextPacketAsync(cancelToken), cancelToken); } catch (OperationCanceledException) { @@ -375,11 +379,10 @@ internal void PacketReader(object? o) { break; } - - if (cancelToken.IsCancellationRequested) + catch (Exception) + { break; - - Thread.Sleep(10); + } } packetQueue.CompleteAdding(); @@ -390,29 +393,27 @@ internal void PacketReader(object? o) /// /// will contain packet ID /// will contain raw packet Data - internal Tuple> ReadNextPacket() + internal IncomingPacket ReadNextPacket() { - var size = dataTypes.ReadNextVarIntRAW(socketWrapper); //Packet size - Queue packetData = new(socketWrapper.ReadDataRAW(size)); //Packet contents + IncomingPacket packet = socketWrapper.GetNextPacket( + protocolVersion >= MC_1_8_Version ? compression_treshold : -1, + dataTypes); + if (handler.GetNetworkPacketCaptureEnabled()) + handler.OnNetworkPacket(packet.PacketId, packet.Payload.ToList(), currentState == CurrentState.Login, true); - //Handle packet decompression - if (protocolVersion >= MC_1_8_Version - && compression_treshold >= 0) - { - var sizeUncompressed = dataTypes.ReadNextVarInt(packetData); - if (sizeUncompressed != 0) // != 0 means compressed, let's decompress - { - var toDecompress = packetData.ToArray(); - var uncompressed = ZlibUtils.Decompress(toDecompress, sizeUncompressed); - packetData = new Queue(uncompressed); - } - } + return packet; + } - var packetId = dataTypes.ReadNextVarInt(packetData); // Packet ID + internal async Task ReadNextPacketAsync(CancellationToken cancellationToken) + { + IncomingPacket packet = await socketWrapper.GetNextPacketAsync( + protocolVersion >= MC_1_8_Version ? compression_treshold : -1, + dataTypes, + cancellationToken); if (handler.GetNetworkPacketCaptureEnabled()) - handler.OnNetworkPacket(packetId, packetData.ToList(), currentState == CurrentState.Login, true); + handler.OnNetworkPacket(packet.PacketId, packet.Payload.ToList(), currentState == CurrentState.Login, true); - return new(packetId, packetData); + return packet; } /// @@ -421,11 +422,9 @@ internal Tuple> ReadNextPacket() /// Packet ID /// Packet contents /// TRUE if the packet was processed, FALSE if ignored or unknown - internal bool HandlePacket(int packetId, Queue packetData) + internal bool HandlePacket(int packetId, PacketReader packetData) { - // This copy is necessary because by the time we get to the catch block, - // the packetData queue will have been processed and the data will be lost - var _copy = packetData.ToArray(); + byte[] _copy = packetData.GetRawData(); try { @@ -487,11 +486,11 @@ internal bool HandlePacket(int packetId, Queue packetData) break; case ConfigurationPacketTypesIn.KeepAlive: - SendPacket(ConfigurationPacketTypesOut.KeepAlive, packetData); + SendPacket(ConfigurationPacketTypesOut.KeepAlive, packetData.CopyRemaining()); break; case ConfigurationPacketTypesIn.Ping: - SendPacket(ConfigurationPacketTypesOut.Pong, packetData); + SendPacket(ConfigurationPacketTypesOut.Pong, packetData.CopyRemaining()); break; case ConfigurationPacketTypesIn.RegistryData: @@ -673,7 +672,7 @@ internal bool HandlePacket(int packetId, Queue packetData) return true; } - public void HandleResourcePackPacket(Queue packetData) + public void HandleResourcePackPacket(PacketReader packetData) { var uuid = Guid.Empty; @@ -721,17 +720,17 @@ public void HandleResourcePackPacket(Queue packetData) } } - private bool HandlePlayPackets(int packetId, Queue packetData) + private bool HandlePlayPackets(int packetId, PacketReader packetData) { switch (packetPalette.GetIncomingTypeById(packetId)) { case PacketTypesIn.KeepAlive: // Keep Alive (Play) - SendPacket(PacketTypesOut.KeepAlive, packetData); + SendPacket(PacketTypesOut.KeepAlive, packetData.CopyRemaining()); handler.OnServerKeepAlive(); break; case PacketTypesIn.Ping: - SendPacket(PacketTypesOut.Pong, packetData); + SendPacket(PacketTypesOut.Pong, packetData.CopyRemaining()); break; case PacketTypesIn.JoinGame: @@ -1601,7 +1600,7 @@ private bool HandlePlayPackets(int packetId, Queue packetData) pTerrain.ProcessChunkColumnData(chunkX, chunkZ, chunkMask, addBitmap, currentDimension == 0, chunksContinuous, currentDimension, - new Queue(decompressed)); + new PacketReader(decompressed)); Interlocked.Decrement(ref handler.GetWorld().chunkLoadNotCompleted); } else @@ -1983,7 +1982,7 @@ private bool HandlePlayPackets(int packetId, Queue packetData) hasSkyLight = dataTypes.ReadNextBool(packetData); var compressed = dataTypes.ReadData(compressedDataSize, packetData); var decompressed = ZlibUtils.Decompress(compressed); - chunkData = new Queue(decompressed); + chunkData = new PacketReader(decompressed); } else { @@ -2296,7 +2295,7 @@ private bool HandlePlayPackets(int packetId, Queue packetData) // Length is unneeded as the whole remaining packetData is the entire payload of the packet. if (protocolVersion < MC_1_8_Version) pForge.ReadNextVarShort(packetData); - handler.OnPluginChannelMessage(channel, packetData.ToArray()); + handler.OnPluginChannelMessage(channel, packetData.CopyRemaining()); return pForge.HandlePluginMessage(channel, packetData, ref currentDimension); case PacketTypesIn.Disconnect: handler.OnConnectionLost(ChatBot.DisconnectReason.InGameKick, @@ -3322,7 +3321,7 @@ private bool HandlePlayPackets(int packetId, Queue packetData) /// Read a Holder<SoundEvent> from packet data and return its key when inline. /// Returns null when the holder is a registry reference. /// - private string? ReadSoundEventHolderName(Queue packetData) + private string? ReadSoundEventHolderName(PacketReader packetData) { int soundHolderId = dataTypes.ReadNextVarInt(packetData); if (soundHolderId != 0) @@ -3338,7 +3337,7 @@ private bool HandlePlayPackets(int packetId, Queue packetData) /// /// Handle the Statistics packet for pre-1.12 legacy achievements. /// - private void HandleLegacyStatistics(Queue packetData) + private void HandleLegacyStatistics(PacketReader packetData) { int statCount = dataTypes.ReadNextVarInt(packetData); @@ -3369,7 +3368,7 @@ private void HandleLegacyStatistics(Queue packetData) /// /// Handle the Advancements packet (1.12+). /// - private void HandleAdvancements(Queue packetData) + private void HandleAdvancements(PacketReader packetData) { bool reset = dataTypes.ReadNextBool(packetData); @@ -3540,14 +3539,14 @@ private static bool ComputeAdvancementCompleted(List> requirements, /// /// Handle the SelectAdvancementTab packet. /// - private void HandleSelectAdvancementTab(Queue packetData) + private void HandleSelectAdvancementTab(PacketReader packetData) { bool hasTab = dataTypes.ReadNextBool(packetData); string? tabId = hasTab ? dataTypes.ReadNextString(packetData) : null; handler.OnSelectAdvancementTab(tabId); } - private void HandleUnlockRecipes(Queue packetData) + private void HandleUnlockRecipes(PacketReader packetData) { int action = dataTypes.ReadNextVarInt(packetData); if (!SkipRecipeBookSettings(packetData)) @@ -3575,7 +3574,7 @@ private void HandleUnlockRecipes(Queue packetData) } } - private void HandleRecipeBookAdd(Queue packetData) + private void HandleRecipeBookAdd(PacketReader packetData) { int entryCount = dataTypes.ReadNextVarInt(packetData); RecipeBookRecipeEntry[] recipeEntries = new RecipeBookRecipeEntry[entryCount]; @@ -3592,7 +3591,7 @@ private void HandleRecipeBookAdd(Queue packetData) handler.OnRecipeBookAdd(recipeEntries, replace); } - private string[] ReadRecipeBookRecipeIds(Queue packetData) + private string[] ReadRecipeBookRecipeIds(PacketReader packetData) { int recipeCount = dataTypes.ReadNextVarInt(packetData); string[] recipeIds = new string[recipeCount]; @@ -3603,7 +3602,7 @@ private string[] ReadRecipeBookRecipeIds(Queue packetData) return recipeIds; } - private string[] ReadRecipeBookDisplayIds(Queue packetData) + private string[] ReadRecipeBookDisplayIds(PacketReader packetData) { int recipeCount = dataTypes.ReadNextVarInt(packetData); string[] recipeIds = new string[recipeCount]; @@ -3614,7 +3613,7 @@ private string[] ReadRecipeBookDisplayIds(Queue packetData) return recipeIds; } - private RecipeBookRecipeEntry ReadRecipeBookDisplayEntry(Queue packetData) + private RecipeBookRecipeEntry ReadRecipeBookDisplayEntry(PacketReader packetData) { int displayId = dataTypes.ReadNextVarInt(packetData); string resultLabel = ReadRecipeDisplayResultLabel(packetData); @@ -3628,7 +3627,7 @@ private RecipeBookRecipeEntry ReadRecipeBookDisplayEntry(Queue packetData) return new RecipeBookRecipeEntry(commandId, displayText); } - private string ReadRecipeDisplayResultLabel(Queue packetData) + private string ReadRecipeDisplayResultLabel(PacketReader packetData) { int displayType = dataTypes.ReadNextVarInt(packetData); return displayType switch @@ -3642,7 +3641,7 @@ private string ReadRecipeDisplayResultLabel(Queue packetData) }; } - private string ReadShapelessRecipeDisplayResultLabel(Queue packetData) + private string ReadShapelessRecipeDisplayResultLabel(PacketReader packetData) { int ingredientCount = dataTypes.ReadNextVarInt(packetData); for (int i = 0; i < ingredientCount; i++) @@ -3653,7 +3652,7 @@ private string ReadShapelessRecipeDisplayResultLabel(Queue packetData) return result; } - private string ReadShapedRecipeDisplayResultLabel(Queue packetData) + private string ReadShapedRecipeDisplayResultLabel(PacketReader packetData) { _ = dataTypes.ReadNextVarInt(packetData); // width _ = dataTypes.ReadNextVarInt(packetData); // height @@ -3666,7 +3665,7 @@ private string ReadShapedRecipeDisplayResultLabel(Queue packetData) return result; } - private string ReadFurnaceRecipeDisplayResultLabel(Queue packetData) + private string ReadFurnaceRecipeDisplayResultLabel(PacketReader packetData) { _ = ReadSlotDisplayLabel(packetData); // ingredient _ = ReadSlotDisplayLabel(packetData); // fuel @@ -3677,7 +3676,7 @@ private string ReadFurnaceRecipeDisplayResultLabel(Queue packetData) return result; } - private string ReadStonecutterRecipeDisplayResultLabel(Queue packetData) + private string ReadStonecutterRecipeDisplayResultLabel(PacketReader packetData) { _ = ReadSlotDisplayLabel(packetData); // input string result = ReadSlotDisplayLabel(packetData); @@ -3685,7 +3684,7 @@ private string ReadStonecutterRecipeDisplayResultLabel(Queue packetData) return result; } - private string ReadSmithingRecipeDisplayResultLabel(Queue packetData) + private string ReadSmithingRecipeDisplayResultLabel(PacketReader packetData) { _ = ReadSlotDisplayLabel(packetData); // template _ = ReadSlotDisplayLabel(packetData); // base @@ -3695,7 +3694,7 @@ private string ReadSmithingRecipeDisplayResultLabel(Queue packetData) return result; } - private string ReadSlotDisplayLabel(Queue packetData) + private string ReadSlotDisplayLabel(PacketReader packetData) { int slotDisplayType = dataTypes.ReadNextVarInt(packetData); @@ -3738,7 +3737,7 @@ private string ReadSlotDisplayLabel(Queue packetData) /// /// Reads a with_any_potion slot display (26.1+): contains a nested SlotDisplay. /// - private string ReadWithAnyPotionSlotDisplayLabel(Queue packetData) + private string ReadWithAnyPotionSlotDisplayLabel(PacketReader packetData) { return ReadSlotDisplayLabel(packetData); } @@ -3746,7 +3745,7 @@ private string ReadWithAnyPotionSlotDisplayLabel(Queue packetData) /// /// Reads an only_with_component slot display (26.1+): contains a nested SlotDisplay and a DataComponentType VarInt ID. /// - private string ReadOnlyWithComponentSlotDisplayLabel(Queue packetData) + private string ReadOnlyWithComponentSlotDisplayLabel(PacketReader packetData) { string sourceLabel = ReadSlotDisplayLabel(packetData); _ = dataTypes.ReadNextVarInt(packetData); // DataComponentType registry id @@ -3756,14 +3755,14 @@ private string ReadOnlyWithComponentSlotDisplayLabel(Queue packetData) /// /// Reads a dyed slot display (26.1+): contains two nested SlotDisplays (dye + target). /// - private string ReadDyedSlotDisplayLabel(Queue packetData) + private string ReadDyedSlotDisplayLabel(PacketReader packetData) { _ = ReadSlotDisplayLabel(packetData); // dye string targetLabel = ReadSlotDisplayLabel(packetData); // target return targetLabel; } - private string ReadSmithingTrimSlotDisplayLabel(Queue packetData) + private string ReadSmithingTrimSlotDisplayLabel(PacketReader packetData) { string baseLabel = ReadSlotDisplayLabel(packetData); _ = ReadSlotDisplayLabel(packetData); // material @@ -3771,14 +3770,14 @@ private string ReadSmithingTrimSlotDisplayLabel(Queue packetData) return baseLabel; } - private string ReadWithRemainderSlotDisplayLabel(Queue packetData) + private string ReadWithRemainderSlotDisplayLabel(PacketReader packetData) { string inputLabel = ReadSlotDisplayLabel(packetData); _ = ReadSlotDisplayLabel(packetData); // remainder return inputLabel; } - private string ReadCompositeSlotDisplayLabel(Queue packetData) + private string ReadCompositeSlotDisplayLabel(PacketReader packetData) { int optionCount = dataTypes.ReadNextVarInt(packetData); string label = "Composite"; @@ -3797,12 +3796,12 @@ private string ReadCompositeSlotDisplayLabel(Queue packetData) /// Read an ItemStackTemplate (26.1+) which encodes fields in a different order /// than ItemStack: item_id (VarInt), count (VarInt), DataComponentPatch. /// - private string ReadItemStackTemplateLabel(Queue packetData) + private string ReadItemStackTemplateLabel(PacketReader packetData) { return dataTypes.ReadNextItemStackTemplate(packetData, itemPalette).GetTypeString(); } - private void SkipOptionalCraftingRequirements(Queue packetData) + private void SkipOptionalCraftingRequirements(PacketReader packetData) { if (!dataTypes.ReadNextBool(packetData)) return; @@ -3812,7 +3811,7 @@ private void SkipOptionalCraftingRequirements(Queue packetData) SkipItemHolderSet(packetData); } - private void SkipItemHolderSet(Queue packetData) + private void SkipItemHolderSet(PacketReader packetData) { int entryCount = dataTypes.ReadNextVarInt(packetData) - 1; if (entryCount == -1) @@ -3825,12 +3824,12 @@ private void SkipItemHolderSet(Queue packetData) _ = dataTypes.ReadNextVarInt(packetData); } - private bool SkipRecipeBookSettings(Queue packetData) + private bool SkipRecipeBookSettings(PacketReader packetData) { // MC 1.13 uses 4 booleans for the crafting/smelting recipe book states. // MC 1.14+ expands this to 8 booleans by adding blast furnace and smoker states. int boolCount = protocolVersion >= MC_1_14_Version ? 8 : 4; - if (packetData.Count < boolCount) + if (packetData.RemainingLength < boolCount) return false; for (int i = 0; i < boolCount; i++) @@ -3840,23 +3839,17 @@ private bool SkipRecipeBookSettings(Queue packetData) } /// - /// Start the updating thread. Should be called after login success. + /// Start the serialized packet/tick tasks. Should be called after login success. /// private void StartUpdating() { - Thread threadUpdater = new(new ParameterizedThreadStart(Updater)) - { - Name = "ProtocolPacketHandler" - }; - netMain = new Tuple(threadUpdater, new CancellationTokenSource()); - threadUpdater.Start(netMain.Item2.Token); + CancellationTokenSource netMainCts = new(); + netMainCancellationTokenSource = netMainCts; + netMainTask = Task.Run(() => UpdaterAsync(netMainCts.Token), netMainCts.Token); - Thread threadReader = new(new ParameterizedThreadStart(PacketReader)) - { - Name = "ProtocolPacketReader" - }; - netReader = new Tuple(threadReader, new CancellationTokenSource()); - threadReader.Start(netReader.Item2.Token); + CancellationTokenSource netReaderCts = new(); + netReaderCancellationTokenSource = netReaderCts; + netReaderTask = Task.Run(() => PacketReaderAsync(netReaderCts.Token), netReaderCts.Token); } /// @@ -3865,7 +3858,7 @@ private void StartUpdating() /// Net read thread ID public int GetNetMainThreadId() { - return netMain is not null ? netMain.Item1.ManagedThreadId : -1; + return netMainThreadId; } /// @@ -3875,14 +3868,14 @@ public void Dispose() { try { - if (netMain is not null) + if (netMainCancellationTokenSource is not null) { - netMain.Item2.Cancel(); + netMainCancellationTokenSource.Cancel(); } - if (netReader is not null) + if (netReaderCancellationTokenSource is not null) { - netReader.Item2.Cancel(); + netReaderCancellationTokenSource.Cancel(); socketWrapper.Disconnect(); } } @@ -3901,9 +3894,9 @@ private void SendPacket(PacketTypesOut packet, IEnumerable packetData) SendPacket(packetPalette.GetOutgoingIdByType(packet), packetData); } - private void ProcessChunkBlockEntityData(int chunkX, int chunkZ, Queue packetData) + private void ProcessChunkBlockEntityData(int chunkX, int chunkZ, PacketReader packetData) { - if (protocolVersion < MC_1_17_Version || packetData.Count == 0) + if (protocolVersion < MC_1_17_Version || packetData.RemainingLength == 0) return; int blockEntityCount = dataTypes.ReadNextVarInt(packetData); @@ -3979,6 +3972,11 @@ private void SendPacket(int packetId, IEnumerable packetData) /// /// True if login successful public bool Login(PlayerKeyPair? playerKeyPair, SessionToken session, bool isTransfer = false) + { + return LoginAsync(playerKeyPair, session, isTransfer).GetAwaiter().GetResult(); + } + + private async Task LoginAsync(PlayerKeyPair? playerKeyPair, SessionToken session, bool isTransfer = false) { int nextState = isTransfer && protocolVersion >= MC_1_20_6_Version ? 3 : 2; @@ -4057,7 +4055,9 @@ public bool Login(PlayerKeyPair? playerKeyPair, SessionToken session, bool isTra // 3. Encryption Request - 9. Login Acknowledged while (true) { - var (packetId, packetData) = ReadNextPacket(); + IncomingPacket packet = await ReadNextPacketAsync(CancellationToken.None); + int packetId = packet.PacketId; + PacketReader packetData = packet.CreateReader(); switch (packetId) { @@ -4080,7 +4080,7 @@ public bool Login(PlayerKeyPair? playerKeyPair, SessionToken session, bool isTra if (protocolVersion >= MC_1_20_6_Version) shouldAuthetnicate = dataTypes.ReadNextBool(packetData); - return StartEncryption(handler.GetUserUuidStr(), handler.GetSessionID(), + return await StartEncryptionAsync(handler.GetUserUuidStr(), handler.GetSessionID(), Config.Main.General.AccountType, token, serverId, serverPublicKey, playerKeyPair, session, shouldAuthetnicate); } @@ -4096,7 +4096,7 @@ public bool Login(PlayerKeyPair? playerKeyPair, SessionToken session, bool isTra if (protocolVersion >= MC_1_20_2_Version) SendPacket(0x03, new List()); - if (!pForge.CompleteForgeHandshake()) + if (!await pForge.CompleteForgeHandshakeAsync()) { log.Error($"§8{Translations.error_forge}"); return false; @@ -4106,7 +4106,8 @@ public bool Login(PlayerKeyPair? playerKeyPair, SessionToken session, bool isTra return true; //No need to check session or start encryption } default: - HandlePacket(packetId, packetData); + using (MainThreadExecutionScope.Enter(handler)) + HandlePacket(packetId, packetData); break; } } @@ -4116,7 +4117,7 @@ public bool Login(PlayerKeyPair? playerKeyPair, SessionToken session, bool isTra /// Start network encryption. Automatically called by Login() if the server requests encryption. /// /// True if encryption was successful - private bool StartEncryption(string uuid, string sessionID, LoginType type, byte[] token, string serverIDhash, + private async Task StartEncryptionAsync(string uuid, string sessionID, LoginType type, byte[] token, string serverIDhash, byte[] serverPublicKey, PlayerKeyPair? playerKeyPair, SessionToken session, bool shouldAuthetnicate) { var RSAService = CryptoHandler.DecodeRSAPublicKey(serverPublicKey)!; @@ -4133,8 +4134,7 @@ private bool StartEncryption(string uuid, string sessionID, LoginType type, byte && serverIDhash == session.ServerIDhash && serverPublicKey.SequenceEqual(session.ServerPublicKey)) { - session.SessionPreCheckTask.Wait(); - if (session.SessionPreCheckTask.Result) // PreCheck Success + if (session.SessionPreCheckTask.IsCompletedSuccessfully && session.SessionPreCheckTask.Result) needCheckSession = false; } @@ -4145,7 +4145,7 @@ private bool StartEncryption(string uuid, string sessionID, LoginType type, byte if (needCheckSession) { var serverHash = CryptoHandler.GetServerHash(serverIDhash, serverPublicKey, secretKey); - if (ProtocolHandler.SessionCheck(uuid, sessionID, serverHash, type)) + if (await ProtocolHandler.SessionCheckAsync(uuid, sessionID, serverHash, type)) { session.ServerIDhash = serverIDhash; session.ServerPublicKey = serverPublicKey; @@ -4195,7 +4195,9 @@ private bool StartEncryption(string uuid, string sessionID, LoginType type, byte int loopPrevention = ushort.MaxValue; while (true) { - var (packetId, packetData) = ReadNextPacket(); + IncomingPacket packet = await ReadNextPacketAsync(CancellationToken.None); + int packetId = packet.PacketId; + PacketReader packetData = packet.CreateReader(); if (packetId < 0 || loopPrevention-- < 0) // Failed to read packet or too many iterations (issue #1150) { handler.OnConnectionLost(ChatBot.DisconnectReason.ConnectionLost, @@ -4246,7 +4248,7 @@ private bool StartEncryption(string uuid, string sessionID, LoginType type, byte handler.OnLoginSuccess(uuidReceived, userName, playerProperty); - if (!pForge.CompleteForgeHandshake()) + if (!await pForge.CompleteForgeHandshakeAsync()) { log.Error($"§8{Translations.error_forge_encrypt}"); return false; @@ -4256,7 +4258,8 @@ private bool StartEncryption(string uuid, string sessionID, LoginType type, byte return true; } default: - HandlePacket(packetId, packetData); + using (MainThreadExecutionScope.Enter(handler)) + HandlePacket(packetId, packetData); break; } } @@ -4355,17 +4358,12 @@ public static bool DoPing(string host, int port, ref int protocolVersion, ref Fo var statusRequest = DataTypes.GetVarInt(0); socketWrapper.SendDataRAW(dataTypes.ConcatBytes(DataTypes.GetVarInt(statusRequest.Length), statusRequest)); - // Read Response length - var packetLength = dataTypes.ReadNextVarIntRAW(socketWrapper); - if (packetLength <= 0) - return false; - - // Read the Packet Id - var packetData = new Queue(socketWrapper.ReadDataRAW(packetLength)); - if (dataTypes.ReadNextVarInt(packetData) != 0x00) + IncomingPacket statusPacket = socketWrapper.GetNextPacket(-1, dataTypes); + if (statusPacket.PacketId != 0x00) return false; // Get the Json data + var packetData = statusPacket.CreateReader(); var result = dataTypes.ReadNextString(packetData); if (Config.Logging.DebugMessages) @@ -4441,15 +4439,12 @@ public static bool DoPing(string host, int port, ref int protocolVersion, ref Fo var pingRequest = dataTypes.ConcatBytes(DataTypes.GetVarInt(0x01), DataTypes.GetLong(pingPayload)); socketWrapper.SendDataRAW(dataTypes.ConcatBytes(DataTypes.GetVarInt(pingRequest.Length), pingRequest)); - packetLength = dataTypes.ReadNextVarIntRAW(socketWrapper); - if (packetLength > 0) + IncomingPacket pongPacket = socketWrapper.GetNextPacket(-1, dataTypes); + if (pongPacket.PacketId == 0x01) { - packetData = new Queue(socketWrapper.ReadDataRAW(packetLength)); - if (dataTypes.ReadNextVarInt(packetData) == 0x01) - { - long pongPayload = dataTypes.ReadNextLong(packetData); - pingMs = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds() - pingPayload; - } + var pongPacketData = pongPacket.CreateReader(); + long pongPayload = dataTypes.ReadNextLong(pongPacketData); + pingMs = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds() - pingPayload; } } catch diff --git a/MinecraftClient/Protocol/Handlers/Protocol18Forge.cs b/MinecraftClient/Protocol/Handlers/Protocol18Forge.cs index 7359d20060..3fef412233 100644 --- a/MinecraftClient/Protocol/Handlers/Protocol18Forge.cs +++ b/MinecraftClient/Protocol/Handlers/Protocol18Forge.cs @@ -3,8 +3,10 @@ using System.Linq; using System.Text; using System.Threading; +using System.Threading.Tasks; using MinecraftClient.Protocol.Handlers.Forge; using MinecraftClient.Protocol.Message; +using MinecraftClient.Protocol.PacketPipeline; using MinecraftClient.Scripting; namespace MinecraftClient.Protocol.Handlers @@ -21,6 +23,7 @@ class Protocol18Forge(ForgeInfo? forgeInfo, int protocolVersion, DataTypes dataT private readonly ForgeInfo? forgeInfo = forgeInfo; private FMLHandshakeClientState fmlHandshakeState = FMLHandshakeClientState.START; + private DateTime? pendingServerDataAckAt; private bool ForgeEnabled() { return forgeInfo is not null; } /// @@ -40,12 +43,19 @@ public string GetServerAddress(string serverAddress) /// /// Whether the handshake was successful. public bool CompleteForgeHandshake() + { + return CompleteForgeHandshakeAsync().GetAwaiter().GetResult(); + } + + public async Task CompleteForgeHandshakeAsync(CancellationToken cancellationToken = default) { if (ForgeEnabled() && forgeInfo!.Version == FMLVersion.FML) { while (fmlHandshakeState != FMLHandshakeClientState.DONE) { - (int packetID, Queue packetData) = protocol18.ReadNextPacket(); + IncomingPacket packet = await protocol18.ReadNextPacketAsync(cancellationToken); + int packetID = packet.PacketId; + PacketReader packetData = packet.CreateReader(); if (packetID == 0x40) // Disconnect { @@ -56,18 +66,37 @@ public bool CompleteForgeHandshake() { // Send back regular packet to the vanilla protocol handler protocol18.HandlePacket(packetID, packetData); + await FlushPendingHandshakeActionsAsync(cancellationToken); } } } return true; } + private async Task FlushPendingHandshakeActionsAsync(CancellationToken cancellationToken) + { + if (!pendingServerDataAckAt.HasValue) + return; + + TimeSpan delay = pendingServerDataAckAt.Value - DateTime.UtcNow; + if (delay > TimeSpan.Zero) + await Task.Delay(delay, cancellationToken); + + if (Settings.Config.Logging.DebugMessages) + ConsoleIO.WriteLineFormatted("§8" + Translations.forge_accept, acceptnewlines: true); + + SendForgeHandshakePacket(FMLHandshakeDiscriminator.HandshakeAck, + new byte[] { (byte)FMLHandshakeClientState.WAITINGSERVERDATA }); + + pendingServerDataAckAt = null; + } + /// /// Read Forge VarShort field /// /// Packet data to read from /// Length from packet data - public int ReadNextVarShort(Queue packetData) + public int ReadNextVarShort(PacketReader packetData) { if (ForgeEnabled()) { @@ -88,7 +117,7 @@ public int ReadNextVarShort(Queue packetData) /// Plugin message data /// Current world dimension /// TRUE if the plugin message was recognized and handled - public bool HandlePluginMessage(string channel, Queue packetData, ref int currentDimension) + public bool HandlePluginMessage(string channel, PacketReader packetData, ref int currentDimension) { if (ForgeEnabled() && forgeInfo!.Version == FMLVersion.FML && fmlHandshakeState != FMLHandshakeClientState.DONE) { @@ -145,16 +174,9 @@ public bool HandlePluginMessage(string channel, Queue packetData, ref int if (discriminator != FMLHandshakeDiscriminator.ModList) return false; - Thread.Sleep(2000); - - if (Settings.Config.Logging.DebugMessages) - ConsoleIO.WriteLineFormatted("§8" + Translations.forge_accept, acceptnewlines: true); // Tell the server that yes, we are OK with the mods it has // even though we don't actually care what mods it has. - - SendForgeHandshakePacket(FMLHandshakeDiscriminator.HandshakeAck, - new byte[] { (byte)FMLHandshakeClientState.WAITINGSERVERDATA }); - + pendingServerDataAckAt = DateTime.UtcNow.AddSeconds(2); fmlHandshakeState = FMLHandshakeClientState.WAITINGSERVERCOMPLETE; return false; case FMLHandshakeClientState.WAITINGSERVERCOMPLETE: @@ -224,7 +246,7 @@ public bool HandlePluginMessage(string channel, Queue packetData, ref int /// Plugin message data /// Response data to return to server /// TRUE/FALSE depending on whether the packet was understood or not - public bool HandleLoginPluginRequest(string channel, Queue packetData, ref List responseData) + public bool HandleLoginPluginRequest(string channel, PacketReader packetData, ref List responseData) { if (ForgeEnabled() && (forgeInfo!.Version == FMLVersion.FML2 || forgeInfo!.Version == FMLVersion.FML3) && channel == "fml:loginwrapper") { @@ -313,7 +335,7 @@ public bool HandleLoginPluginRequest(string channel, Queue packetData, ref // FML3 specific, List dataPackRegistries = new(); - if (forgeInfo!.Version == FMLVersion.FML3 && packetData.Count != 0) + if (forgeInfo!.Version == FMLVersion.FML3 && packetData.RemainingLength != 0) { int dataPackRegistryCount = dataTypes.ReadNextVarInt(packetData); for (int i = 0; i < dataPackRegistryCount; i++) diff --git a/MinecraftClient/Protocol/Handlers/Protocol18Terrain.cs b/MinecraftClient/Protocol/Handlers/Protocol18Terrain.cs index 9bd29e87c4..92ccac037b 100644 --- a/MinecraftClient/Protocol/Handlers/Protocol18Terrain.cs +++ b/MinecraftClient/Protocol/Handlers/Protocol18Terrain.cs @@ -6,6 +6,7 @@ //using System.Linq; //using System.Text; using MinecraftClient.Mapping; +using MinecraftClient.Protocol.PacketPipeline; namespace MinecraftClient.Protocol.Handlers { @@ -23,7 +24,7 @@ class Protocol18Terrain(int protocolVersion, DataTypes dataTypes, IMinecraftComH /// /// Cache for reading data [MethodImpl(MethodImplOptions.AggressiveOptimization)] - private Chunk? ReadBlockStatesField(Queue cache) + private Chunk? ReadBlockStatesField(PacketReader cache) { // read Block states (Type: Paletted Container) byte bitsPerEntry = dataTypes.ReadNextByte(cache); @@ -134,7 +135,7 @@ class Protocol18Terrain(int protocolVersion, DataTypes dataTypes, IMinecraftComH /// Cache for reading chunk data /// token to cancel the task [MethodImpl(MethodImplOptions.AggressiveOptimization)] - public void ProcessChunkColumnData(int chunkX, int chunkZ, ulong[]? verticalStripBitmask, Queue cache) + public void ProcessChunkColumnData(int chunkX, int chunkZ, ulong[]? verticalStripBitmask, PacketReader cache) { World world = handler.GetWorld(); @@ -236,7 +237,7 @@ public void ProcessChunkColumnData(int chunkX, int chunkZ, ulong[]? verticalStri /// Cache for reading chunk data /// token to cancel the task [MethodImpl(MethodImplOptions.AggressiveOptimization)] - public void ProcessChunkColumnData(int chunkX, int chunkZ, ushort chunkMask, ushort chunkMask2, bool hasSkyLight, bool chunksContinuous, int currentDimension, Queue cache) + public void ProcessChunkColumnData(int chunkX, int chunkZ, ushort chunkMask, ushort chunkMask2, bool hasSkyLight, bool chunksContinuous, int currentDimension, PacketReader cache) { World world = handler.GetWorld(); diff --git a/MinecraftClient/Protocol/Handlers/SocketWrapper.cs b/MinecraftClient/Protocol/Handlers/SocketWrapper.cs index a4f451b134..3260adebeb 100644 --- a/MinecraftClient/Protocol/Handlers/SocketWrapper.cs +++ b/MinecraftClient/Protocol/Handlers/SocketWrapper.cs @@ -1,6 +1,12 @@ using System; +using System.Collections.Generic; +using System.IO; +using System.IO.Compression; using System.Net.Sockets; +using System.Threading; +using System.Threading.Tasks; using MinecraftClient.Crypto; +using MinecraftClient.Protocol.PacketPipeline; namespace MinecraftClient.Protocol.Handlers { @@ -9,9 +15,14 @@ namespace MinecraftClient.Protocol.Handlers /// public class SocketWrapper { - readonly TcpClient c; - AesCfb8Stream? s; - bool encrypted = false; + private readonly TcpClient client; + private readonly Stream networkStream; + private readonly SemaphoreSlim sendSemaphore = new(1, 1); + private readonly byte[] singleByteBuffer = new byte[1]; + private AesCfb8Stream? encryptedStream; + private Stream readStream; + private Stream writeStream; + private bool encrypted = false; /// /// Initialize a new SocketWrapper @@ -19,7 +30,9 @@ public class SocketWrapper /// TcpClient connected to the server public SocketWrapper(TcpClient client) { - c = client; + this.client = client; + networkStream = client.GetStream(); + readStream = writeStream = networkStream; } /// @@ -29,7 +42,7 @@ public SocketWrapper(TcpClient client) /// Silently dropped connection can only be detected by attempting to read/write data public bool IsConnected() { - return c.Client is not null && c.Connected; + return client.Client is not null && client.Connected; } /// @@ -38,7 +51,7 @@ public bool IsConnected() /// TRUE if data is available to read public bool HasDataAvailable() { - return c.Client.Available > 0; + return client.Client.Available > 0; } /// @@ -49,23 +62,21 @@ public void SwitchToEncrypted(byte[] secretKey) { if (encrypted) throw new InvalidOperationException("Stream is already encrypted!?"); - s = new AesCfb8Stream(c.GetStream(), secretKey); + encryptedStream = new AesCfb8Stream(networkStream, secretKey); + readStream = writeStream = encryptedStream; encrypted = true; } - /// - /// Network reading method. Read bytes from the socket or encrypted socket. - /// - private void Receive(byte[] buffer, int start, int offset, SocketFlags f) + public byte ReadByteRAW() { - int read = 0; - while (read < offset) - { - if (encrypted) - read += s!.Read(buffer, start + read, offset - read); - else - read += c.Client.Receive(buffer, start + read, offset - read, f); - } + readStream.ReadExactly(singleByteBuffer); + return singleByteBuffer[0]; + } + + public async ValueTask ReadByteRAWAsync(CancellationToken cancellationToken) + { + await readStream.ReadExactlyAsync(singleByteBuffer.AsMemory(0, 1), cancellationToken); + return singleByteBuffer[0]; } /// @@ -77,13 +88,45 @@ public byte[] ReadDataRAW(int length) { if (length > 0) { - byte[] cache = new byte[length]; - Receive(cache, 0, length, SocketFlags.None); + byte[] cache = GC.AllocateUninitializedArray(length); + readStream.ReadExactly(cache); + return cache; + } + return Array.Empty(); + } + + public async Task ReadDataRAWAsync(int length, CancellationToken cancellationToken) + { + if (length > 0) + { + byte[] cache = GC.AllocateUninitializedArray(length); + await readStream.ReadExactlyAsync(cache.AsMemory(0, length), cancellationToken); return cache; } + return Array.Empty(); } + internal IncomingPacket GetNextPacket(int compressionThreshold, DataTypes dataTypes) + { + int packetLength = ReadNextVarIntRaw(); + using PacketReadStream packetStream = new(readStream, packetLength); + byte[] payload = ReadPacketPayload(packetStream, compressionThreshold); + var packetData = new PacketReader(payload); + int packetId = dataTypes.ReadNextVarInt(packetData); + return new(packetId, packetData.CopyRemaining()); + } + + internal async Task GetNextPacketAsync(int compressionThreshold, DataTypes dataTypes, CancellationToken cancellationToken) + { + int packetLength = await ReadNextVarIntRawAsync(cancellationToken); + await using PacketReadStream packetStream = new(readStream, packetLength); + byte[] payload = await ReadPacketPayloadAsync(packetStream, compressionThreshold, cancellationToken); + var packetData = new PacketReader(payload); + int packetId = dataTypes.ReadNextVarInt(packetData); + return new(packetId, packetData.CopyRemaining()); + } + /// /// Send raw data to the server. /// @@ -93,10 +136,33 @@ public void SendDataRAW(byte[] buffer) if (!IsConnected()) throw new SocketException((int)SocketError.NotConnected); - if (encrypted) - s!.Write(buffer, 0, buffer.Length); - else - c.Client.Send(buffer); + sendSemaphore.Wait(); + try + { + writeStream.Write(buffer, 0, buffer.Length); + writeStream.Flush(); + } + finally + { + sendSemaphore.Release(); + } + } + + public async Task SendDataRAWAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken) + { + if (!IsConnected()) + throw new SocketException((int)SocketError.NotConnected); + + await sendSemaphore.WaitAsync(cancellationToken); + try + { + await writeStream.WriteAsync(buffer, cancellationToken); + await writeStream.FlushAsync(cancellationToken); + } + finally + { + sendSemaphore.Release(); + } } /// @@ -106,12 +172,117 @@ public void Disconnect() { try { - c.Close(); + encryptedStream?.Dispose(); + client.Close(); } catch (SocketException) { } catch (System.IO.IOException) { } catch (NullReferenceException) { } catch (ObjectDisposedException) { } } + + private int ReadNextVarIntRaw() + { + int value = 0; + int position = 0; + + while (true) + { + byte current = ReadByteRAW(); + value |= (current & 0x7F) << position++ * 7; + if (position > 5) + throw new OverflowException("VarInt too big"); + if ((current & 0x80) != 0x80) + return value; + } + } + + private async Task ReadNextVarIntRawAsync(CancellationToken cancellationToken) + { + int value = 0; + int position = 0; + + while (true) + { + byte current = await ReadByteRAWAsync(cancellationToken); + value |= (current & 0x7F) << position++ * 7; + if (position > 5) + throw new OverflowException("VarInt too big"); + if ((current & 0x80) != 0x80) + return value; + } + } + + private static byte[] ReadPacketPayload(PacketReadStream packetStream, int compressionThreshold) + { + if (compressionThreshold >= 0) + { + int uncompressedLength = ReadNextVarIntRaw(packetStream); + if (uncompressedLength > 0) + { + using ZLibStream zlibStream = new(packetStream, CompressionMode.Decompress, leaveOpen: true); + byte[] payload = GC.AllocateUninitializedArray(uncompressedLength); + zlibStream.ReadExactly(payload); + return payload; + } + } + + return packetStream.ReadRemaining(); + } + + private static async Task ReadPacketPayloadAsync(PacketReadStream packetStream, int compressionThreshold, CancellationToken cancellationToken) + { + if (compressionThreshold >= 0) + { + int uncompressedLength = await ReadNextVarIntRawAsync(packetStream, cancellationToken); + if (uncompressedLength > 0) + { + await using ZLibStream zlibStream = new(packetStream, CompressionMode.Decompress, leaveOpen: true); + byte[] payload = GC.AllocateUninitializedArray(uncompressedLength); + await zlibStream.ReadExactlyAsync(payload.AsMemory(0, uncompressedLength), cancellationToken); + return payload; + } + } + + return await packetStream.ReadRemainingAsync(cancellationToken); + } + + private static int ReadNextVarIntRaw(Stream stream) + { + int value = 0; + int position = 0; + + while (true) + { + int current = stream.ReadByte(); + if (current < 0) + throw new IOException("Connection closed."); + + value |= (current & 0x7F) << position++ * 7; + if (position > 5) + throw new OverflowException("VarInt too big"); + if ((current & 0x80) != 0x80) + return value; + } + } + + private static async Task ReadNextVarIntRawAsync(Stream stream, CancellationToken cancellationToken) + { + byte[] buffer = new byte[1]; + int value = 0; + int position = 0; + + while (true) + { + await stream.ReadExactlyAsync(buffer.AsMemory(0, 1), cancellationToken); + byte current = buffer[0]; + + value |= (current & 0x7F) << position++ * 7; + if (position > 5) + throw new OverflowException("VarInt too big"); + if ((current & 0x80) != 0x80) + return value; + } + } } } diff --git a/MinecraftClient/Protocol/Message/ChatParser.cs b/MinecraftClient/Protocol/Message/ChatParser.cs index e0954ee2ca..b0dcbc8fd8 100644 --- a/MinecraftClient/Protocol/Message/ChatParser.cs +++ b/MinecraftClient/Protocol/Message/ChatParser.cs @@ -7,6 +7,7 @@ using System.Text; using System.Text.Json; using System.Text.RegularExpressions; +using System.Threading; using System.Threading.Tasks; using static MinecraftClient.Settings; @@ -231,6 +232,8 @@ private static string Color2tag(string colorname) /// Specify whether translation rules have been loaded /// private static bool RulesInitialized = false; + private static readonly Lock RulesInitializationLock = new(); + private static Task? RulesRefreshTask = null; /// /// Set of translation rules for formatting text @@ -243,23 +246,25 @@ private static string Color2tag(string colorname) /// public static void InitTranslations() { - if (!RulesInitialized) + lock (RulesInitializationLock) { - InitRules(); + if (RulesInitialized) + return; + RulesInitialized = true; + RulesRefreshTask = InitRulesAsync(); + _ = ObserveInitRulesAsync(RulesRefreshTask); } } /// - /// Internal rule initialization method. Looks for local rule file or download it from Mojang asset servers. + /// Internal rule initialization method. Looks for local rule file and refreshes it from Mojang asset servers if needed. /// - private static void InitRules() + private static async Task InitRulesAsync() { if (Config.Main.Advanced.Language == "en_us") { - TranslationRules = - JsonSerializer.Deserialize>( - (byte[])MinecraftAssets.ResourceManager.GetObject("en_us.json")!)!; + TranslationRules = LoadEmbeddedTranslationRules(); return; } @@ -269,21 +274,9 @@ private static void InitRules() string languageFilePath = "lang" + Path.DirectorySeparatorChar + Config.Main.Advanced.Language + ".json"; - // Load the external dictionary of translation rules or display an error message - if (File.Exists(languageFilePath)) - { - try - { - TranslationRules = - JsonSerializer.Deserialize>(File.OpenRead(languageFilePath))!; - } - catch (IOException) - { - } - catch (JsonException) - { - } - } + if (TryLoadTranslationRulesFromFile(languageFilePath, out Dictionary? translationRules)) + TranslationRules = translationRules; + else TranslationRules = LoadEmbeddedTranslationRules(); if (TranslationRules.TryGetValue("Version", out string? version) && version == Settings.TranslationsFile_Version) @@ -296,14 +289,12 @@ private static void InitRules() // Try downloading language file from Mojang's servers? ConsoleIO.WriteLineFormatted( "§8" + string.Format(Translations.chat_download, Config.Main.Advanced.Language)); - HttpClient httpClient = new(); + using HttpClient httpClient = new(); try { - Task fetch_index = httpClient.GetStringAsync(TranslationsFile_Website_Index); - fetch_index.Wait(); - Match match = Regex.Match(fetch_index.Result, + string fetchIndex = await httpClient.GetStringAsync(TranslationsFile_Website_Index); + Match match = Regex.Match(fetchIndex, $"minecraft/lang/{Config.Main.Advanced.Language}.json" + @""":\s\{""hash"":\s""([\d\w]{40})"""); - fetch_index.Dispose(); if (match.Success && match.Groups.Count == 2) { string hash = match.Groups[1].Value; @@ -312,22 +303,19 @@ private static void InitRules() ConsoleIO.WriteLineFormatted( string.Format(Translations.chat_request, translation_file_location)); - Task?> fetckFileTask = - httpClient.GetFromJsonAsync>(translation_file_location); - fetckFileTask.Wait(); - if (fetckFileTask.Result is not null && fetckFileTask.Result.Count > 0) + Dictionary? fetchedFile = + await httpClient.GetFromJsonAsync>(translation_file_location); + if (fetchedFile is not null && fetchedFile.Count > 0) { - TranslationRules = fetckFileTask.Result; + TranslationRules = fetchedFile; TranslationRules["Version"] = TranslationsFile_Version; - File.WriteAllText(languageFilePath, + await File.WriteAllTextAsync(languageFilePath, JsonSerializer.Serialize(TranslationRules, typeof(Dictionary)), Encoding.UTF8); ConsoleIO.WriteLineFormatted("§8" + string.Format(Translations.chat_done, languageFilePath)); return; } - - fetckFileTask.Dispose(); } else { @@ -350,15 +338,50 @@ private static void InitRules() if (Config.Logging.DebugMessages && !string.IsNullOrEmpty(e.StackTrace)) ConsoleIO.WriteLine(e.StackTrace); } - finally + TranslationRules = LoadEmbeddedTranslationRules(); + ConsoleIO.WriteLine(Translations.chat_use_default); + } + + private static async Task ObserveInitRulesAsync(Task initRulesTask) + { + try + { + await initRulesTask; + } + catch (Exception e) { - httpClient.Dispose(); + TranslationRules = LoadEmbeddedTranslationRules(); + if (Config.Logging.DebugMessages) + ConsoleIO.WriteLine(e.ToString()); } + } - TranslationRules = - JsonSerializer.Deserialize>( - (byte[])MinecraftAssets.ResourceManager.GetObject("en_us.json")!)!; - ConsoleIO.WriteLine(Translations.chat_use_default); + private static Dictionary LoadEmbeddedTranslationRules() + { + return JsonSerializer.Deserialize>( + (byte[])MinecraftAssets.ResourceManager.GetObject("en_us.json")!)!; + } + + private static bool TryLoadTranslationRulesFromFile(string languageFilePath, out Dictionary? translationRules) + { + translationRules = null; + if (!File.Exists(languageFilePath)) + return false; + + try + { + translationRules = + JsonSerializer.Deserialize>(File.OpenRead(languageFilePath))!; + return translationRules is not null; + } + catch (IOException) + { + return false; + } + catch (JsonException) + { + return false; + } } public static string? TranslateString(string rulename) @@ -379,10 +402,7 @@ private static void InitRules() private static string TranslateString(string rulename, List using_data) { if (!RulesInitialized) - { - InitRules(); - RulesInitialized = true; - } + InitTranslations(); if (TranslationRules.ContainsKey(rulename)) { @@ -617,4 +637,4 @@ private static string NbtToString(Dictionary nbt, string formatt return formatting + message + extraBuilder.ToString(); } } -} \ No newline at end of file +} diff --git a/MinecraftClient/Protocol/MicrosoftAuthentication.cs b/MinecraftClient/Protocol/MicrosoftAuthentication.cs index 4c84ce477e..1389fd8870 100644 --- a/MinecraftClient/Protocol/MicrosoftAuthentication.cs +++ b/MinecraftClient/Protocol/MicrosoftAuthentication.cs @@ -6,6 +6,7 @@ using System.Linq; using System.Runtime.InteropServices; using System.Threading; +using System.Threading.Tasks; namespace MinecraftClient.Protocol { @@ -37,7 +38,13 @@ public static LoginResponse RequestAccessToken(string code) { string postData = "client_id={0}&grant_type=authorization_code&redirect_uri=https%3A%2F%2Fmccteam.github.io%2Fredirect.html&code={1}"; postData = string.Format(postData, clientId, code); - return RequestToken(postData); + return RequestTokenAsync(postData).GetAwaiter().GetResult(); + } + + public static Task RequestAccessTokenAsync(string code) + { + string postData = "client_id={0}&grant_type=authorization_code&redirect_uri=https%3A%2F%2Fmccteam.github.io%2Fredirect.html&code={1}"; + return RequestTokenAsync(string.Format(postData, clientId, code)); } /// @@ -49,7 +56,13 @@ public static LoginResponse RefreshAccessToken(string refreshToken) { string postData = "client_id={0}&grant_type=refresh_token&redirect_uri=https%3A%2F%2Fmccteam.github.io%2Fredirect.html&refresh_token={1}"; postData = string.Format(postData, clientId, refreshToken); - return RequestToken(postData); + return RequestTokenAsync(postData).GetAwaiter().GetResult(); + } + + public static Task RefreshAccessTokenAsync(string refreshToken) + { + string postData = "client_id={0}&grant_type=refresh_token&redirect_uri=https%3A%2F%2Fmccteam.github.io%2Fredirect.html&refresh_token={1}"; + return RequestTokenAsync(string.Format(postData, clientId, refreshToken)); } /// @@ -58,6 +71,11 @@ public static LoginResponse RefreshAccessToken(string refreshToken) /// /// Device code response for user to complete authentication public static DeviceCodeResponse RequestDeviceCode() + { + return RequestDeviceCodeAsync().GetAwaiter().GetResult(); + } + + public static async Task RequestDeviceCodeAsync(CancellationToken cancellationToken = default) { string postData = string.Format("client_id={0}&scope=XboxLive.signin%20offline_access%20openid%20email", clientId); @@ -65,7 +83,7 @@ public static DeviceCodeResponse RequestDeviceCode() { UserAgent = "MCC/" + Program.Version }; - var response = request.Post("application/x-www-form-urlencoded", postData); + var response = await request.PostAsync("application/x-www-form-urlencoded", postData, cancellationToken); var jsonData = Json.ParseJson(response.Body); if (jsonData?["error"] is not null) @@ -93,6 +111,11 @@ public static DeviceCodeResponse RequestDeviceCode() /// Polling interval in seconds /// Login response with access token and refresh token public static LoginResponse PollDeviceCodeToken(string deviceCode, int expiresIn, int interval) + { + return PollDeviceCodeTokenAsync(deviceCode, expiresIn, interval).GetAwaiter().GetResult(); + } + + public static async Task PollDeviceCodeTokenAsync(string deviceCode, int expiresIn, int interval, CancellationToken cancellationToken = default) { // Per OAuth 2.0 device code spec, server may respond with "slow_down" requiring // the client to increase its polling interval by this amount @@ -107,13 +130,13 @@ public static LoginResponse PollDeviceCodeToken(string deviceCode, int expiresIn while (stopwatch.Elapsed.TotalSeconds < expiresIn) { - Thread.Sleep(pollInterval * 1000); + await Task.Delay(TimeSpan.FromSeconds(pollInterval), cancellationToken); var request = new ProxiedWebRequest(tokenUrl) { UserAgent = "MCC/" + Program.Version }; - var response = request.Post("application/x-www-form-urlencoded", postData); + var response = await request.PostAsync("application/x-www-form-urlencoded", postData, cancellationToken); var jsonData = Json.ParseJson(response.Body); if (jsonData?["error"] is not null) @@ -173,12 +196,17 @@ public static LoginResponse PollDeviceCodeToken(string deviceCode, int expiresIn /// Complete POST data for the request /// private static LoginResponse RequestToken(string postData) + { + return RequestTokenAsync(postData).GetAwaiter().GetResult(); + } + + private static async Task RequestTokenAsync(string postData, CancellationToken cancellationToken = default) { var request = new ProxiedWebRequest(tokenUrl) { UserAgent = "MCC/" + Program.Version }; - var response = request.Post("application/x-www-form-urlencoded", postData); + var response = await request.PostAsync("application/x-www-form-urlencoded", postData, cancellationToken); var jsonData = Json.ParseJson(response.Body); // Error handling @@ -271,6 +299,11 @@ static class XboxLive /// /// public static XblAuthenticateResponse XblAuthenticate(Microsoft.LoginResponse loginResponse) + { + return XblAuthenticateAsync(loginResponse).GetAwaiter().GetResult(); + } + + public static async Task XblAuthenticateAsync(Microsoft.LoginResponse loginResponse, CancellationToken cancellationToken = default) { var request = new ProxiedWebRequest(xbl) { @@ -291,7 +324,7 @@ public static XblAuthenticateResponse XblAuthenticate(Microsoft.LoginResponse lo + "\"RelyingParty\": \"http://auth.xboxlive.com\"," + "\"TokenType\": \"JWT\"" + "}"; - var response = request.Post("application/json", payload); + var response = await request.PostAsync("application/json", payload, cancellationToken); if (Settings.Config.Logging.DebugMessages) { ConsoleIO.WriteLine(response.ToString()); @@ -321,6 +354,11 @@ public static XblAuthenticateResponse XblAuthenticate(Microsoft.LoginResponse lo /// /// public static XSTSAuthenticateResponse XSTSAuthenticate(XblAuthenticateResponse xblResponse) + { + return XSTSAuthenticateAsync(xblResponse).GetAwaiter().GetResult(); + } + + public static async Task XSTSAuthenticateAsync(XblAuthenticateResponse xblResponse, CancellationToken cancellationToken = default) { var request = new ProxiedWebRequest(xsts) { @@ -339,7 +377,7 @@ public static XSTSAuthenticateResponse XSTSAuthenticate(XblAuthenticateResponse + "\"RelyingParty\": \"rp://api.minecraftservices.com/\"," + "\"TokenType\": \"JWT\"" + "}"; - var response = request.Post("application/json", payload); + var response = await request.PostAsync("application/json", payload, cancellationToken); if (Settings.Config.Logging.DebugMessages) { ConsoleIO.WriteLine(response.ToString()); @@ -404,6 +442,11 @@ static class MinecraftWithXbox /// /// public static string LoginWithXbox(string userHash, string xstsToken) + { + return LoginWithXboxAsync(userHash, xstsToken).GetAwaiter().GetResult(); + } + + public static async Task LoginWithXboxAsync(string userHash, string xstsToken, CancellationToken cancellationToken = default) { var request = new ProxiedWebRequest(loginWithXbox) { @@ -411,7 +454,7 @@ public static string LoginWithXbox(string userHash, string xstsToken) }; string payload = "{\"identityToken\": \"XBL3.0 x=" + userHash + ";" + xstsToken + "\"}"; - var response = request.Post("application/json", payload); + var response = await request.PostAsync("application/json", payload, cancellationToken); if (Settings.Config.Logging.DebugMessages) { @@ -430,10 +473,15 @@ public static string LoginWithXbox(string userHash, string xstsToken) /// /// True if the user own the game public static bool UserHasGame(string accessToken) + { + return UserHasGameAsync(accessToken).GetAwaiter().GetResult(); + } + + public static async Task UserHasGameAsync(string accessToken, CancellationToken cancellationToken = default) { var request = new ProxiedWebRequest(ownership); request.Headers.Add("Authorization", string.Format("Bearer {0}", accessToken)); - var response = request.Get(); + var response = await request.GetAsync(cancellationToken); if (Settings.Config.Logging.DebugMessages) { @@ -446,10 +494,15 @@ public static bool UserHasGame(string accessToken) } public static UserProfile GetUserProfile(string accessToken) + { + return GetUserProfileAsync(accessToken).GetAwaiter().GetResult(); + } + + public static async Task GetUserProfileAsync(string accessToken, CancellationToken cancellationToken = default) { var request = new ProxiedWebRequest(profile); request.Headers.Add("Authorization", string.Format("Bearer {0}", accessToken)); - var response = request.Get(); + var response = await request.GetAsync(cancellationToken); if (Settings.Config.Logging.DebugMessages) { diff --git a/MinecraftClient/Protocol/MojangAPI.cs b/MinecraftClient/Protocol/MojangAPI.cs index b05e6a7144..9e91c9ddbc 100644 --- a/MinecraftClient/Protocol/MojangAPI.cs +++ b/MinecraftClient/Protocol/MojangAPI.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Net.Http; +using System.Threading; using System.Threading.Tasks; /// !!! ATTENTION !!! @@ -103,14 +104,15 @@ private static ServiceStatus StringToServiceStatus(string s) /// /// Playername /// UUID as string - public static string NameToUuid(string name) + public static string NameToUuid(string name) => + NameToUuidAsync(name).GetAwaiter().GetResult(); + + public static async Task NameToUuidAsync(string name, CancellationToken cancellationToken = default) { try { - Task fetchTask = httpClient.GetStringAsync("https://api.mojang.com/users/profiles/minecraft/" + name); - fetchTask.Wait(); - string result = Json.ParseJson(fetchTask.Result)!["id"]!.GetStringValue(); - fetchTask.Dispose(); + string responseBody = await httpClient.GetStringAsync("https://api.mojang.com/users/profiles/minecraft/" + name, cancellationToken); + string result = Json.ParseJson(responseBody)!["id"]!.GetStringValue(); return result; } catch (Exception) { return string.Empty; } @@ -121,15 +123,16 @@ public static string NameToUuid(string name) /// /// UUID of a player /// Players UUID - public static string UuidToCurrentName(string uuid) + public static string UuidToCurrentName(string uuid) => + UuidToCurrentNameAsync(uuid).GetAwaiter().GetResult(); + + public static async Task UuidToCurrentNameAsync(string uuid, CancellationToken cancellationToken = default) { // Perform web request try { - Task fetchTask = httpClient.GetStringAsync("https://api.mojang.com/user/profiles/" + uuid + "/names"); - fetchTask.Wait(); - var nameChanges = Json.ParseJson(fetchTask.Result)!.AsArray(); - fetchTask.Dispose(); + string responseBody = await httpClient.GetStringAsync("https://api.mojang.com/user/profiles/" + uuid + "/names", cancellationToken); + var nameChanges = Json.ParseJson(responseBody)!.AsArray(); // Names are sorted from past to most recent. We need to get the last name in the list return nameChanges[^1]!["name"]!.GetStringValue(); @@ -142,7 +145,10 @@ public static string UuidToCurrentName(string uuid) /// /// UUID of a player /// Name history, as a dictionary - public static Dictionary UuidToNameHistory(string uuid) + public static Dictionary UuidToNameHistory(string uuid) => + UuidToNameHistoryAsync(uuid).GetAwaiter().GetResult(); + + public static async Task> UuidToNameHistoryAsync(string uuid, CancellationToken cancellationToken = default) { Dictionary tempDict = new(); System.Text.Json.Nodes.JsonArray jsonDataList; @@ -150,10 +156,8 @@ public static Dictionary UuidToNameHistory(string uuid) // Perform web request try { - Task fetchTask = httpClient.GetStringAsync("https://api.mojang.com/user/profiles/" + uuid + "/names"); - fetchTask.Wait(); - jsonDataList = Json.ParseJson(fetchTask.Result)!.AsArray(); - fetchTask.Dispose(); + string responseBody = await httpClient.GetStringAsync("https://api.mojang.com/user/profiles/" + uuid + "/names", cancellationToken); + jsonDataList = Json.ParseJson(responseBody)!.AsArray(); } catch (Exception) { return tempDict; } @@ -181,17 +185,18 @@ public static Dictionary UuidToNameHistory(string uuid) /// Get the Mojang API status /// /// Dictionary of the Mojang services - public static MojangServiceStatus GetMojangServiceStatus() + public static MojangServiceStatus GetMojangServiceStatus() => + GetMojangServiceStatusAsync().GetAwaiter().GetResult(); + + public static async Task GetMojangServiceStatusAsync(CancellationToken cancellationToken = default) { System.Text.Json.Nodes.JsonArray jsonDataList; // Perform web request try { - Task fetchTask = httpClient.GetStringAsync("https://status.mojang.com/check"); - fetchTask.Wait(); - jsonDataList = Json.ParseJson(fetchTask.Result)!.AsArray(); - fetchTask.Dispose(); + string responseBody = await httpClient.GetStringAsync("https://status.mojang.com/check", cancellationToken); + jsonDataList = Json.ParseJson(responseBody)!.AsArray(); } catch (Exception) { @@ -215,7 +220,10 @@ public static MojangServiceStatus GetMojangServiceStatus() /// /// UUID of a player /// Dictionary with a link to the skin and cape of a player. - public static SkinInfo GetSkinInfo(string uuid) + public static SkinInfo GetSkinInfo(string uuid) => + GetSkinInfoAsync(uuid).GetAwaiter().GetResult(); + + public static async Task GetSkinInfoAsync(string uuid, CancellationToken cancellationToken = default) { System.Text.Json.Nodes.JsonObject textureObj; string base64SkinInfo; @@ -224,11 +232,9 @@ public static SkinInfo GetSkinInfo(string uuid) // Perform web request try { - Task fetchTask = httpClient.GetStringAsync("https://sessionserver.mojang.com/session/minecraft/profile/" + uuid); - fetchTask.Wait(); + string responseBody = await httpClient.GetStringAsync("https://sessionserver.mojang.com/session/minecraft/profile/" + uuid, cancellationToken); // Obtain the Base64 encoded skin information from the API. Discard the rest, since it can be obtained easier through other requests. - base64SkinInfo = Json.ParseJson(fetchTask.Result)!["properties"]![0]!["value"]!.GetStringValue(); - fetchTask.Dispose(); + base64SkinInfo = Json.ParseJson(responseBody)!["properties"]![0]!["value"]!.GetStringValue(); } catch (Exception) { return new SkinInfo(); } diff --git a/MinecraftClient/Protocol/PacketPipeline/IncomingPacket.cs b/MinecraftClient/Protocol/PacketPipeline/IncomingPacket.cs new file mode 100644 index 0000000000..8d5ae34079 --- /dev/null +++ b/MinecraftClient/Protocol/PacketPipeline/IncomingPacket.cs @@ -0,0 +1,13 @@ +using System; + +namespace MinecraftClient.Protocol.PacketPipeline; + +internal readonly record struct IncomingPacket(int PacketId, byte[] Payload) +{ + public PacketReader CreateReader() + { + return new PacketReader(Payload); + } + + public ReadOnlySpan PayloadSpan => Payload; +} diff --git a/MinecraftClient/Protocol/PacketPipeline/PacketReadStream.cs b/MinecraftClient/Protocol/PacketPipeline/PacketReadStream.cs new file mode 100644 index 0000000000..167eb03b77 --- /dev/null +++ b/MinecraftClient/Protocol/PacketPipeline/PacketReadStream.cs @@ -0,0 +1,203 @@ +using System; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace MinecraftClient.Protocol.PacketPipeline; + +internal sealed class PacketReadStream : Stream +{ + private const int DrainBufferSize = 4096; + + private readonly Stream baseStream; + private readonly byte[] singleByteBuffer = new byte[1]; + private int remainingLength; + + public PacketReadStream(Stream baseStream, int packetLength) + { + ArgumentNullException.ThrowIfNull(baseStream); + if (packetLength < 0) + throw new ArgumentOutOfRangeException(nameof(packetLength)); + + this.baseStream = baseStream; + remainingLength = packetLength; + } + + public int RemainingLength => remainingLength; + + public override bool CanRead => true; + public override bool CanSeek => false; + public override bool CanWrite => false; + public override long Length => throw new NotSupportedException(); + public override long Position + { + get => throw new NotSupportedException(); + set => throw new NotSupportedException(); + } + + public override int Read(byte[] buffer, int offset, int count) + { + if (remainingLength == 0) + return 0; + + int readLength = Math.Min(count, remainingLength); + int read = baseStream.Read(buffer, offset, readLength); + remainingLength -= read; + return read; + } + + public override int Read(Span buffer) + { + if (remainingLength == 0) + return 0; + + int readLength = Math.Min(buffer.Length, remainingLength); + int read = baseStream.Read(buffer[..readLength]); + remainingLength -= read; + return read; + } + + public override int ReadByte() + { + if (remainingLength == 0) + return -1; + + int value = baseStream.ReadByte(); + if (value == -1) + return -1; + + remainingLength--; + return value; + } + + public override async ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + if (remainingLength == 0) + return 0; + + int readLength = Math.Min(buffer.Length, remainingLength); + int read = await baseStream.ReadAsync(buffer[..readLength], cancellationToken); + remainingLength -= read; + return read; + } + + public new async ValueTask ReadExactlyAsync(Memory buffer, CancellationToken cancellationToken = default) + { + if (buffer.Length > remainingLength) + throw new OverflowException("Reached the end of the packet."); + + await baseStream.ReadExactlyAsync(buffer, cancellationToken); + remainingLength -= buffer.Length; + } + + public new void ReadExactly(Span buffer) + { + if (buffer.Length > remainingLength) + throw new OverflowException("Reached the end of the packet."); + + baseStream.ReadExactly(buffer); + remainingLength -= buffer.Length; + } + + public byte[] ReadRemaining() + { + if (remainingLength == 0) + return []; + + byte[] buffer = GC.AllocateUninitializedArray(remainingLength); + ReadExactly(buffer); + return buffer; + } + + public async Task ReadRemainingAsync(CancellationToken cancellationToken = default) + { + if (remainingLength == 0) + return []; + + byte[] buffer = GC.AllocateUninitializedArray(remainingLength); + await ReadExactlyAsync(buffer, cancellationToken); + return buffer; + } + + public void DrainRemaining() + { + if (remainingLength == 0) + return; + + byte[] buffer = GC.AllocateUninitializedArray(Math.Min(DrainBufferSize, remainingLength)); + while (remainingLength > 0) + { + int read = baseStream.Read(buffer, 0, Math.Min(buffer.Length, remainingLength)); + if (read <= 0) + throw new EndOfStreamException("Connection closed while draining packet data."); + + remainingLength -= read; + } + } + + public async ValueTask DrainRemainingAsync(CancellationToken cancellationToken = default) + { + if (remainingLength == 0) + return; + + byte[] buffer = GC.AllocateUninitializedArray(Math.Min(DrainBufferSize, remainingLength)); + while (remainingLength > 0) + { + int read = await baseStream.ReadAsync(buffer.AsMemory(0, Math.Min(buffer.Length, remainingLength)), cancellationToken); + if (read <= 0) + throw new EndOfStreamException("Connection closed while draining packet data."); + + remainingLength -= read; + } + } + + public override void Flush() + { + throw new NotSupportedException(); + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotSupportedException(); + } + + public override void SetLength(long value) + { + throw new NotSupportedException(); + } + + public override void Write(byte[] buffer, int offset, int count) + { + throw new NotSupportedException(); + } + + protected override void Dispose(bool disposing) + { + if (disposing && remainingLength > 0) + { + try + { + DrainRemaining(); + } + catch (IOException) { } + catch (ObjectDisposedException) { } + } + + base.Dispose(disposing); + } + + public override async ValueTask DisposeAsync() + { + if (remainingLength > 0) + { + try + { + await DrainRemainingAsync(); + } + catch (IOException) { } + catch (ObjectDisposedException) { } + } + + await base.DisposeAsync(); + } +} diff --git a/MinecraftClient/Protocol/PacketPipeline/PacketReader.cs b/MinecraftClient/Protocol/PacketPipeline/PacketReader.cs new file mode 100644 index 0000000000..1f7fd6ea68 --- /dev/null +++ b/MinecraftClient/Protocol/PacketPipeline/PacketReader.cs @@ -0,0 +1,123 @@ +using System; +using System.Buffers.Binary; + +namespace MinecraftClient.Protocol.PacketPipeline; + +public sealed class PacketReader +{ + private readonly byte[] buffer; + private int position; + + public PacketReader(byte[] buffer) + { + ArgumentNullException.ThrowIfNull(buffer); + this.buffer = buffer; + } + + public int Position => position; + public int RemainingLength => buffer.Length - position; + public ReadOnlySpan RemainingSpan => buffer.AsSpan(position); + public ReadOnlySpan FullSpan => buffer; + + public byte[] GetRawData() => buffer; + + public byte ReadByte() + { + EnsureAvailable(1); + return buffer[position++]; + } + + public byte[] ReadData(int length) + { + if (length == 0) + return []; + + EnsureAvailable(length); + byte[] data = GC.AllocateUninitializedArray(length); + Buffer.BlockCopy(buffer, position, data, 0, length); + position += length; + return data; + } + + public void ReadData(Span destination) + { + if (destination.Length == 0) + return; + + EnsureAvailable(destination.Length); + buffer.AsSpan(position, destination.Length).CopyTo(destination); + position += destination.Length; + } + + public void ReadDataReverse(Span destination) + { + if (destination.Length == 0) + return; + + EnsureAvailable(destination.Length); + for (int i = destination.Length - 1; i >= 0; --i) + destination[i] = buffer[position++]; + } + + public void Skip(int length) + { + if (length < 0) + throw new ArgumentOutOfRangeException(nameof(length)); + + EnsureAvailable(length); + position += length; + } + + public ushort ReadUInt16BigEndian() + { + EnsureAvailable(sizeof(ushort)); + ushort value = BinaryPrimitives.ReadUInt16BigEndian(buffer.AsSpan(position, sizeof(ushort))); + position += sizeof(ushort); + return value; + } + + public short ReadInt16BigEndian() + { + EnsureAvailable(sizeof(short)); + short value = BinaryPrimitives.ReadInt16BigEndian(buffer.AsSpan(position, sizeof(short))); + position += sizeof(short); + return value; + } + + public int ReadInt32BigEndian() + { + EnsureAvailable(sizeof(int)); + int value = BinaryPrimitives.ReadInt32BigEndian(buffer.AsSpan(position, sizeof(int))); + position += sizeof(int); + return value; + } + + public long ReadInt64BigEndian() + { + EnsureAvailable(sizeof(long)); + long value = BinaryPrimitives.ReadInt64BigEndian(buffer.AsSpan(position, sizeof(long))); + position += sizeof(long); + return value; + } + + public byte[] CopyRemaining() + { + return ReadOnlySpanToArray(RemainingSpan); + } + + private void EnsureAvailable(int length) + { + if (RemainingLength < length) + throw new OverflowException("Reached the end of the packet."); + } + + private static byte[] ReadOnlySpanToArray(ReadOnlySpan span) + { + if (span.IsEmpty) + return []; + + byte[] copy = GC.AllocateUninitializedArray(span.Length); + span.CopyTo(copy); + return copy; + } +} diff --git a/MinecraftClient/Protocol/ProtocolHandler.cs b/MinecraftClient/Protocol/ProtocolHandler.cs index a8568f3863..1a9c5b2599 100644 --- a/MinecraftClient/Protocol/ProtocolHandler.cs +++ b/MinecraftClient/Protocol/ProtocolHandler.cs @@ -7,6 +7,8 @@ using System.Net.Sockets; using System.Text; using System.Text.RegularExpressions; +using System.Threading; +using System.Threading.Tasks; using DnsClient; using MinecraftClient.Protocol.Handlers; using MinecraftClient.Protocol.Handlers.Forge; @@ -91,11 +93,30 @@ public static bool MinecraftServiceLookup(ref string domain, ref ushort port) /// TRUE if ping was successful public static bool GetServerInfo(string serverIP, ushort serverPort, ref int protocolversion, ref ForgeInfo? forgeInfo) + { + (bool success, int resolvedProtocolVersion, ForgeInfo? resolvedForgeInfo) = + GetServerInfoAsync(serverIP, serverPort, protocolversion).GetAwaiter().GetResult(); + + if (!success) + return false; + + if (protocolversion != 0 && protocolversion != resolvedProtocolVersion) + ConsoleIO.WriteLineFormatted("§8" + Translations.error_version_different, acceptnewlines: true); + if (protocolversion == 0 && resolvedProtocolVersion <= 1) + ConsoleIO.WriteLineFormatted("§8" + Translations.error_no_version_report, acceptnewlines: true); + if (protocolversion == 0) + protocolversion = resolvedProtocolVersion; + + forgeInfo = resolvedForgeInfo; + return true; + } + + public static async Task<(bool Success, int ProtocolVersion, ForgeInfo? ForgeInfo)> GetServerInfoAsync(string serverIP, ushort serverPort, int protocolversion) { bool success = false; int protocolversionTmp = 0; ForgeInfo? forgeInfoTmp = null; - if (AutoTimeout.Perform(() => + if (await AutoTimeout.PerformAsync(() => { try { @@ -118,19 +139,15 @@ public static bool GetServerInfo(string serverIP, ushort serverPort, ref int pro ? 10 : 30))) { - if (protocolversion != 0 && protocolversion != protocolversionTmp) - ConsoleIO.WriteLineFormatted("§8" + Translations.error_version_different, acceptnewlines: true); - if (protocolversion == 0 && protocolversionTmp <= 1) - ConsoleIO.WriteLineFormatted("§8" + Translations.error_no_version_report, acceptnewlines: true); if (protocolversion == 0) protocolversion = protocolversionTmp; - forgeInfo = forgeInfoTmp; - return success; + + return (success, protocolversion, forgeInfoTmp); } else { ConsoleIO.WriteLineFormatted("§8" + Translations.error_connection_timeout, acceptnewlines: true); - return false; + return (false, protocolversion, forgeInfoTmp); } } @@ -590,24 +607,23 @@ public enum AccountType /// Returns the status of the login (Success, Failure, etc.) public static LoginResult GetLogin(string user, string pass, LoginType type, out SessionToken session) { - if (type == LoginType.microsoft) - { - if (Config.Main.General.Method == LoginMethod.mcc) - return MicrosoftMCCLogin(user, pass, out session); - else - return MicrosoftBrowserLogin(out session, user); - } - else if (type == LoginType.mojang) - { - return MojangLogin(user, pass, out session); - } - else if (type == LoginType.yggdrasil) + var login = GetLoginAsync(user, pass, type).GetAwaiter().GetResult(); + session = login.Session; + return login.Result; + } + + public static Task<(LoginResult Result, SessionToken Session)> GetLoginAsync(string user, string pass, LoginType type, CancellationToken cancellationToken = default) + { + return type switch { - return YggdrasiLogin(user, pass, out session); - } - else - throw new InvalidOperationException( - "Account type must be Mojang or Microsoft or valid authlib 3rd Servers!"); + LoginType.microsoft => Config.Main.General.Method == LoginMethod.mcc + ? MicrosoftMCCLoginAsync(user, pass, cancellationToken) + : MicrosoftBrowserLoginAsync(user, cancellationToken), + LoginType.mojang => MojangLoginAsync(user, pass, cancellationToken), + LoginType.yggdrasil => YggdrasiLoginAsync(user, pass, cancellationToken), + _ => throw new InvalidOperationException( + "Account type must be Mojang or Microsoft or valid authlib 3rd Servers!") + }; } /// @@ -619,20 +635,29 @@ public static LoginResult GetLogin(string user, string pass, LoginType type, out /// private static LoginResult MojangLogin(string user, string pass, out SessionToken session) { - session = new SessionToken() { ClientID = Guid.NewGuid().ToString().Replace("-", "") }; + var login = MojangLoginAsync(user, pass).GetAwaiter().GetResult(); + session = login.Session; + return login.Result; + } + + private static async Task<(LoginResult Result, SessionToken Session)> MojangLoginAsync(string user, string pass, CancellationToken cancellationToken = default) + { + SessionToken session = new() { ClientID = Guid.NewGuid().ToString().Replace("-", "") }; try { - string result = ""; + string result; string json_request = "{\"agent\": { \"name\": \"Minecraft\", \"version\": 1 }, \"username\": \"" + JsonEncode(user) + "\", \"password\": \"" + JsonEncode(pass) + "\", \"clientToken\": \"" + JsonEncode(session.ClientID) + "\" }"; - int code = DoHTTPSPost("authserver.mojang.com", 443, "/authenticate", json_request, ref result); + var response = await DoHTTPSPostAsync("authserver.mojang.com", 443, "/authenticate", json_request, cancellationToken); + int code = response.StatusCode; + result = response.Result; if (code == 200) { if (result.Contains("availableProfiles\":[]}")) { - return LoginResult.NotPremium; + return (LoginResult.NotPremium, session); } else { @@ -645,27 +670,27 @@ private static LoginResult MojangLogin(string user, string pass, out SessionToke session.PlayerID = loginResponse["selectedProfile"]!["id"]!.GetStringValue(); session.PlayerName = loginResponse["selectedProfile"]!["name"]! .GetStringValue(); - return LoginResult.Success; + return (LoginResult.Success, session); } - else return LoginResult.InvalidResponse; + else return (LoginResult.InvalidResponse, session); } } else if (code == 403) { if (result.Contains("UserMigratedException")) { - return LoginResult.AccountMigrated; + return (LoginResult.AccountMigrated, session); } - else return LoginResult.WrongPassword; + else return (LoginResult.WrongPassword, session); } else if (code == 503) { - return LoginResult.ServiceUnavailable; + return (LoginResult.ServiceUnavailable, session); } else { ConsoleIO.WriteLineFormatted("§8" + string.Format(Translations.error_http_code, code)); - return LoginResult.OtherError; + return (LoginResult.OtherError, session); } } catch (System.Security.Authentication.AuthenticationException e) @@ -675,7 +700,7 @@ private static LoginResult MojangLogin(string user, string pass, out SessionToke ConsoleIO.WriteLineFormatted("§8" + e.ToString()); } - return LoginResult.SSLError; + return (LoginResult.SSLError, session); } catch (System.IO.IOException e) { @@ -686,9 +711,9 @@ private static LoginResult MojangLogin(string user, string pass, out SessionToke if (e.Message.Contains("authentication")) { - return LoginResult.SSLError; + return (LoginResult.SSLError, session); } - else return LoginResult.OtherError; + else return (LoginResult.OtherError, session); } catch (Exception e) { @@ -697,28 +722,41 @@ private static LoginResult MojangLogin(string user, string pass, out SessionToke ConsoleIO.WriteLineFormatted("§8" + e.ToString()); } - return LoginResult.OtherError; + return (LoginResult.OtherError, session); } } private static LoginResult YggdrasiLogin(string user, string pass, out SessionToken session) { - session = new SessionToken() { ClientID = Guid.NewGuid().ToString().Replace("-", "") }; + var login = YggdrasiLoginAsync(user, pass).GetAwaiter().GetResult(); + session = login.Session; + return login.Result; + } + + private static async Task<(LoginResult Result, SessionToken Session)> YggdrasiLoginAsync(string user, string pass, CancellationToken cancellationToken = default) + { + SessionToken session = new() { ClientID = Guid.NewGuid().ToString().Replace("-", "") }; try { - string result = ""; + string result; string json_request = "{\"agent\": { \"name\": \"Minecraft\", \"version\": 1 }, \"username\": \"" + JsonEncode(user) + "\", \"password\": \"" + JsonEncode(pass) + "\", \"clientToken\": \"" + JsonEncode(session.ClientID) + "\" }"; - int code = DoHTTPSPost(Config.Main.General.AuthServer.Host, Config.Main.General.AuthServer.Port, - Config.Main.General.AuthServer.AuthlibInjectorAPIPath + "/authserver/authenticate", json_request, - Config.Main.General.AuthServer.UseHttps, ref result); + var response = await DoHTTPSPostAsync( + Config.Main.General.AuthServer.Host, + Config.Main.General.AuthServer.Port, + Config.Main.General.AuthServer.AuthlibInjectorAPIPath + "/authserver/authenticate", + json_request, + Config.Main.General.AuthServer.UseHttps, + cancellationToken); + int code = response.StatusCode; + result = response.Result; if (code == 200) { if (result.Contains("availableProfiles\":[]}")) { - return LoginResult.NotPremium; + return (LoginResult.NotPremium, session); } else { @@ -733,7 +771,7 @@ private static LoginResult YggdrasiLogin(string user, string pass, out SessionTo .GetStringValue(); session.PlayerName = loginResponse["selectedProfile"]!["name"]! .GetStringValue(); - return LoginResult.Success; + return (LoginResult.Success, session); } else { @@ -769,33 +807,33 @@ private static LoginResult YggdrasiLogin(string user, string pass, out SessionTo session.PlayerID = selectedProfile["id"]!.GetStringValue(); session.PlayerName = selectedProfile["name"]!.GetStringValue(); SessionToken currentsession = session; - return GetNewYggdrasilToken(currentsession, out session); + return await GetNewYggdrasilTokenAsync(currentsession, cancellationToken); } else { - return LoginResult.WrongSelection; + return (LoginResult.WrongSelection, session); } } } - else return LoginResult.InvalidResponse; + else return (LoginResult.InvalidResponse, session); } } else if (code == 403) { if (result.Contains("UserMigratedException")) { - return LoginResult.AccountMigrated; + return (LoginResult.AccountMigrated, session); } - else return LoginResult.WrongPassword; + else return (LoginResult.WrongPassword, session); } else if (code == 503) { - return LoginResult.ServiceUnavailable; + return (LoginResult.ServiceUnavailable, session); } else { ConsoleIO.WriteLineFormatted("§8" + string.Format(Translations.error_http_code, code)); - return LoginResult.OtherError; + return (LoginResult.OtherError, session); } } catch (System.Security.Authentication.AuthenticationException e) @@ -805,7 +843,7 @@ private static LoginResult YggdrasiLogin(string user, string pass, out SessionTo ConsoleIO.WriteLineFormatted("§8" + e.ToString()); } - return LoginResult.SSLError; + return (LoginResult.SSLError, session); } catch (System.IO.IOException e) { @@ -816,9 +854,9 @@ private static LoginResult YggdrasiLogin(string user, string pass, out SessionTo if (e.Message.Contains("authentication")) { - return LoginResult.SSLError; + return (LoginResult.SSLError, session); } - else return LoginResult.OtherError; + else return (LoginResult.OtherError, session); } catch (Exception e) { @@ -827,7 +865,7 @@ private static LoginResult YggdrasiLogin(string user, string pass, out SessionTo ConsoleIO.WriteLineFormatted("§8" + e.ToString()); } - return LoginResult.OtherError; + return (LoginResult.OtherError, session); } } @@ -840,10 +878,17 @@ private static LoginResult YggdrasiLogin(string user, string pass, out SessionTo /// /// private static LoginResult MicrosoftMCCLogin(string email, string password, out SessionToken session) + { + var login = MicrosoftMCCLoginAsync(email, password).GetAwaiter().GetResult(); + session = login.Session; + return login.Result; + } + + private static async Task<(LoginResult Result, SessionToken Session)> MicrosoftMCCLoginAsync(string email, string password, CancellationToken cancellationToken = default) { try { - var deviceCode = Microsoft.RequestDeviceCode(); + var deviceCode = await Microsoft.RequestDeviceCodeAsync(cancellationToken); ConsoleIO.WriteLineFormatted(string.Format(Translations.mcc_device_code_prompt, deviceCode.VerificationUri, deviceCode.UserCode)); @@ -852,19 +897,19 @@ private static LoginResult MicrosoftMCCLogin(string email, string password, out ConsoleIO.WriteLineFormatted(Translations.mcc_device_code_waiting); - var msaResponse = Microsoft.PollDeviceCodeToken(deviceCode.DeviceCode, deviceCode.ExpiresIn, deviceCode.Interval); - return MicrosoftLogin(msaResponse, out session); + var msaResponse = await Microsoft.PollDeviceCodeTokenAsync(deviceCode.DeviceCode, deviceCode.ExpiresIn, deviceCode.Interval, cancellationToken); + return await MicrosoftLoginAsync(msaResponse, cancellationToken); } catch (Exception e) { - session = new SessionToken() { ClientID = Guid.NewGuid().ToString().Replace("-", "") }; + SessionToken session = new() { ClientID = Guid.NewGuid().ToString().Replace("-", "") }; ConsoleIO.WriteLineFormatted("§cMicrosoft authenticate failed: " + e.Message); if (Settings.Config.Logging.DebugMessages) { ConsoleIO.WriteLineFormatted("§c" + e.StackTrace); } - return LoginResult.OtherError; + return (LoginResult.OtherError, session); } } @@ -879,6 +924,13 @@ private static LoginResult MicrosoftMCCLogin(string email, string password, out /// /// public static LoginResult MicrosoftBrowserLogin(out SessionToken session, string loginHint = "") + { + var login = MicrosoftBrowserLoginAsync(loginHint).GetAwaiter().GetResult(); + session = login.Session; + return login.Result; + } + + public static async Task<(LoginResult Result, SessionToken Session)> MicrosoftBrowserLoginAsync(string loginHint = "", CancellationToken cancellationToken = default) { if (string.IsNullOrEmpty(loginHint)) Microsoft.OpenBrowser(Microsoft.SignInUrl); @@ -891,40 +943,54 @@ public static LoginResult MicrosoftBrowserLogin(out SessionToken session, string string code = ConsoleIO.ReadLine(); ConsoleIO.WriteLine(string.Format(Translations.mcc_connecting, "Microsoft")); - var msaResponse = Microsoft.RequestAccessToken(code); - return MicrosoftLogin(msaResponse, out session); + var msaResponse = await Microsoft.RequestAccessTokenAsync(code); + return await MicrosoftLoginAsync(msaResponse, cancellationToken); } public static LoginResult MicrosoftLoginRefresh(string refreshToken, out SessionToken session) { - var msaResponse = Microsoft.RefreshAccessToken(refreshToken); - return MicrosoftLogin(msaResponse, out session); + var login = MicrosoftLoginRefreshAsync(refreshToken).GetAwaiter().GetResult(); + session = login.Session; + return login.Result; + } + + public static async Task<(LoginResult Result, SessionToken Session)> MicrosoftLoginRefreshAsync(string refreshToken, CancellationToken cancellationToken = default) + { + var msaResponse = await Microsoft.RefreshAccessTokenAsync(refreshToken); + return await MicrosoftLoginAsync(msaResponse, cancellationToken); } private static LoginResult MicrosoftLogin(Microsoft.LoginResponse msaResponse, out SessionToken session) { - session = new SessionToken() { ClientID = Guid.NewGuid().ToString().Replace("-", "") }; + var login = MicrosoftLoginAsync(msaResponse).GetAwaiter().GetResult(); + session = login.Session; + return login.Result; + } + + private static async Task<(LoginResult Result, SessionToken Session)> MicrosoftLoginAsync(Microsoft.LoginResponse msaResponse, CancellationToken cancellationToken = default) + { + SessionToken session = new() { ClientID = Guid.NewGuid().ToString().Replace("-", "") }; try { - var xblResponse = XboxLive.XblAuthenticate(msaResponse); - var xsts = XboxLive.XSTSAuthenticate(xblResponse); // Might throw even password correct + var xblResponse = await XboxLive.XblAuthenticateAsync(msaResponse, cancellationToken); + var xsts = await XboxLive.XSTSAuthenticateAsync(xblResponse, cancellationToken); // Might throw even password correct - string accessToken = MinecraftWithXbox.LoginWithXbox(xsts.UserHash, xsts.Token); - bool hasGame = MinecraftWithXbox.UserHasGame(accessToken); + string accessToken = await MinecraftWithXbox.LoginWithXboxAsync(xsts.UserHash, xsts.Token, cancellationToken); + bool hasGame = await MinecraftWithXbox.UserHasGameAsync(accessToken, cancellationToken); if (hasGame) { - var profile = MinecraftWithXbox.GetUserProfile(accessToken); + var profile = await MinecraftWithXbox.GetUserProfileAsync(accessToken, cancellationToken); session.PlayerName = profile.UserName; session.PlayerID = profile.UUID; session.ID = accessToken; session.RefreshToken = msaResponse.RefreshToken; InternalConfig.Account.Login = msaResponse.Email; - return LoginResult.Success; + return (LoginResult.Success, session); } else { - return LoginResult.NotPremium; + return (LoginResult.NotPremium, session); } } catch (Exception e) @@ -935,7 +1001,7 @@ private static LoginResult MicrosoftLogin(Microsoft.LoginResponse msaResponse, o ConsoleIO.WriteLineFormatted("§c" + e.StackTrace); } - return LoginResult.WrongPassword; // Might not always be wrong password + return (LoginResult.WrongPassword, session); // Might not always be wrong password } } @@ -1075,6 +1141,52 @@ public static LoginResult GetNewYggdrasilToken(SessionToken currentsession, out } } + public static async Task<(LoginResult Result, SessionToken Session)> GetNewYggdrasilTokenAsync(SessionToken currentsession, CancellationToken cancellationToken = default) + { + SessionToken session = new(); + try + { + string json_request = "{ \"accessToken\": \"" + JsonEncode(currentsession.ID) + + "\", \"clientToken\": \"" + JsonEncode(currentsession.ClientID) + + "\", \"selectedProfile\": { \"id\": \"" + JsonEncode(currentsession.PlayerID) + + "\", \"name\": \"" + JsonEncode(currentsession.PlayerName) + "\" } }"; + var response = await DoHTTPSPostAsync( + Config.Main.General.AuthServer.Host, + Config.Main.General.AuthServer.Port, + Config.Main.General.AuthServer.AuthlibInjectorAPIPath + "/authserver/refresh", + json_request, + Config.Main.General.AuthServer.UseHttps, + cancellationToken); + string result = response.Result; + int code = response.StatusCode; + if (code == 200) + { + var loginResponse = Json.ParseJson(result); + if (loginResponse?["accessToken"] is not null + && loginResponse["selectedProfile"]?["id"] is not null + && loginResponse["selectedProfile"]?["name"] is not null) + { + session.ID = loginResponse["accessToken"]!.GetStringValue(); + session.PlayerID = loginResponse["selectedProfile"]!["id"]!.GetStringValue(); + session.PlayerName = loginResponse["selectedProfile"]!["name"]!.GetStringValue(); + return (LoginResult.Success, session); + } + + return (LoginResult.InvalidResponse, session); + } + + if (code == 403 && result.Contains("InvalidToken")) + return (LoginResult.InvalidToken, session); + + ConsoleIO.WriteLineFormatted("§8" + string.Format(Translations.error_auth, code)); + return (LoginResult.OtherError, session); + } + catch + { + return (LoginResult.OtherError, session); + } + } + /// /// Check session using Mojang's Yggdrasil authentication scheme. Allows to join an online-mode server /// @@ -1108,6 +1220,43 @@ public static bool SessionCheck(string uuid, string accesstoken, string serverha } } + public static async Task SessionCheckAsync(string uuid, string accesstoken, string serverhash, LoginType type) + { + try + { + string jsonRequest = "{\"accessToken\":\"" + accesstoken + "\",\"selectedProfile\":\"" + uuid + + "\",\"serverId\":\"" + serverhash + "\"}"; + string host = type == LoginType.yggdrasil + ? Config.Main.General.AuthServer.Host + : "sessionserver.mojang.com"; + int port = type == LoginType.yggdrasil ? Config.Main.General.AuthServer.Port : 443; + string endpoint = type == LoginType.yggdrasil + ? Config.Main.General.AuthServer.AuthlibInjectorAPIPath + "/sessionserver/session/minecraft/join" + : "/session/minecraft/join"; + + bool useHttps = type == LoginType.yggdrasil ? Config.Main.General.AuthServer.UseHttps : true; + var response = await DoHTTPSRequestAsync( + HttpMethod.Post, + host, + port, + endpoint, + new Dictionary + { + { "Accept", "application/json" }, + { "Content-Type", "application/json" } + }, + jsonRequest, + useHttps, + CancellationToken.None); + + return response.StatusCode >= 200 && response.StatusCode < 300; + } + catch + { + return false; + } + } + /// /// Retrieve available Realms worlds of a player and display them /// @@ -1115,15 +1264,18 @@ public static bool SessionCheck(string uuid, string accesstoken, string serverha /// Player UUID /// Access token /// List of ID of available Realms worlds - public static List RealmsListWorlds(string username, string uuid, string accesstoken) + public static List RealmsListWorlds(string username, string uuid, string accesstoken) => + RealmsListWorldsAsync(username, uuid, accesstoken).GetAwaiter().GetResult(); + + public static async Task> RealmsListWorldsAsync(string username, string uuid, string accesstoken, CancellationToken cancellationToken = default) { List realmsWorldsResult = new(); // Store world ID try { - string result = ""; + string result; string cookies = String.Format("sid=token:{0}:{1};user={2};version={3}", accesstoken, uuid, username, Program.MCHighestVersion); - DoHTTPSGet("pc.realms.minecraft.net", 443, "/worlds", cookies, ref result); + (_, result) = await DoHTTPSGetAsync("pc.realms.minecraft.net", 443, "/worlds", cookies, cancellationToken); var realmsWorlds = Json.ParseJson(result); if (realmsWorlds?["servers"] is System.Text.Json.Nodes.JsonArray serversArray && serversArray.Count > 0) @@ -1179,15 +1331,21 @@ public static List RealmsListWorlds(string username, string uuid, string /// Access token /// Server address (host:port) or empty string if failure public static string GetRealmsWorldServerAddress(string worldId, string username, string uuid, - string accesstoken) + string accesstoken) => + GetRealmsWorldServerAddressAsync(worldId, username, uuid, accesstoken).GetAwaiter().GetResult(); + + public static async Task GetRealmsWorldServerAddressAsync(string worldId, string username, string uuid, + string accesstoken, CancellationToken cancellationToken = default) { try { - string result = ""; + string result; string cookies = String.Format("sid=token:{0}:{1};user={2};version={3}", accesstoken, uuid, username, Program.MCHighestVersion); - int statusCode = DoHTTPSGet("pc.realms.minecraft.net", 443, "/worlds/v1/" + worldId + "/join/pc", - cookies, ref result); + var response = await DoHTTPSGetAsync("pc.realms.minecraft.net", 443, "/worlds/v1/" + worldId + "/join/pc", + cookies, cancellationToken); + int statusCode = response.StatusCode; + result = response.Result; if (statusCode == 200) { var serverAddress = Json.ParseJson(result); @@ -1227,6 +1385,13 @@ public static string GetRealmsWorldServerAddress(string worldId, string username /// Request result /// HTTP Status code private static int DoHTTPSGet(string host, int port, string path, string cookies, ref string result) + { + var response = DoHTTPSGetAsync(host, port, path, cookies).GetAwaiter().GetResult(); + result = response.Result; + return response.StatusCode; + } + + private static Task<(int StatusCode, string Result)> DoHTTPSGetAsync(string host, int port, string path, string cookies, CancellationToken cancellationToken = default) { Dictionary headers = new() { @@ -1235,7 +1400,7 @@ private static int DoHTTPSGet(string host, int port, string path, string cookies { "Pragma", "no-cache" }, { "User-Agent", "Java/1.6.0_27" } }; - return DoHTTPSRequest(HttpMethod.Get, host, port, path, headers, null, useHttps: true, ref result); + return DoHTTPSRequestAsync(HttpMethod.Get, host, port, path, headers, null, useHttps: true, cancellationToken); } /// @@ -1261,13 +1426,23 @@ private static int DoHTTPSPost(string host, int port, string path, string body, /// Request result /// HTTP Status code private static int DoHTTPSPost(string host, int port, string path, string body, bool useHttps, ref string result) + { + var response = DoHTTPSPostAsync(host, port, path, body, useHttps).GetAwaiter().GetResult(); + result = response.Result; + return response.StatusCode; + } + + private static Task<(int StatusCode, string Result)> DoHTTPSPostAsync(string host, int port, string path, string body, CancellationToken cancellationToken = default) => + DoHTTPSPostAsync(host, port, path, body, useHttps: true, cancellationToken); + + private static Task<(int StatusCode, string Result)> DoHTTPSPostAsync(string host, int port, string path, string body, bool useHttps, CancellationToken cancellationToken = default) { Dictionary headers = new() { { "User-Agent", "MCC/" + Program.Version }, { "Content-Type", "application/json" } }; - return DoHTTPSRequest(HttpMethod.Post, host, port, path, headers, body, useHttps, ref result); + return DoHTTPSRequestAsync(HttpMethod.Post, host, port, path, headers, body, useHttps, cancellationToken); } /// @@ -1284,69 +1459,62 @@ private static int DoHTTPSPost(string host, int port, string path, string body, /// HTTP Status code private static int DoHTTPSRequest(HttpMethod method, string host, int port, string path, Dictionary headers, string? body, bool useHttps, ref string result) { - string? postResult = null; - int statusCode = 520; - Exception? exception = null; - AutoTimeout.Perform(() => + var response = DoHTTPSRequestAsync(method, host, port, path, headers, body, useHttps, CancellationToken.None) + .GetAwaiter() + .GetResult(); + result = response.Result; + return response.StatusCode; + } + + private static async Task<(int StatusCode, string Result)> DoHTTPSRequestAsync(HttpMethod method, string host, int port, string path, Dictionary headers, string? body, bool useHttps, CancellationToken cancellationToken) + { + if (Settings.Config.Logging.DebugMessages) + ConsoleIO.WriteLineFormatted("§8" + string.Format(Translations.debug_request, host)); + + using SocketsHttpHandler handler = new(); + handler.ConnectCallback = async (ctx, ct) => { - try - { - if (Settings.Config.Logging.DebugMessages) - ConsoleIO.WriteLineFormatted("§8" + string.Format(Translations.debug_request, host)); + TcpClient client = ProxyHandler.NewTcpClient(host, port, true); + return client.GetStream(); + }; - using SocketsHttpHandler handler = new SocketsHttpHandler(); - handler.ConnectCallback = async (ctx, ct) => - { - TcpClient client = ProxyHandler.NewTcpClient(host, port, true); - return client.GetStream(); - }; + using HttpClient client = new(handler); - using HttpClient client = new HttpClient(handler); + string scheme = useHttps ? "https" : "http"; + using HttpRequestMessage request = new(method, scheme + "://" + host + ":" + port + path); - string scheme = useHttps ? "https" : "http"; - var request = new HttpRequestMessage(method, scheme + "://" + host + ":" + port + path); + string contentType = "text/plain"; + foreach (var header in headers) + { + request.Headers.TryAddWithoutValidation(header.Key, header.Value); + if (header.Key.Equals("Content-Type", StringComparison.OrdinalIgnoreCase)) + contentType = header.Value; + } - var contentType = "text/plain"; - foreach (var header in headers) - { - request.Headers.TryAddWithoutValidation(header.Key, header.Value); - if (header.Key.Equals("Content-Type", StringComparison.OrdinalIgnoreCase)) - contentType = header.Value; - } + if (body is not null) + request.Content = new StringContent(body, Encoding.UTF8, contentType); - if (body is not null) - request.Content = new StringContent(body, Encoding.UTF8, contentType); + if (Settings.Config.Logging.DebugMessages) + ConsoleIO.WriteLineFormatted("§8> " + request); - if (Settings.Config.Logging.DebugMessages) - ConsoleIO.WriteLineFormatted("§8> " + request); + using CancellationTokenSource timeoutCancellationTokenSource = + CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + timeoutCancellationTokenSource.CancelAfter(TimeSpan.FromSeconds(30)); - HttpResponseMessage response = client.SendAsync(request).GetAwaiter().GetResult(); - statusCode = (int)response.StatusCode; + using HttpResponseMessage response = await client.SendAsync(request, timeoutCancellationTokenSource.Token); + int statusCode = (int)response.StatusCode; + string responseBody = statusCode == 204 + ? "No Content" + : await response.Content.ReadAsStringAsync(timeoutCancellationTokenSource.Token); - postResult = statusCode == 204 - ? "No Content" - : response.Content.ReadAsStringAsync().GetAwaiter().GetResult(); + if (Settings.Config.Logging.DebugMessages) + { + ConsoleIO.WriteLine(""); + foreach (string line in responseBody.Split('\n')) + ConsoleIO.WriteLineFormatted("§8< " + line); + } - if (Settings.Config.Logging.DebugMessages) - { - ConsoleIO.WriteLine(""); - foreach (string line in postResult.Split('\n')) - ConsoleIO.WriteLineFormatted("§8< " + line); - } - } - catch (Exception e) - { - if (e is not System.Threading.ThreadAbortException) - { - exception = e; - } - } - }, TimeSpan.FromSeconds(30)); - if (postResult is not null) - result = postResult; - if (exception is not null) - throw exception; - return statusCode; + return (statusCode, responseBody); } /// @@ -1389,4 +1557,4 @@ public static DateTime UnixTimeStampToDateTime(long unixTimeStamp) return dateTime; } } -} \ No newline at end of file +} diff --git a/MinecraftClient/Protocol/ProxiedWebRequest.cs b/MinecraftClient/Protocol/ProxiedWebRequest.cs index 220ac3191a..7911bc329e 100644 --- a/MinecraftClient/Protocol/ProxiedWebRequest.cs +++ b/MinecraftClient/Protocol/ProxiedWebRequest.cs @@ -3,6 +3,8 @@ using System.Net; using System.Net.Http; using System.Text; +using System.Threading; +using System.Threading.Tasks; using MinecraftClient.Proxy; namespace MinecraftClient.Protocol @@ -72,6 +74,12 @@ private void SetupBasicHeaders() /// public Response Get() => Send(HttpMethod.Get); + /// + /// Perform GET request asynchronously. Proxy is handled automatically. + /// + public Task GetAsync(CancellationToken cancellationToken = default) => + SendAsync(HttpMethod.Get, cancellationToken: cancellationToken); + /// /// Perform POST request. Proxy is handled automatically. /// @@ -79,6 +87,14 @@ private void SetupBasicHeaders() /// Request body public Response Post(string contentType, string body) => Send(HttpMethod.Post, contentType, body); + /// + /// Perform POST request asynchronously. Proxy is handled automatically. + /// + /// The content type of request body + /// Request body + public Task PostAsync(string contentType, string body, CancellationToken cancellationToken = default) => + SendAsync(HttpMethod.Post, contentType, body, cancellationToken); + /// /// Send an HTTP request. Proxy is configured automatically from Settings. /// @@ -144,6 +160,66 @@ private Response Send(HttpMethod method, string? contentType = null, string? bod } } + /// + /// Send an HTTP request asynchronously. Proxy is configured automatically from Settings. + /// + private async Task SendAsync(HttpMethod method, string? contentType = null, string? body = null, CancellationToken cancellationToken = default) + { + using var handler = CreateHandler(); + using var client = new HttpClient(handler); + + using var request = new HttpRequestMessage(method, _uri); + + foreach (string key in Headers) + { + if (key.Equals("Content-Type", StringComparison.OrdinalIgnoreCase) || + key.Equals("Content-Length", StringComparison.OrdinalIgnoreCase) || + key.Equals("Host", StringComparison.OrdinalIgnoreCase)) + continue; + + request.Headers.TryAddWithoutValidation(key, Headers[key]); + } + + if (body is not null) + request.Content = new StringContent(body, Encoding.UTF8, contentType ?? "text/plain"); + + if (Debug) + { + ConsoleIO.WriteLine($"< {method} {_uri}"); + foreach (string key in Headers) + ConsoleIO.WriteLine($"< {key}: {Headers[key]}"); + } + + try + { + using var httpResponse = await client.SendAsync(request, cancellationToken); + string responseBody = await httpResponse.Content.ReadAsStringAsync(cancellationToken); + + var responseHeaders = new NameValueCollection(); + foreach (var header in httpResponse.Headers) + foreach (var val in header.Value) + responseHeaders.Add(header.Key.ToLowerInvariant(), val); + foreach (var header in httpResponse.Content.Headers) + foreach (var val in header.Value) + responseHeaders.Add(header.Key.ToLowerInvariant(), val); + + var cookies = new NameValueCollection(); + foreach (Cookie cookie in handler.CookieContainer.GetCookies(_uri)) + { + if (!cookie.Expired) + cookies.Add(cookie.Name, cookie.Value); + } + + return new Response((int)httpResponse.StatusCode, responseBody, responseHeaders, cookies); + } + catch (HttpRequestException ex) + { + if (Debug) + ConsoleIO.WriteLine("HTTP error: " + ex.Message); + return Response.Empty(); + } + } + /// /// Create a SocketsHttpHandler with proxy support from ProxyHandler settings. /// @@ -231,4 +307,4 @@ public override string ToString() } } } -} \ No newline at end of file +} diff --git a/MinecraftClient/Protocol/Session/SessionToken.cs b/MinecraftClient/Protocol/Session/SessionToken.cs index 1364012bc8..244c82a403 100644 --- a/MinecraftClient/Protocol/Session/SessionToken.cs +++ b/MinecraftClient/Protocol/Session/SessionToken.cs @@ -54,6 +54,16 @@ public bool SessionPreCheck(LoginType type) return false; } + public async Task SessionPreCheckAsync(LoginType type) + { + if (ID == string.Empty || PlayerID == String.Empty || ServerPublicKey is null) + return false; + + Crypto.CryptoHandler.ClientAESPrivateKey ??= Crypto.CryptoHandler.GenerateAESPrivateKey(); + string serverHash = Crypto.CryptoHandler.GetServerHash(ServerIDhash, ServerPublicKey, Crypto.CryptoHandler.ClientAESPrivateKey); + return await ProtocolHandler.SessionCheckAsync(PlayerID, ID, serverHash, type); + } + public override string ToString() { return String.Join(",", ID, PlayerName, PlayerID, ClientID, RefreshToken, ServerIDhash, diff --git a/MinecraftClient/Scripting/MccGameApi.cs b/MinecraftClient/Scripting/MccGameApi.cs index 391ece5024..4bd18f1bf5 100644 --- a/MinecraftClient/Scripting/MccGameApi.cs +++ b/MinecraftClient/Scripting/MccGameApi.cs @@ -704,9 +704,76 @@ public MccGameResult MoveToPlayer(string playerName, bool /// /// Run on a worker thread so ChatBot callbacks can poll the result without blocking MCC updates. /// - public Task> MoveToPlayerAsync(string playerName, bool allowUnsafe = false, bool allowDirectTeleport = false, int maxOffset = 0, int minOffset = 0, int timeoutMs = 0) + public async Task> MoveToPlayerAsync(string playerName, bool allowUnsafe = false, bool allowDirectTeleport = false, int maxOffset = 0, int minOffset = 0, int timeoutMs = 0) { - return Task.Run(() => MoveToPlayer(playerName, allowUnsafe, allowDirectTeleport, maxOffset, minOffset, timeoutMs)); + if (string.IsNullOrWhiteSpace(playerName)) + return MccGameResult.Fail("invalid_args"); + + if (!AreValidPathOffsets(maxOffset, minOffset) || timeoutMs < 0) + return MccGameResult.Fail("invalid_args"); + + McClient? client = clientProvider(); + if (client is null) + return NotConnected(); + + if (!client.GetTerrainEnabled() || !client.GetEntityHandlingEnabled()) + return MccGameResult.Fail("feature_disabled"); + + string nameFilter = playerName.Trim(); + NearbyPlayerSnapshot? target = client.InvokeOnMainThread(() => + { + return BuildTrackedPlayerSnapshots(client, includeSelf: false) + .Where(player => PlayerNameMatches(player, nameFilter)) + .OrderBy(player => player.Distance) + .FirstOrDefault(); + }); + + if (target is null) + return MccGameResult.Fail("invalid_state"); + + Location goal = new(target.X, target.Y, target.Z); + Location startLocation = client.InvokeOnMainThread(client.GetCurrentLocation); + TimeSpan? timeout = timeoutMs > 0 ? TimeSpan.FromMilliseconds(timeoutMs) : null; + bool pathFound = client.InvokeOnMainThread(() => client.MoveTo(goal, allowUnsafe, allowDirectTeleport, maxOffset, minOffset, timeout)); + + int verifyWaitMs = GetArrivalWaitMs(timeoutMs); + double tolerance = GetArrivalTolerance(maxOffset, minOffset); + Location finalLocation = client.InvokeOnMainThread(client.GetCurrentLocation); + bool arrived = false; + + if (pathFound) + { + (arrived, finalLocation) = await WaitForArrivalAsync(client, goal, verifyWaitMs, tolerance); + } + + MccMoveToPlayerResult resultData = new() + { + PathFound = pathFound, + Arrived = arrived, + Tolerance = tolerance, + VerifyWaitMs = verifyWaitMs, + Target = new MccMoveToPlayerTarget + { + PlayerName = target.Name, + EntityId = target.EntityId, + X = MccGameCommon.RoundCoordinate(target.X), + Y = MccGameCommon.RoundCoordinate(target.Y), + Z = MccGameCommon.RoundCoordinate(target.Z) + }, + StartLocation = MccGameCommon.ToCoordinate(startLocation), + FinalLocation = MccGameCommon.ToCoordinate(finalLocation), + FinalDistance = MccGameCommon.GetDistance(finalLocation, goal), + DistanceMoved = MccGameCommon.GetDistance(startLocation, finalLocation), + AllowUnsafe = allowUnsafe, + AllowDirectTeleport = allowDirectTeleport, + MaxOffset = maxOffset, + MinOffset = minOffset, + TimeoutMs = timeoutMs + }; + + return pathFound && arrived + ? MccGameResult.Ok(resultData) + : MccGameResult.Fail("action_incomplete", data: resultData); } /// @@ -1055,9 +1122,94 @@ public MccGameResult PickupItems(string itemType, double r /// /// Run on a worker thread so ChatBot callbacks can poll the result without blocking MCC updates. /// - public Task> PickupItemsAsync(string itemType, double radius = 16, int maxItems = 10, bool allowUnsafe = false, int timeoutMs = 0) + public async Task> PickupItemsAsync(string itemType, double radius = 16, int maxItems = 10, bool allowUnsafe = false, int timeoutMs = 0) { - return Task.Run(() => PickupItems(itemType, radius, maxItems, allowUnsafe, timeoutMs)); + if (string.IsNullOrWhiteSpace(itemType) || radius <= 0 || radius > 1024 || maxItems < 1 || timeoutMs < 0) + return MccGameResult.Fail("invalid_args"); + + if (!MccGameCommon.TryParseItemType(itemType.Trim(), out ItemType parsedItemType)) + return MccGameResult.Fail("invalid_args"); + + McClient? client = clientProvider(); + if (client is null) + return NotConnected(); + + if (!client.GetTerrainEnabled() || !client.GetEntityHandlingEnabled()) + return MccGameResult.Fail("feature_disabled"); + + int limit = Math.Clamp(maxItems, 1, 50); + NearbyItemSnapshot[] targets = client.InvokeOnMainThread(() => BuildNearbyItemSnapshots(client, parsedItemType, radius, limit)); + if (targets.Length == 0) + return MccGameResult.Fail("invalid_state"); + + bool inventoryEnabled = client.GetInventoryEnabled(); + int beforeCount = inventoryEnabled ? client.InvokeOnMainThread(() => GetInventoryItemCount(client, parsedItemType)) : 0; + int initialCount = beforeCount; + int verifyWaitMs = timeoutMs > 0 ? Math.Clamp(timeoutMs, MinArrivalWaitMs, MaxArrivalWaitMs) : 2500; + List attempts = new(targets.Length); + int successfulPickups = 0; + + foreach (NearbyItemSnapshot target in targets) + { + Location targetLocation = new(target.X, target.Y, target.Z); + Location startLocation = client.InvokeOnMainThread(client.GetCurrentLocation); + TimeSpan? moveTimeout = timeoutMs > 0 ? TimeSpan.FromMilliseconds(timeoutMs) : null; + bool pathFound = client.InvokeOnMainThread(() => client.MoveTo(targetLocation, allowUnsafe, false, 0, 0, moveTimeout)); + + Location finalLocation = client.InvokeOnMainThread(client.GetCurrentLocation); + bool arrived = false; + if (pathFound) + { + (arrived, finalLocation) = await WaitForArrivalAsync(client, targetLocation, verifyWaitMs, 2.0); + } + + bool entityGone = await WaitForEntityRemovalAsync(client, target.EntityId, verifyWaitMs); + int afterCount = inventoryEnabled ? client.InvokeOnMainThread(() => GetInventoryItemCount(client, parsedItemType)) : beforeCount; + int inventoryDelta = inventoryEnabled ? Math.Max(0, afterCount - beforeCount) : 0; + bool pickedUp = entityGone || inventoryDelta > 0; + if (pickedUp) + successfulPickups++; + + attempts.Add(new MccPickupAttempt + { + EntityId = target.EntityId, + ItemType = target.ItemType.ToString(), + TypeLabel = target.TypeLabel, + ExpectedCount = target.Count, + Target = MccGameCommon.ToCoordinate(target.X, target.Y, target.Z), + PathFound = pathFound, + Arrived = arrived, + EntityGone = entityGone, + InventoryDelta = inventoryDelta, + StartLocation = MccGameCommon.ToCoordinate(startLocation), + FinalLocation = MccGameCommon.ToCoordinate(finalLocation), + FinalDistance = MccGameCommon.GetDistance(finalLocation, targetLocation) + }); + + beforeCount = afterCount; + } + + int remainingNearby = client.InvokeOnMainThread(() => BuildNearbyItemSnapshots(client, parsedItemType, radius, 1000).Length); + int collectedCount = inventoryEnabled ? Math.Max(0, beforeCount - initialCount) : successfulPickups; + MccPickupItemsResult resultData = new() + { + ItemType = parsedItemType.ToString(), + Radius = radius, + MaxItems = limit, + AllowUnsafe = allowUnsafe, + TimeoutMs = verifyWaitMs, + Attempted = attempts.Count, + SuccessfulPickups = successfulPickups, + CollectedCount = collectedCount, + InitialInventoryCount = inventoryEnabled ? initialCount : null, + FinalInventoryCount = inventoryEnabled ? beforeCount : null, + RemainingNearby = remainingNearby, + Attempts = attempts.ToArray() + }; + + return successfulPickups > 0 + ? MccGameResult.Ok(resultData) + : MccGameResult.Fail("action_incomplete", data: resultData); } private static MccGameResult NotConnected() @@ -1132,6 +1284,24 @@ private static bool WaitForArrival(McClient client, Location goal, int waitMs, d } } + private static async Task<(bool Arrived, Location FinalLocation)> WaitForArrivalAsync(McClient client, Location goal, int waitMs, double tolerance) + { + DateTime deadline = DateTime.UtcNow.AddMilliseconds(waitMs); + Location finalLocation = client.InvokeOnMainThread(client.GetCurrentLocation); + + while (true) + { + finalLocation = client.InvokeOnMainThread(client.GetCurrentLocation); + if (MccGameCommon.GetDistance(finalLocation, goal) <= tolerance) + return (true, finalLocation); + + if (DateTime.UtcNow >= deadline) + return (false, finalLocation); + + await Task.Delay(ArrivalPollIntervalMs); + } + } + private static bool WaitForEntityRemoval(McClient client, int entityId, int waitMs) { DateTime deadline = DateTime.UtcNow.AddMilliseconds(waitMs); @@ -1148,6 +1318,22 @@ private static bool WaitForEntityRemoval(McClient client, int entityId, int wait } } + private static async Task WaitForEntityRemovalAsync(McClient client, int entityId, int waitMs) + { + DateTime deadline = DateTime.UtcNow.AddMilliseconds(waitMs); + while (true) + { + bool exists = client.InvokeOnMainThread(() => client.GetEntities().ContainsKey(entityId)); + if (!exists) + return true; + + if (DateTime.UtcNow >= deadline) + return false; + + await Task.Delay(ArrivalPollIntervalMs); + } + } + private static int GetInventoryItemCount(McClient client, ItemType itemType) { Container? inventory = client.GetInventory(0); diff --git a/MinecraftClient/TaskWithResult.cs b/MinecraftClient/TaskWithResult.cs index 53aec3aa4f..23b0dbc1e0 100644 --- a/MinecraftClient/TaskWithResult.cs +++ b/MinecraftClient/TaskWithResult.cs @@ -1,20 +1,24 @@ using System; using System.Threading; +using System.Threading.Tasks; namespace MinecraftClient { + internal interface IMainThreadTask + { + void ExecuteSynchronously(); + void Cancel(); + } + /// /// Holds an asynchronous task with return value /// /// Type of the return value - public class TaskWithResult + public sealed class TaskWithResult : IMainThreadTask { - private readonly AutoResetEvent resultEvent = new(false); private readonly Func task; - private T? result = default; - private Exception? exception = null; - private bool taskRun = false; - private readonly Lock taskRunLock = new(); + private readonly TaskCompletionSource completionSource = new(TaskCreationOptions.RunContinuationsAsynchronously); + private int taskState; /// /// Create a new asynchronous task with return value @@ -28,13 +32,7 @@ public TaskWithResult(Func task) /// /// Check whether the task has finished running /// - public bool HasRun - { - get - { - return taskRun; - } - } + public bool HasRun => completionSource.Task.IsCompleted; /// /// Get the task result (return value of the inner delegate) @@ -44,10 +42,10 @@ public T Result { get { - if (taskRun) - return result!; - else + if (!completionSource.Task.IsCompleted) throw new InvalidOperationException("Attempting to retrieve the result of an unfinished task"); + + return completionSource.Task.GetAwaiter().GetResult(); } } @@ -58,40 +56,39 @@ public Exception? Exception { get { - return exception; + return completionSource.Task.Exception?.InnerException; } } + public Task AsTask() + { + return completionSource.Task; + } + /// /// Execute the task in the current thread and set the property or to the returned value /// public void ExecuteSynchronously() { - // Make sur the task will not run twice - lock (taskRunLock) - { - if (taskRun) - { - throw new InvalidOperationException("Attempting to run a task twice"); - } - } + if (Interlocked.CompareExchange(ref taskState, 1, 0) != 0) + throw new InvalidOperationException("Attempting to run a task twice"); - // Run the task try { - result = task(); + completionSource.TrySetResult(task()); } catch (Exception e) { - exception = e; + completionSource.TrySetException(e); } + } - // Mark task as complete and release wait event - lock (taskRunLock) - { - taskRun = true; - } - resultEvent.Set(); + public void Cancel() + { + if (Interlocked.CompareExchange(ref taskState, 1, 0) != 0) + return; + + completionSource.TrySetException(new OperationCanceledException("Main-thread task was canceled before execution.")); } /// @@ -101,22 +98,7 @@ public void ExecuteSynchronously() /// Any exception thrown by the task public T WaitGetResult() { - // Wait only if the result is not available yet - bool mustWait = false; - lock (taskRunLock) - { - mustWait = !taskRun; - } - if (mustWait) - { - resultEvent.WaitOne(); - } - - // Receive exception from task - if (exception is not null) - throw exception; - - return result!; + return completionSource.Task.GetAwaiter().GetResult(); } } } diff --git a/MinecraftClient/Tui/TuiConsoleBackend.cs b/MinecraftClient/Tui/TuiConsoleBackend.cs index bee5975116..7a12df8927 100644 --- a/MinecraftClient/Tui/TuiConsoleBackend.cs +++ b/MinecraftClient/Tui/TuiConsoleBackend.cs @@ -3,6 +3,7 @@ using System.Diagnostics; using System.Runtime.InteropServices; using System.Threading; +using System.Threading.Tasks; using Avalonia; using Avalonia.Threading; using Consolonia; @@ -48,12 +49,11 @@ internal void RunTuiMainLoop(string[] args, Program.StartupState startupState) Dispatcher.UIThread.Post(() => view.HandleCtrlC()); }; - new Thread(() => + _ = Task.Run(() => { _viewReady.Wait(); ContinueMccStartup(args); - }) - { Name = "MCC-Main", IsBackground = true }.Start(); + }); AppBuilder builder = AppBuilder.Configure() .UseConsolonia() @@ -206,32 +206,49 @@ internal void DismissOverlay() } public string RequestImmediateInput() + { + return RequestImmediateInputAsync(CancellationToken.None).GetAwaiter().GetResult(); + } + + public Task RequestImmediateInputAsync(CancellationToken cancellationToken) { if (_shutdownRequested) { - Thread.Sleep(Timeout.Infinite); - return string.Empty; + return Task.FromCanceled(cancellationToken.CanBeCanceled + ? cancellationToken + : new CancellationToken(canceled: true)); } - var mre = new ManualResetEventSlim(false); - string? result = null; + TaskCompletionSource completion = new(TaskCreationOptions.RunContinuationsAsynchronously); void Handler(object? sender, string e) { - result = e; - mre.Set(); + MessageReceived -= Handler; + completion.TrySetResult(e); } MessageReceived += Handler; - mre.Wait(); - MessageReceived -= Handler; - return result ?? string.Empty; + if (cancellationToken.CanBeCanceled) + { + cancellationToken.Register(() => + { + MessageReceived -= Handler; + completion.TrySetCanceled(cancellationToken); + }); + } + + return completion.Task; } public string? ReadPassword() { - return RequestImmediateInput(); + return ReadPasswordAsync(CancellationToken.None).GetAwaiter().GetResult(); + } + + public async Task ReadPasswordAsync(CancellationToken cancellationToken) + { + return await RequestImmediateInputAsync(cancellationToken); } public void ClearInputBuffer() @@ -267,11 +284,11 @@ public void Shutdown() Dispatcher.UIThread.Post(() => lifetime.Shutdown()); } - new Thread(() => + _ = Task.Run(async () => { - Thread.Sleep(1000); + await Task.Delay(1000); Environment.Exit(0); - }) { Name = "TUI-Exit-Guard", IsBackground = true }.Start(); + }); } private volatile bool _shutdownRequested;