Skip to content

Commit 002d977

Browse files
Fixing two incorrect type comparisons
1 parent c3f04e4 commit 002d977

2 files changed

Lines changed: 29 additions & 34 deletions

File tree

src/replit_river/codegen/client.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
ListTypeExpr,
3131
LiteralTypeExpr,
3232
ModuleName,
33+
NoneTypeExpr,
3334
OpenUnionTypeExpr,
3435
RenderedPath,
3536
TypeExpression,
@@ -170,7 +171,7 @@ def encode_type(
170171
encoder_name: TypeName | None = None # defining this up here to placate mypy
171172
chunks: List[FileContents] = []
172173
if isinstance(type, RiverNotType):
173-
return (TypeName("None"), [], [], set())
174+
return (NoneTypeExpr(), [], [], set())
174175
elif isinstance(type, RiverUnionType):
175176
typeddict_encoder = list[str]()
176177
encoder_names: set[TypeName] = set()
@@ -379,17 +380,15 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
379380
typeddict_encoder.append(
380381
f"encode_{render_literal_type(inner_type_name)}(x)"
381382
)
382-
case DictTypeExpr(_):
383-
raise ValueError(
384-
"What does it mean to try and encode a dict in"
385-
" this position?"
386-
)
387383
case LiteralTypeExpr(const):
388384
typeddict_encoder.append(repr(const))
385+
case TypeName(value):
386+
typeddict_encoder.append(f"encode_{value}(x)")
387+
case NoneTypeExpr():
388+
typeddict_encoder.append("None")
389389
case other:
390-
typeddict_encoder.append(
391-
f"encode_{render_literal_type(other)}(x)"
392-
)
390+
_o2: DictTypeExpr | OpenUnionTypeExpr | UnionTypeExpr = other
391+
raise ValueError(f"What does it mean to have {_o2} here?")
393392
if permit_unknown_members:
394393
union = OpenUnionTypeExpr(UnionTypeExpr(any_of))
395394
else:
@@ -471,7 +470,7 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
471470
return (TypeName("bool"), [], [], set())
472471
elif type.type == "null" or type.type == "undefined":
473472
typeddict_encoder.append("None")
474-
return (TypeName("None"), [], [], set())
473+
return (NoneTypeExpr(), [], [], set())
475474
elif type.type == "Date":
476475
typeddict_encoder.append("TODO: dstewart")
477476
return (TypeName("datetime.datetime"), [], [], set())
@@ -511,8 +510,11 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
511510
)
512511
case LiteralTypeExpr(const):
513512
typeddict_encoder.append(repr(const))
513+
case TypeName(value):
514+
typeddict_encoder.append(f"encode_{value}(x)")
514515
case other:
515-
typeddict_encoder.append(f"encode_{render_literal_type(other)}(x)")
516+
_o1: NoneTypeExpr | OpenUnionTypeExpr | UnionTypeExpr = other
517+
raise ValueError(f"What does it mean to have {_o1} here?")
516518
return (DictTypeExpr(type_name), module_info, type_chunks, encoder_names)
517519
assert type.type == "object", type.type
518520

@@ -823,7 +825,7 @@ def __init__(self, client: river.Client[Any]):
823825
module_names,
824826
permit_unknown_members=True,
825827
)
826-
if error_type == "None":
828+
if isinstance(error_type, NoneTypeExpr):
827829
error_type = TypeName("RiverError")
828830
else:
829831
serdes.append(
@@ -916,7 +918,7 @@ def __init__(self, client: river.Client[Any]):
916918
f"Unable to derive the input encoder from: {input_type}"
917919
)
918920

919-
if output_type == "None":
921+
if isinstance(output_type, NoneTypeExpr):
920922
parse_output_method = "lambda x: None"
921923

922924
if procedure.type == "rpc":

src/replit_river/codegen/typing.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,12 @@ def __str__(self) -> str:
1717
raise Exception("Complex type must be put through render_type_expr!")
1818

1919

20+
@dataclass(frozen=True)
21+
class NoneTypeExpr:
22+
def __str__(self) -> str:
23+
raise Exception("Complex type must be put through render_type_expr!")
24+
25+
2026
@dataclass(frozen=True)
2127
class DictTypeExpr:
2228
nested: "TypeExpression"
@@ -59,6 +65,7 @@ def __str__(self) -> str:
5965

6066
TypeExpression = (
6167
TypeName
68+
| NoneTypeExpr
6269
| DictTypeExpr
6370
| ListTypeExpr
6471
| LiteralTypeExpr
@@ -86,6 +93,8 @@ def render_type_expr(value: TypeExpression) -> str:
8693
)
8794
case TypeName(name):
8895
return name
96+
case NoneTypeExpr():
97+
return "None"
8998
case other:
9099
assert_never(other)
91100

@@ -112,33 +121,17 @@ def extract_inner_type(value: TypeExpression) -> TypeName:
112121
)
113122
case TypeName(name):
114123
return TypeName(name)
124+
case NoneTypeExpr():
125+
raise ValueError(f"Attempting to extract from a literal 'None': {value}")
115126
case other:
116127
assert_never(other)
117128

118129

119130
def ensure_literal_type(value: TypeExpression) -> TypeName:
120131
match value:
121-
case DictTypeExpr(_):
122-
raise ValueError(
123-
f"Unexpected expression when expecting a type name: {value}"
124-
)
125-
case ListTypeExpr(_):
126-
raise ValueError(
127-
f"Unexpected expression when expecting a type name: {value}"
128-
)
129-
case LiteralTypeExpr(_):
130-
raise ValueError(
131-
f"Unexpected expression when expecting a type name: {value}"
132-
)
133-
case UnionTypeExpr(_):
134-
raise ValueError(
135-
f"Unexpected expression when expecting a type name: {value}"
136-
)
137-
case OpenUnionTypeExpr(_):
138-
raise ValueError(
139-
f"Unexpected expression when expecting a type name: {value}"
140-
)
141132
case TypeName(name):
142133
return TypeName(name)
143134
case other:
144-
assert_never(other)
135+
raise ValueError(
136+
f"Unexpected expression when expecting a type name: {other}"
137+
)

0 commit comments

Comments
 (0)