|
2 | 2 | #include <chrono> |
3 | 3 | #include <filesystem> |
4 | 4 | #include <fstream> |
| 5 | +#include <future> |
5 | 6 | #include <iomanip> |
6 | 7 | #include <iostream> |
7 | 8 | #include <mutex> |
@@ -368,6 +369,18 @@ int main(int argc, const char** argv) { |
368 | 369 | return httplib::Server::HandlerResponse::Unhandled; |
369 | 370 | }); |
370 | 371 |
|
| 372 | + auto wait_for_generation = [](std::future<void>& ft, sd_ctx_t* sd_ctx, const httplib::Request& req) { |
| 373 | + std::future_status ft_status; |
| 374 | + do { |
| 375 | + if (!ft.valid()) |
| 376 | + break; |
| 377 | + ft_status = ft.wait_for(std::chrono::milliseconds(1000)); |
| 378 | + if (req.is_connection_closed()) { |
| 379 | + sd_cancel_generation(sd_ctx, SD_CANCEL_ALL); |
| 380 | + } |
| 381 | + } while (ft_status != std::future_status::ready); |
| 382 | + }; |
| 383 | + |
371 | 384 | // root |
372 | 385 | svr.Get("/", [&](const httplib::Request&, httplib::Response& res) { |
373 | 386 | if (!svr_params.serve_html_path.empty()) { |
@@ -510,11 +523,13 @@ int main(int argc, const char** argv) { |
510 | 523 | sd_image_t* results = nullptr; |
511 | 524 | int num_results = 0; |
512 | 525 |
|
513 | | - { |
| 526 | + std::future<void> ft = std::async(std::launch::async, [&]() { |
514 | 527 | std::lock_guard<std::mutex> lock(sd_ctx_mutex); |
515 | 528 | results = generate_image(sd_ctx, &img_gen_params); |
516 | 529 | num_results = gen_params.batch_count; |
517 | | - } |
| 530 | + }); |
| 531 | + |
| 532 | + wait_for_generation(ft, sd_ctx, req); |
518 | 533 |
|
519 | 534 | for (int i = 0; i < num_results; i++) { |
520 | 535 | if (results[i].data == nullptr) { |
@@ -756,11 +771,13 @@ int main(int argc, const char** argv) { |
756 | 771 | sd_image_t* results = nullptr; |
757 | 772 | int num_results = 0; |
758 | 773 |
|
759 | | - { |
| 774 | + std::future<void> ft = std::async(std::launch::async, [&]() { |
760 | 775 | std::lock_guard<std::mutex> lock(sd_ctx_mutex); |
761 | 776 | results = generate_image(sd_ctx, &img_gen_params); |
762 | 777 | num_results = gen_params.batch_count; |
763 | | - } |
| 778 | + }); |
| 779 | + |
| 780 | + wait_for_generation(ft, sd_ctx, req); |
764 | 781 |
|
765 | 782 | json out; |
766 | 783 | out["created"] = static_cast<long long>(std::time(nullptr)); |
@@ -1071,11 +1088,13 @@ int main(int argc, const char** argv) { |
1071 | 1088 | sd_image_t* results = nullptr; |
1072 | 1089 | int num_results = 0; |
1073 | 1090 |
|
1074 | | - { |
| 1091 | + std::future<void> ft = std::async(std::launch::async, [&]() { |
1075 | 1092 | std::lock_guard<std::mutex> lock(sd_ctx_mutex); |
1076 | 1093 | results = generate_image(sd_ctx, &img_gen_params); |
1077 | 1094 | num_results = gen_params.batch_count; |
1078 | | - } |
| 1095 | + }); |
| 1096 | + |
| 1097 | + wait_for_generation(ft, sd_ctx, req); |
1079 | 1098 |
|
1080 | 1099 | json out; |
1081 | 1100 | out["images"] = json::array(); |
|
0 commit comments