3838 ensure_literal_type ,
3939 extract_inner_type ,
4040 render_type_expr ,
41+ render_literal_type ,
4142)
4243
4344_NON_ALNUM_RE = re .compile (r"[^a-zA-Z0-9_]+" )
@@ -167,7 +168,7 @@ def encode_type(
167168 in_module : list [ModuleName ],
168169 permit_unknown_members : bool ,
169170) -> Tuple [TypeExpression , list [ModuleName ], list [FileContents ], set [TypeName ]]:
170- encoder_name : Optional [ str ] = None # defining this up here to placate mypy
171+ encoder_name : TypeName | None = None # defining this up here to placate mypy
171172 chunks : List [FileContents ] = []
172173 if isinstance (type , RiverNotType ):
173174 return (TypeName ("None" ), [], [], set ())
@@ -234,7 +235,7 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
234235 and prop .const is not None
235236 ].pop ()
236237 one_of_pending .setdefault (
237- f"{ prefix } OneOf_{ discriminator_value } " ,
238+ f"{ render_literal_type ( prefix ) } OneOf_{ discriminator_value } " ,
238239 (discriminator_value , []),
239240 )[1 ].append (oneof_t )
240241
@@ -270,12 +271,12 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
270271 oneof_t .properties .keys ()
271272 ).difference (common_members )
272273 encoder_name = TypeName (
273- f"encode_{ ensure_literal_type (type_name )} "
274+ f"encode_{ render_literal_type (type_name )} "
274275 )
275276 encoder_names .add (encoder_name )
276277 typeddict_encoder .append (
277278 f"""\
278- { encoder_name } (x) # type: ignore[arg-type]
279+ { render_literal_type ( encoder_name ) } (x) # type: ignore[arg-type]
279280 """ .strip ()
280281 )
281282 if local_discriminators :
@@ -299,12 +300,12 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
299300 one_of .append (type_name )
300301 chunks .extend (contents )
301302 encoder_name = TypeName (
302- f"encode_{ ensure_literal_type (type_name )} "
303+ f"encode_{ render_literal_type (type_name )} "
303304 )
304305 # TODO(dstewart): Figure out why uncommenting this breaks
305306 # generated code
306307 # encoder_names.add(encoder_name)
307- typeddict_encoder .append (f"{ encoder_name } (x)" )
308+ typeddict_encoder .append (f"{ render_literal_type ( encoder_name ) } (x)" )
308309 typeddict_encoder .append (
309310 f"""
310311 if x[{ repr (discriminator_name )} ]
@@ -317,19 +318,19 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
317318 union = OpenUnionTypeExpr (UnionTypeExpr (one_of ))
318319 else :
319320 union = UnionTypeExpr (one_of )
320- chunks .append (FileContents (f"{ prefix } = { render_type_expr (union )} " ))
321+ chunks .append (FileContents (f"{ render_literal_type ( prefix ) } = { render_type_expr (union )} " ))
321322 chunks .append (FileContents ("" ))
322323
323324 if base_model == "TypedDict" :
324- encoder_name = TypeName (f"encode_{ prefix } " )
325+ encoder_name = TypeName (f"encode_{ render_literal_type ( prefix ) } " )
325326 encoder_names .add (encoder_name )
326327 chunks .append (
327328 FileContents (
328329 "\n " .join (
329330 [
330331 dedent (
331332 f"""\
332- { encoder_name } : Callable[[{ repr (prefix )} ], Any] = (
333+ { render_literal_type ( encoder_name ) } : Callable[[{ repr (render_literal_type ( prefix ) )} ], Any] = (
333334 lambda x:
334335 """ .rstrip ()
335336 )
@@ -349,7 +350,7 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
349350 for i , t in enumerate (type .anyOf ):
350351 type_name , _ , contents , _ = encode_type (
351352 t ,
352- TypeName (f"{ prefix } AnyOf_{ i } " ),
353+ TypeName (f"{ render_literal_type ( prefix ) } AnyOf_{ i } " ),
353354 base_model ,
354355 in_module ,
355356 permit_unknown_members = permit_unknown_members ,
@@ -366,7 +367,7 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
366367 match type_name :
367368 case ListTypeExpr (inner_type_name ):
368369 typeddict_encoder .append (
369- f"encode_{ ensure_literal_type (inner_type_name )} (x)"
370+ f"encode_{ render_literal_type (inner_type_name )} (x)"
370371 )
371372 case DictTypeExpr (_):
372373 raise ValueError (
@@ -377,23 +378,23 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
377378 typeddict_encoder .append (repr (const ))
378379 case other :
379380 typeddict_encoder .append (
380- f"encode_{ ensure_literal_type (other )} (x)"
381+ f"encode_{ render_literal_type (other )} (x)"
381382 )
382383 if permit_unknown_members :
383384 union = OpenUnionTypeExpr (UnionTypeExpr (any_of ))
384385 else :
385386 union = UnionTypeExpr (any_of )
386387 if is_literal (type ):
387388 typeddict_encoder = ["x" ]
388- chunks .append (FileContents (f"{ prefix } = { render_type_expr (union )} " ))
389+ chunks .append (FileContents (f"{ render_literal_type ( prefix ) } = { render_type_expr (union )} " ))
389390 if base_model == "TypedDict" :
390- encoder_name = TypeName (f"encode_{ prefix } " )
391+ encoder_name = TypeName (f"encode_{ render_literal_type ( prefix ) } " )
391392 encoder_names .add (encoder_name )
392393 chunks .append (
393394 FileContents (
394395 "\n " .join (
395396 [
396- f"{ encoder_name } : Callable[[{ repr (prefix )} ], Any] = ("
397+ f"{ render_literal_type ( encoder_name ) } : Callable[[{ repr (render_literal_type ( prefix ) )} ], Any] = ("
397398 "lambda x: "
398399 ]
399400 + typeddict_encoder
@@ -491,7 +492,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
491492 match type_name :
492493 case ListTypeExpr (inner_type_name ):
493494 typeddict_encoder .append (
494- f"encode_{ ensure_literal_type (inner_type_name )} (x)"
495+ f"encode_{ render_literal_type (inner_type_name )} (x)"
495496 )
496497 case DictTypeExpr (_):
497498 raise ValueError (
@@ -500,11 +501,11 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
500501 case LiteralTypeExpr (const ):
501502 typeddict_encoder .append (repr (const ))
502503 case other :
503- typeddict_encoder .append (f"encode_{ ensure_literal_type (other )} (x)" )
504+ typeddict_encoder .append (f"encode_{ render_literal_type (other )} (x)" )
504505 return (DictTypeExpr (type_name ), module_info , type_chunks , encoder_names )
505506 assert type .type == "object" , type .type
506507
507- current_chunks : List [str ] = [f"class { prefix } ({ base_model } ):" ]
508+ current_chunks : List [str ] = [f"class { render_literal_type ( prefix ) } ({ base_model } ):" ]
508509 # For the encoder path, do we need "x" to be bound?
509510 # lambda x: ... vs lambda _: {}
510511 needs_binding = False
@@ -519,7 +520,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
519520 typeddict_encoder .append (f"{ repr (name )} :" )
520521 type_name , _ , contents , _ = encode_type (
521522 prop ,
522- TypeName (prefix + name .title ()),
523+ TypeName (prefix . value + name .title ()),
523524 base_model ,
524525 in_module ,
525526 permit_unknown_members = permit_unknown_members ,
@@ -531,17 +532,17 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
531532 typeddict_encoder .append ("'not implemented'" )
532533 elif isinstance (prop , RiverUnionType ):
533534 encoder_name = TypeName (
534- f"encode_{ ensure_literal_type (type_name )} "
535+ f"encode_{ render_literal_type (type_name )} "
535536 )
536537 encoder_names .add (encoder_name )
537- typeddict_encoder .append (f"{ encoder_name } (x[{ repr (name )} ])" )
538+ typeddict_encoder .append (f"{ render_literal_type ( encoder_name ) } (x[{ repr (name )} ])" )
538539 if name not in type .required :
539540 typeddict_encoder .append (
540541 f"if { repr (name )} in x and x[{ repr (name )} ] else None"
541542 )
542543 elif isinstance (prop , RiverIntersectionType ):
543544 encoder_name = TypeName (
544- f"encode_{ ensure_literal_type (type_name )} "
545+ f"encode_{ render_literal_type (type_name )} "
545546 )
546547 encoder_names .add (encoder_name )
547548 typeddict_encoder .append (f"{ encoder_name } (x[{ repr (name )} ])" )
@@ -552,11 +553,11 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
552553 safe_name = name
553554 if prop .type == "object" and not prop .patternProperties :
554555 encoder_name = TypeName (
555- f"encode_{ ensure_literal_type (type_name )} "
556+ f"encode_{ render_literal_type (type_name )} "
556557 )
557558 encoder_names .add (encoder_name )
558559 typeddict_encoder .append (
559- f"{ encoder_name } (x[{ repr (safe_name )} ])"
560+ f"{ render_literal_type ( encoder_name ) } (x[{ repr (safe_name )} ])"
560561 )
561562 if name not in prop .required :
562563 typeddict_encoder .append (
@@ -582,14 +583,14 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
582583 match type_name :
583584 case ListTypeExpr (inner_type_name ):
584585 encoder_name = TypeName (
585- f"encode_{ ensure_literal_type (inner_type_name )} "
586+ f"encode_{ render_literal_type (inner_type_name )} "
586587 )
587588 encoder_names .add (encoder_name )
588589 typeddict_encoder .append (
589590 dedent (
590591 f"""\
591592 [
592- { encoder_name } (y)
593+ { render_literal_type ( encoder_name ) } (y)
593594 for y in x[{ repr (name )} ]
594595 ]
595596 """ .rstrip ()
@@ -679,7 +680,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
679680
680681 if base_model == "TypedDict" :
681682 binding = "x" if needs_binding else "_"
682- encoder_name = TypeName (f"encode_{ prefix } " )
683+ encoder_name = TypeName (f"encode_{ render_literal_type ( prefix ) } " )
683684 encoder_names .add (encoder_name )
684685 current_chunks .insert (
685686 0 ,
@@ -688,7 +689,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
688689 [
689690 dedent (
690691 f"""\
691- { encoder_name } : Callable[[{ repr (prefix )} ], Any] = (
692+ { render_literal_type ( encoder_name ) } : Callable[[{ repr (render_literal_type ( prefix ) )} ], Any] = (
692693 lambda { binding } :
693694 """
694695 )
@@ -847,7 +848,7 @@ def __init__(self, client: river.Client[Any]):
847848 f"lambda xs: [encode_{ init_type_name } (x) for x in xs]"
848849 )
849850 else :
850- render_init_method = f"encode_{ ensure_literal_type (init_type )} "
851+ render_init_method = f"encode_{ render_literal_type (init_type )} "
851852 else :
852853 render_init_method = f"""\
853854 lambda x: TypeAdapter({ render_type_expr (init_type )} )
@@ -870,11 +871,11 @@ def __init__(self, client: river.Client[Any]):
870871 case ListTypeExpr (input_type_name ):
871872 render_input_method = f"""\
872873 lambda xs: [
873- encode_{ ensure_literal_type (input_type_name )} (x) for x in xs
874+ encode_{ render_literal_type (input_type_name )} (x) for x in xs
874875 ]
875876 """
876877 else :
877- render_input_method = f"encode_{ ensure_literal_type (input_type )} "
878+ render_input_method = f"encode_{ render_literal_type (input_type )} "
878879 else :
879880 render_input_method = f"""\
880881 lambda x: TypeAdapter({ render_type_expr (input_type )} )
@@ -1070,7 +1071,7 @@ async def {name}(
10701071 emitted_files [file_path ] = FileContents ("\n " .join ([existing ] + contents ))
10711072
10721073 rendered_imports = [
1073- f"from .{ dotted_modules } import { ', ' .join (sorted (names ))} "
1074+ f"from .{ dotted_modules } import { ', ' .join (sorted (render_literal_type ( x ) for x in names ))} "
10741075 for dotted_modules , names in imports .items ()
10751076 ]
10761077
0 commit comments