|
13 | 13 |
|
14 | 14 | import celery |
15 | 15 | import pytest |
| 16 | +from kombu.utils.json import register_type |
16 | 17 |
|
17 | 18 | from taskbadger import Action, EmailIntegration, StatusEnum |
18 | 19 | from taskbadger.celery import Task |
@@ -111,6 +112,109 @@ def add_with_task_args(self, a, b): |
111 | 112 | create.assert_called_once_with("new_name", value_max=10, actions=actions, status=StatusEnum.PENDING) |
112 | 113 |
|
113 | 114 |
|
| 115 | +def test_celery_record_args(celery_session_app, celery_session_worker, bind_settings): |
| 116 | + @celery_session_app.task(bind=True, base=Task) |
| 117 | + def add_with_task_args(self, a, b): |
| 118 | + assert self.taskbadger_task is not None |
| 119 | + return a + b |
| 120 | + |
| 121 | + celery_session_worker.reload() |
| 122 | + |
| 123 | + with ( |
| 124 | + mock.patch("taskbadger.celery.create_task_safe") as create, |
| 125 | + mock.patch("taskbadger.celery.update_task_safe"), |
| 126 | + mock.patch("taskbadger.celery.get_task"), |
| 127 | + ): |
| 128 | + create.return_value = task_for_test() |
| 129 | + |
| 130 | + result = add_with_task_args.apply_async( |
| 131 | + (2, 2), |
| 132 | + taskbadger_name="new_name", |
| 133 | + taskbadger_value_max=10, |
| 134 | + taskbadger_kwargs={"data": {"foo": "bar"}}, |
| 135 | + taskbadger_record_task_args=True, |
| 136 | + ) |
| 137 | + assert result.get(timeout=10, propagate=True) == 4 |
| 138 | + |
| 139 | + create.assert_called_once_with( |
| 140 | + "new_name", |
| 141 | + value_max=10, |
| 142 | + data={"foo": "bar", "celery_task_args": [2, 2], "celery_task_kwargs": {}}, |
| 143 | + status=StatusEnum.PENDING, |
| 144 | + ) |
| 145 | + |
| 146 | + |
| 147 | +def test_celery_record_task_kwargs(celery_session_app, celery_session_worker, bind_settings): |
| 148 | + @celery_session_app.task(bind=True, base=Task) |
| 149 | + def add_with_task_kwargs(self, a, b, c=0): |
| 150 | + assert self.taskbadger_task is not None |
| 151 | + return a + b + c |
| 152 | + |
| 153 | + celery_session_worker.reload() |
| 154 | + |
| 155 | + with ( |
| 156 | + mock.patch("taskbadger.celery.create_task_safe") as create, |
| 157 | + mock.patch("taskbadger.celery.update_task_safe"), |
| 158 | + mock.patch("taskbadger.celery.get_task"), |
| 159 | + ): |
| 160 | + create.return_value = task_for_test() |
| 161 | + |
| 162 | + actions = [Action("stale", integration=EmailIntegration(to="test@test.com"))] |
| 163 | + result = add_with_task_kwargs.delay( |
| 164 | + 2, |
| 165 | + 2, |
| 166 | + c=3, |
| 167 | + taskbadger_name="new_name", |
| 168 | + taskbadger_value_max=10, |
| 169 | + taskbadger_kwargs={"actions": actions}, |
| 170 | + taskbadger_record_task_args=True, |
| 171 | + ) |
| 172 | + assert result.get(timeout=10, propagate=True) == 7 |
| 173 | + |
| 174 | + create.assert_called_once_with( |
| 175 | + "new_name", |
| 176 | + value_max=10, |
| 177 | + data={"celery_task_args": [2, 2], "celery_task_kwargs": {"c": 3}}, |
| 178 | + actions=actions, |
| 179 | + status=StatusEnum.PENDING, |
| 180 | + ) |
| 181 | + |
| 182 | + |
| 183 | +def test_celery_record_task_args_custom_serialization(celery_session_app, celery_session_worker, bind_settings): |
| 184 | + class A: |
| 185 | + def __init__(self, a, b): |
| 186 | + self.a = a |
| 187 | + self.b = b |
| 188 | + |
| 189 | + register_type(A, "A", lambda o: [o.a, o.b], lambda o: A(*o)) |
| 190 | + |
| 191 | + @celery_session_app.task(bind=True, base=Task) |
| 192 | + def add_task_custom_serialization(self, a): |
| 193 | + assert self.taskbadger_task is not None |
| 194 | + return a.a + a.b |
| 195 | + |
| 196 | + celery_session_worker.reload() |
| 197 | + |
| 198 | + with ( |
| 199 | + mock.patch("taskbadger.celery.create_task_safe") as create, |
| 200 | + mock.patch("taskbadger.celery.update_task_safe"), |
| 201 | + mock.patch("taskbadger.celery.get_task"), |
| 202 | + ): |
| 203 | + create.return_value = task_for_test() |
| 204 | + |
| 205 | + result = add_task_custom_serialization.delay( |
| 206 | + A(2, 2), |
| 207 | + taskbadger_record_task_args=True, |
| 208 | + ) |
| 209 | + assert result.get(timeout=10, propagate=True) == 4 |
| 210 | + |
| 211 | + create.assert_called_once_with( |
| 212 | + "tests.test_celery.add_task_custom_serialization", |
| 213 | + data={"celery_task_args": [{"__type__": "A", "__value__": [2, 2]}], "celery_task_kwargs": {}}, |
| 214 | + status=StatusEnum.PENDING, |
| 215 | + ) |
| 216 | + |
| 217 | + |
114 | 218 | def test_celery_task_with_args_in_decorator(celery_session_app, celery_session_worker, bind_settings): |
115 | 219 | @celery_session_app.task( |
116 | 220 | bind=True, |
|
0 commit comments