Skip to content

Commit 2a8e36d

Browse files
Found the rest of them!
1 parent 821b902 commit 2a8e36d

1 file changed

Lines changed: 33 additions & 32 deletions

File tree

src/replit_river/codegen/client.py

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
ensure_literal_type,
3939
extract_inner_type,
4040
render_type_expr,
41+
render_literal_type,
4142
)
4243

4344
_NON_ALNUM_RE = re.compile(r"[^a-zA-Z0-9_]+")
@@ -167,7 +168,7 @@ def encode_type(
167168
in_module: list[ModuleName],
168169
permit_unknown_members: bool,
169170
) -> Tuple[TypeExpression, list[ModuleName], list[FileContents], set[TypeName]]:
170-
encoder_name: Optional[str] = None # defining this up here to placate mypy
171+
encoder_name: TypeName | None = None # defining this up here to placate mypy
171172
chunks: List[FileContents] = []
172173
if isinstance(type, RiverNotType):
173174
return (TypeName("None"), [], [], set())
@@ -234,7 +235,7 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
234235
and prop.const is not None
235236
].pop()
236237
one_of_pending.setdefault(
237-
f"{prefix}OneOf_{discriminator_value}",
238+
f"{render_literal_type(prefix)}OneOf_{discriminator_value}",
238239
(discriminator_value, []),
239240
)[1].append(oneof_t)
240241

@@ -270,12 +271,12 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
270271
oneof_t.properties.keys()
271272
).difference(common_members)
272273
encoder_name = TypeName(
273-
f"encode_{ensure_literal_type(type_name)}"
274+
f"encode_{render_literal_type(type_name)}"
274275
)
275276
encoder_names.add(encoder_name)
276277
typeddict_encoder.append(
277278
f"""\
278-
{encoder_name}(x) # type: ignore[arg-type]
279+
{render_literal_type(encoder_name)}(x) # type: ignore[arg-type]
279280
""".strip()
280281
)
281282
if local_discriminators:
@@ -299,12 +300,12 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
299300
one_of.append(type_name)
300301
chunks.extend(contents)
301302
encoder_name = TypeName(
302-
f"encode_{ensure_literal_type(type_name)}"
303+
f"encode_{render_literal_type(type_name)}"
303304
)
304305
# TODO(dstewart): Figure out why uncommenting this breaks
305306
# generated code
306307
# encoder_names.add(encoder_name)
307-
typeddict_encoder.append(f"{encoder_name}(x)")
308+
typeddict_encoder.append(f"{render_literal_type(encoder_name)}(x)")
308309
typeddict_encoder.append(
309310
f"""
310311
if x[{repr(discriminator_name)}]
@@ -317,19 +318,19 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
317318
union = OpenUnionTypeExpr(UnionTypeExpr(one_of))
318319
else:
319320
union = UnionTypeExpr(one_of)
320-
chunks.append(FileContents(f"{prefix} = {render_type_expr(union)}"))
321+
chunks.append(FileContents(f"{render_literal_type(prefix)} = {render_type_expr(union)}"))
321322
chunks.append(FileContents(""))
322323

323324
if base_model == "TypedDict":
324-
encoder_name = TypeName(f"encode_{prefix}")
325+
encoder_name = TypeName(f"encode_{render_literal_type(prefix)}")
325326
encoder_names.add(encoder_name)
326327
chunks.append(
327328
FileContents(
328329
"\n".join(
329330
[
330331
dedent(
331332
f"""\
332-
{encoder_name}: Callable[[{repr(prefix)}], Any] = (
333+
{render_literal_type(encoder_name)}: Callable[[{repr(render_literal_type(prefix))}], Any] = (
333334
lambda x:
334335
""".rstrip()
335336
)
@@ -349,7 +350,7 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
349350
for i, t in enumerate(type.anyOf):
350351
type_name, _, contents, _ = encode_type(
351352
t,
352-
TypeName(f"{prefix}AnyOf_{i}"),
353+
TypeName(f"{render_literal_type(prefix)}AnyOf_{i}"),
353354
base_model,
354355
in_module,
355356
permit_unknown_members=permit_unknown_members,
@@ -366,7 +367,7 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
366367
match type_name:
367368
case ListTypeExpr(inner_type_name):
368369
typeddict_encoder.append(
369-
f"encode_{ensure_literal_type(inner_type_name)}(x)"
370+
f"encode_{render_literal_type(inner_type_name)}(x)"
370371
)
371372
case DictTypeExpr(_):
372373
raise ValueError(
@@ -377,23 +378,23 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
377378
typeddict_encoder.append(repr(const))
378379
case other:
379380
typeddict_encoder.append(
380-
f"encode_{ensure_literal_type(other)}(x)"
381+
f"encode_{render_literal_type(other)}(x)"
381382
)
382383
if permit_unknown_members:
383384
union = OpenUnionTypeExpr(UnionTypeExpr(any_of))
384385
else:
385386
union = UnionTypeExpr(any_of)
386387
if is_literal(type):
387388
typeddict_encoder = ["x"]
388-
chunks.append(FileContents(f"{prefix} = {render_type_expr(union)}"))
389+
chunks.append(FileContents(f"{render_literal_type(prefix)} = {render_type_expr(union)}"))
389390
if base_model == "TypedDict":
390-
encoder_name = TypeName(f"encode_{prefix}")
391+
encoder_name = TypeName(f"encode_{render_literal_type(prefix)}")
391392
encoder_names.add(encoder_name)
392393
chunks.append(
393394
FileContents(
394395
"\n".join(
395396
[
396-
f"{encoder_name}: Callable[[{repr(prefix)}], Any] = ("
397+
f"{render_literal_type(encoder_name)}: Callable[[{repr(render_literal_type(prefix))}], Any] = ("
397398
"lambda x: "
398399
]
399400
+ typeddict_encoder
@@ -491,7 +492,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
491492
match type_name:
492493
case ListTypeExpr(inner_type_name):
493494
typeddict_encoder.append(
494-
f"encode_{ensure_literal_type(inner_type_name)}(x)"
495+
f"encode_{render_literal_type(inner_type_name)}(x)"
495496
)
496497
case DictTypeExpr(_):
497498
raise ValueError(
@@ -500,11 +501,11 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
500501
case LiteralTypeExpr(const):
501502
typeddict_encoder.append(repr(const))
502503
case other:
503-
typeddict_encoder.append(f"encode_{ensure_literal_type(other)}(x)")
504+
typeddict_encoder.append(f"encode_{render_literal_type(other)}(x)")
504505
return (DictTypeExpr(type_name), module_info, type_chunks, encoder_names)
505506
assert type.type == "object", type.type
506507

507-
current_chunks: List[str] = [f"class {prefix}({base_model}):"]
508+
current_chunks: List[str] = [f"class {render_literal_type(prefix)}({base_model}):"]
508509
# For the encoder path, do we need "x" to be bound?
509510
# lambda x: ... vs lambda _: {}
510511
needs_binding = False
@@ -519,7 +520,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
519520
typeddict_encoder.append(f"{repr(name)}:")
520521
type_name, _, contents, _ = encode_type(
521522
prop,
522-
TypeName(prefix + name.title()),
523+
TypeName(prefix.value + name.title()),
523524
base_model,
524525
in_module,
525526
permit_unknown_members=permit_unknown_members,
@@ -531,17 +532,17 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
531532
typeddict_encoder.append("'not implemented'")
532533
elif isinstance(prop, RiverUnionType):
533534
encoder_name = TypeName(
534-
f"encode_{ensure_literal_type(type_name)}"
535+
f"encode_{render_literal_type(type_name)}"
535536
)
536537
encoder_names.add(encoder_name)
537-
typeddict_encoder.append(f"{encoder_name}(x[{repr(name)}])")
538+
typeddict_encoder.append(f"{render_literal_type(encoder_name)}(x[{repr(name)}])")
538539
if name not in type.required:
539540
typeddict_encoder.append(
540541
f"if {repr(name)} in x and x[{repr(name)}] else None"
541542
)
542543
elif isinstance(prop, RiverIntersectionType):
543544
encoder_name = TypeName(
544-
f"encode_{ensure_literal_type(type_name)}"
545+
f"encode_{render_literal_type(type_name)}"
545546
)
546547
encoder_names.add(encoder_name)
547548
typeddict_encoder.append(f"{encoder_name}(x[{repr(name)}])")
@@ -552,11 +553,11 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
552553
safe_name = name
553554
if prop.type == "object" and not prop.patternProperties:
554555
encoder_name = TypeName(
555-
f"encode_{ensure_literal_type(type_name)}"
556+
f"encode_{render_literal_type(type_name)}"
556557
)
557558
encoder_names.add(encoder_name)
558559
typeddict_encoder.append(
559-
f"{encoder_name}(x[{repr(safe_name)}])"
560+
f"{render_literal_type(encoder_name)}(x[{repr(safe_name)}])"
560561
)
561562
if name not in prop.required:
562563
typeddict_encoder.append(
@@ -582,14 +583,14 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
582583
match type_name:
583584
case ListTypeExpr(inner_type_name):
584585
encoder_name = TypeName(
585-
f"encode_{ensure_literal_type(inner_type_name)}"
586+
f"encode_{render_literal_type(inner_type_name)}"
586587
)
587588
encoder_names.add(encoder_name)
588589
typeddict_encoder.append(
589590
dedent(
590591
f"""\
591592
[
592-
{encoder_name}(y)
593+
{render_literal_type(encoder_name)}(y)
593594
for y in x[{repr(name)}]
594595
]
595596
""".rstrip()
@@ -679,7 +680,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
679680

680681
if base_model == "TypedDict":
681682
binding = "x" if needs_binding else "_"
682-
encoder_name = TypeName(f"encode_{prefix}")
683+
encoder_name = TypeName(f"encode_{render_literal_type(prefix)}")
683684
encoder_names.add(encoder_name)
684685
current_chunks.insert(
685686
0,
@@ -688,7 +689,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
688689
[
689690
dedent(
690691
f"""\
691-
{encoder_name}: Callable[[{repr(prefix)}], Any] = (
692+
{render_literal_type(encoder_name)}: Callable[[{repr(render_literal_type(prefix))}], Any] = (
692693
lambda {binding}:
693694
"""
694695
)
@@ -847,7 +848,7 @@ def __init__(self, client: river.Client[Any]):
847848
f"lambda xs: [encode_{init_type_name}(x) for x in xs]"
848849
)
849850
else:
850-
render_init_method = f"encode_{ensure_literal_type(init_type)}"
851+
render_init_method = f"encode_{render_literal_type(init_type)}"
851852
else:
852853
render_init_method = f"""\
853854
lambda x: TypeAdapter({render_type_expr(init_type)})
@@ -870,11 +871,11 @@ def __init__(self, client: river.Client[Any]):
870871
case ListTypeExpr(input_type_name):
871872
render_input_method = f"""\
872873
lambda xs: [
873-
encode_{ensure_literal_type(input_type_name)}(x) for x in xs
874+
encode_{render_literal_type(input_type_name)}(x) for x in xs
874875
]
875876
"""
876877
else:
877-
render_input_method = f"encode_{ensure_literal_type(input_type)}"
878+
render_input_method = f"encode_{render_literal_type(input_type)}"
878879
else:
879880
render_input_method = f"""\
880881
lambda x: TypeAdapter({render_type_expr(input_type)})
@@ -1070,7 +1071,7 @@ async def {name}(
10701071
emitted_files[file_path] = FileContents("\n".join([existing] + contents))
10711072

10721073
rendered_imports = [
1073-
f"from .{dotted_modules} import {', '.join(sorted(names))}"
1074+
f"from .{dotted_modules} import {', '.join(sorted(render_literal_type(x) for x in names))}"
10741075
for dotted_modules, names in imports.items()
10751076
]
10761077

0 commit comments

Comments
 (0)