diff --git a/.changepacks/changepack_log_h04jC8dn9bMOhEM5LmVP4.json b/.changepacks/changepack_log_h04jC8dn9bMOhEM5LmVP4.json new file mode 100644 index 0000000..afdc076 --- /dev/null +++ b/.changepacks/changepack_log_h04jC8dn9bMOhEM5LmVP4.json @@ -0,0 +1 @@ +{"changes":{"Cargo.toml":"Patch"},"note":"Implement Multiform","date":"2026-03-06T16:01:53.731789300Z"} \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 2cff1dc..6acfb7b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -341,7 +341,6 @@ version = "0.1.0" dependencies = [ "axum", "axum-test", - "axum_typed_multipart", "insta", "sea-orm", "serde", @@ -411,41 +410,6 @@ dependencies = [ "url", ] -[[package]] -name = "axum_typed_multipart" -version = "0.16.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74c8b2ee396b35396ec27f5b9aa101f77000ba842dc82549a381b74c3ae2db7e" -dependencies = [ - "anyhow", - "async-trait", - "axum", - "axum_typed_multipart_macros", - "bytes", - "chrono", - "futures-core", - "futures-util", - "rust_decimal", - "tempfile", - "thiserror", - "tokio", - "uuid", -] - -[[package]] -name = "axum_typed_multipart_macros" -version = "0.16.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a27cefbd055910a29c4a3710016559cece5bdb4fb78ec055a1c2e9f8c61e3aa9" -dependencies = [ - "darling 0.23.0", - "heck 0.5.0", - "proc-macro-error2", - "quote", - "syn 2.0.117", - "ubyte", -] - [[package]] name = "base64" version = "0.22.1" @@ -752,18 +716,8 @@ version = "0.20.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee" dependencies = [ - "darling_core 0.20.11", - "darling_macro 0.20.11", -] - -[[package]] -name = "darling" -version = "0.23.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25ae13da2f202d56bd7f91c25fba009e7717a1e4a1cc98a76d844b65ae912e9d" -dependencies = [ - "darling_core 0.23.0", - "darling_macro 0.23.0", + "darling_core", + "darling_macro", ] [[package]] @@ -779,37 +733,13 @@ dependencies = [ "syn 2.0.117", ] -[[package]] -name = "darling_core" -version = "0.23.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9865a50f7c335f53564bb694ef660825eb8610e0a53d3e11bf1b0d3df31e03b0" -dependencies = [ - "ident_case", - "proc-macro2", - "quote", - "strsim", - "syn 2.0.117", -] - [[package]] name = "darling_macro" version = "0.20.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" dependencies = [ - "darling_core 0.20.11", - "quote", - "syn 2.0.117", -] - -[[package]] -name = "darling_macro" -version = "0.23.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac3984ec7bd6cfa798e62b4a642426a5be0e68f9401cfc2a01e3fa9ea2fcdb8d" -dependencies = [ - "darling_core 0.23.0", + "darling_core", "quote", "syn 2.0.117", ] @@ -2600,7 +2530,7 @@ version = "1.0.0-rc.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8d88ad44b6ad9788c8b9476b6b91f94c7461d1e19d39cd8ea37838b1e6ff5aa8" dependencies = [ - "darling 0.20.11", + "darling", "heck 0.4.1", "proc-macro2", "quote", @@ -3092,12 +3022,6 @@ dependencies = [ "unicode-properties", ] -[[package]] -name = "strsim" -version = "0.11.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" - [[package]] name = "strum" version = "0.27.2" @@ -3450,12 +3374,6 @@ dependencies = [ "syn 2.0.117", ] -[[package]] -name = "ubyte" -version = "0.10.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f720def6ce1ee2fc44d40ac9ed6d3a59c361c80a75a7aa8e75bb9baed31cf2ea" - [[package]] name = "unicode-bidi" version = "0.3.18" @@ -3539,11 +3457,10 @@ checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" [[package]] name = "vespera" -version = "0.1.40" +version = "0.1.41" dependencies = [ "axum", "axum-extra", - "axum_typed_multipart", "chrono", "serde_json", "tempfile", @@ -3555,7 +3472,7 @@ dependencies = [ [[package]] name = "vespera_core" -version = "0.1.40" +version = "0.1.41" dependencies = [ "rstest", "serde", @@ -3564,7 +3481,7 @@ dependencies = [ [[package]] name = "vespera_macro" -version = "0.1.40" +version = "0.1.41" dependencies = [ "insta", "proc-macro2", diff --git a/README.md b/README.md index dc8ca95..ccf0c8d 100644 --- a/README.md +++ b/README.md @@ -168,13 +168,14 @@ pub struct CreateUserRequest { #### Typed Multipart (Recommended) -Upload files using `TypedMultipart` from [`axum_typed_multipart`](https://crates.io/crates/axum_typed_multipart): +Upload files using vespera's built-in `TypedMultipart` extractor: ```rust -use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart}; +use vespera::multipart::{FieldData, TypedMultipart}; +use vespera::{Multipart, Schema}; use tempfile::NamedTempFile; -#[derive(TryFromMultipart, vespera::Schema)] +#[derive(Multipart, Schema)] pub struct CreateUploadRequest { pub name: String, #[form_data(limit = "10MiB")] @@ -189,8 +190,6 @@ pub async fn create_upload( Vespera automatically generates `multipart/form-data` content type in OpenAPI, and maps `FieldData` to `{ "type": "string", "format": "binary" }`. -> **Note:** `axum` must be a direct dependency of your project (not just via vespera) because `TryFromMultipart` internally references `axum::extract::multipart::Multipart`. - #### Raw Multipart (Untyped) For dynamic multipart handling where the fields aren't known at compile time, use axum's built-in `Multipart` extractor: @@ -444,10 +443,10 @@ vespera::schema_type!(Schema from Model, name = "MemoSchema"); ### Multipart Mode -Generate `TryFromMultipart` structs from existing types using the `multipart` keyword: +Generate `Multipart` structs from existing types using the `multipart` keyword: ```rust -#[derive(TryFromMultipart, vespera::Schema)] +#[derive(vespera::Multipart, vespera::Schema)] pub struct CreateUploadRequest { pub name: String, #[form_data(limit = "10MiB")] @@ -455,12 +454,12 @@ pub struct CreateUploadRequest { pub description: Option, } -// Generates a TryFromMultipart struct (no serde derives), all fields Optional +// Generates a Multipart struct (no serde derives), all fields Optional schema_type!(PatchUploadRequest from CreateUploadRequest, multipart, partial, omit = ["file"]); ``` When `multipart` is enabled: -- Derives `TryFromMultipart` instead of `Serialize`/`Deserialize` +- Derives `Multipart` instead of `Serialize`/`Deserialize` - Suppresses `#[serde(...)]` attributes (multipart parsing is not serde-based) - Preserves `#[form_data(...)]` attributes from source struct - Skips SeaORM relation fields (nested objects can't be represented in multipart forms) @@ -479,7 +478,7 @@ When `multipart` is enabled: | `name` | Custom OpenAPI schema name: `name = "UserSchema"` | | `rename_all` | Serde rename strategy: `rename_all = "camelCase"` | | `ignore` | Skip Schema derive (bare keyword, no value) | -| `multipart` | Derive `TryFromMultipart` instead of serde (bare keyword) | +| `multipart` | Derive `Multipart` instead of serde (bare keyword) | | `omit_default` | Auto-omit fields with DB defaults: `primary_key`, `default_value` (bare keyword) | --- diff --git a/SKILL.md b/SKILL.md index 51c6c24..31bc95b 100644 --- a/SKILL.md +++ b/SKILL.md @@ -345,7 +345,7 @@ Json(model.into()) // Easy conversion! | `rename` | Rename fields | API naming differs from model | | `rename_all` | Serde rename strategy | Different casing needed | | `add` | Add new fields | New fields not in model (breaks `From` impl) | -| `multipart` | Derive `TryFromMultipart` | Multipart form-data endpoints | +| `multipart` | Derive `Multipart` | Multipart form-data endpoints | **Avoid (Special Cases Only):** @@ -446,14 +446,15 @@ pub async fn patch_user( ### Multipart Mode (`multipart`) -Generate `TryFromMultipart` structs from existing multipart request types: +Generate `Multipart` structs from existing multipart request types: ```rust -use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart}; +use vespera::multipart::{FieldData, TypedMultipart}; +use vespera::{Multipart, Schema}; use tempfile::NamedTempFile; // Base multipart struct (manually defined) -#[derive(TryFromMultipart, vespera::Schema)] +#[derive(Multipart, Schema)] pub struct CreateUploadRequest { pub name: String, #[form_data(limit = "10MiB")] @@ -464,7 +465,7 @@ pub struct CreateUploadRequest { } // Derive a partial update struct via schema_type! -// - Derives TryFromMultipart (not serde) +// - Derives Multipart (not serde) // - All fields become Option (partial) // - "document" field excluded // - #[form_data(limit = "10MiB")] preserved from source @@ -475,18 +476,17 @@ schema_type!(PatchUploadRequest from CreateUploadRequest, multipart, partial, om | Aspect | Normal Mode | Multipart Mode | |--------|------------|----------------| -| Derives | `Serialize`, `Deserialize` | `TryFromMultipart` | +| Derives | `Serialize`, `Deserialize` | `Multipart` | | Struct attrs | `#[serde(rename_all=...)]` | None | | Field attrs | `#[serde(...)]` preserved | `#[form_data(...)]` preserved | | Relation fields | Included (BelongsTo/HasOne) | **Skipped** (can't represent in forms) | | `From` impl | Auto-generated | **Not generated** | -**OpenAPI rename alignment:** The schema parser reads `#[form_data(field_name = "...")]` and `#[try_from_multipart(rename_all = "...")]` as fallbacks when serde attrs are absent, ensuring OpenAPI field names match runtime multipart parsing. +**OpenAPI rename alignment:** The schema parser reads `#[form_data(field_name = "...")]` and `#[serde(rename_all = "...")]` for multipart structs, ensuring OpenAPI field names match runtime multipart parsing. **Dependencies required in your Cargo.toml:** ```toml -axum = "0.8" # Required: TryFromMultipart references axum internals -axum_typed_multipart = "0.16" # The multipart crate +vespera = "0.1" # Includes multipart support natively tempfile = "3" # For NamedTempFile file uploads ``` diff --git a/crates/vespera/Cargo.toml b/crates/vespera/Cargo.toml index 42cdbda..eff4ef7 100644 --- a/crates/vespera/Cargo.toml +++ b/crates/vespera/Cargo.toml @@ -12,10 +12,9 @@ default = ["axum-extra/typed-header", "axum-extra/form", "axum-extra/query", "ax [dependencies] vespera_core = { workspace = true } vespera_macro = { workspace = true } -axum = "0.8" +axum = { version = "0.8", features = ["multipart"] } axum-extra = { version = "0.12" } chrono = { version = "0.4", features = ["serde"] } -axum_typed_multipart = "0.16" tempfile = "3" serde_json = "1" tower-layer = "0.3" diff --git a/crates/vespera/src/lib.rs b/crates/vespera/src/lib.rs index 0b46068..5676061 100644 --- a/crates/vespera/src/lib.rs +++ b/crates/vespera/src/lib.rs @@ -20,7 +20,7 @@ pub mod openapi { pub use vespera_core::openapi::OpenApi; // Re-export macros from vespera_macro -pub use vespera_macro::{Schema, export_app, route, schema, schema_type, vespera}; +pub use vespera_macro::{Multipart, Schema, export_app, route, schema, schema_type, vespera}; // Re-export serde_json for merge feature (runtime spec merging) pub use serde_json; @@ -29,9 +29,8 @@ pub use serde_json; // This allows generated types to use chrono::DateTime without users adding chrono dependency pub use chrono; -// Re-export axum_typed_multipart for schema_type! multipart mode -// This allows generated types to use FieldData/TryFromMultipart without users adding the dependency -pub use axum_typed_multipart; +// Native multipart form data extraction (replaces axum_typed_multipart) +pub mod multipart; // Re-export tempfile for schema_type! multipart mode (NamedTempFile) pub use tempfile; diff --git a/crates/vespera/src/multipart.rs b/crates/vespera/src/multipart.rs new file mode 100644 index 0000000..6f8f8ea --- /dev/null +++ b/crates/vespera/src/multipart.rs @@ -0,0 +1,711 @@ +//! Native multipart form data extraction for Vespera. +//! +//! Replaces the `axum_typed_multipart` crate with a zero-dependency (beyond axum) +//! implementation of typed multipart extraction. All types here are referenced by +//! the `#[derive(Multipart)]` macro's generated code. +//! +//! # Key types +//! +//! - [`TypedMultipart`] — Axum extractor that parses `multipart/form-data` into `T` +//! - [`TypedMultipartError`] — Error type for multipart parsing failures +//! - [`FieldData`] — Wrapper providing file metadata alongside field contents +//! - [`FieldMetadata`] — Metadata extracted from a multipart field +//! - [`TryFromMultipartWithState`] — Trait for parsing a full multipart request +//! - [`TryFromFieldWithState`] — Trait for parsing a single multipart field + +use std::fmt; + +use axum::extract::multipart::{Field, MultipartError, MultipartRejection}; +use axum::extract::{FromRequest, Request}; +use axum::http::StatusCode; +use axum::response::{IntoResponse, Response}; + +// ═══════════════════════════════════════════════════════════════════════════════ +// Error type +// ═══════════════════════════════════════════════════════════════════════════════ + +/// Errors that can occur when parsing multipart form data. +#[derive(Debug)] +pub enum TypedMultipartError { + /// The request could not be parsed as multipart (e.g., missing Content-Type). + InvalidRequest { + /// The underlying rejection from axum's Multipart extractor. + source: MultipartRejection, + }, + /// An error occurred while reading the multipart body stream. + InvalidRequestBody { + /// The underlying multipart stream error. + source: MultipartError, + }, + /// A required field was not present in the multipart form. + MissingField { + /// Name of the missing field. + field_name: String, + }, + /// A field's value could not be parsed as the expected type. + WrongFieldType { + /// Name of the field. + field_name: String, + /// The expected type name. + wanted: String, + /// Description of the parse error. + source: String, + }, + /// A non-repeatable field appeared more than once (strict mode). + DuplicateField { + /// Name of the duplicate field. + field_name: String, + }, + /// An unrecognized field was found (strict mode only). + UnknownField { + /// Name of the unknown field. + field_name: String, + }, + /// A field's value is not a valid variant of the expected enum. + InvalidEnumValue { + /// Name of the field. + field_name: String, + /// The invalid value that was received. + value: String, + }, + /// A field without a name was encountered (strict mode only). + NamelessField, + /// A field exceeded its configured size limit. + FieldTooLarge { + /// Name of the field. + field_name: String, + /// The configured limit in bytes. + limit_bytes: usize, + }, + /// A catch-all for other errors during multipart processing. + Other { + /// Description of the error. + source: String, + }, +} + +impl fmt::Display for TypedMultipartError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::InvalidRequest { source } => { + write!(f, "Invalid multipart request: {source}") + } + Self::InvalidRequestBody { source } => { + write!(f, "Invalid multipart body: {source}") + } + Self::MissingField { field_name } => { + write!(f, "Missing field: `{field_name}`") + } + Self::WrongFieldType { + field_name, + wanted, + source, + } => { + write!( + f, + "Wrong type for field `{field_name}` (expected {wanted}): {source}" + ) + } + Self::DuplicateField { field_name } => { + write!(f, "Duplicate field: `{field_name}`") + } + Self::UnknownField { field_name } => { + write!(f, "Unknown field: `{field_name}`") + } + Self::InvalidEnumValue { field_name, value } => { + write!(f, "Invalid enum value `{value}` for field `{field_name}`") + } + Self::NamelessField => write!(f, "Encountered a field without a name"), + Self::FieldTooLarge { + field_name, + limit_bytes, + } => { + write!( + f, + "Field `{field_name}` exceeds size limit of {limit_bytes} bytes" + ) + } + Self::Other { source } => write!(f, "{source}"), + } + } +} + +impl std::error::Error for TypedMultipartError {} + +impl IntoResponse for TypedMultipartError { + fn into_response(self) -> Response { + let status = match &self { + Self::InvalidRequest { .. } + | Self::InvalidRequestBody { .. } + | Self::MissingField { .. } + | Self::DuplicateField { .. } + | Self::UnknownField { .. } + | Self::InvalidEnumValue { .. } + | Self::NamelessField => StatusCode::BAD_REQUEST, + Self::WrongFieldType { .. } => StatusCode::UNSUPPORTED_MEDIA_TYPE, + Self::FieldTooLarge { .. } => StatusCode::PAYLOAD_TOO_LARGE, + Self::Other { .. } => StatusCode::INTERNAL_SERVER_ERROR, + }; + (status, self.to_string()).into_response() + } +} + +impl From for TypedMultipartError { + fn from(source: MultipartError) -> Self { + Self::InvalidRequestBody { source } + } +} + +impl From for TypedMultipartError { + fn from(source: MultipartRejection) -> Self { + Self::InvalidRequest { source } + } +} + +// ═══════════════════════════════════════════════════════════════════════════════ +// Traits +// ═══════════════════════════════════════════════════════════════════════════════ + +/// Parse a full multipart request body into a struct. +/// +/// Typically generated by `#[derive(Multipart)]`. Each field in the struct +/// is matched against multipart field names and parsed via +/// [`TryFromFieldWithState`]. +pub trait TryFromMultipartWithState: Sized { + /// Parse the multipart stream into `Self`. + fn try_from_multipart_with_state( + multipart: &mut axum::extract::Multipart, + state: &S, + ) -> impl std::future::Future> + Send; +} + +/// Parse a single multipart field into a value. +/// +/// Built-in implementations exist for `String`, `bool`, all integer and float +/// types, `char`, `tempfile::NamedTempFile`, and `FieldData`. +pub trait TryFromFieldWithState: Sized { + /// Parse a single field into `Self`, optionally enforcing a byte-size limit. + fn try_from_field_with_state( + field: Field<'_>, + limit_bytes: Option, + state: &S, + ) -> impl std::future::Future> + Send; +} + +// ═══════════════════════════════════════════════════════════════════════════════ +// Field metadata +// ═══════════════════════════════════════════════════════════════════════════════ + +/// Metadata extracted from a multipart field part. +#[derive(Debug, Clone)] +pub struct FieldMetadata { + /// The field name (`name` attribute in the form). + pub name: Option, + /// The original filename (present for file uploads). + pub file_name: Option, + /// The MIME content type of the field. + pub content_type: Option, + /// All HTTP headers associated with this multipart part. + pub headers: axum::http::HeaderMap, +} + +impl From<&Field<'_>> for FieldMetadata { + fn from(field: &Field<'_>) -> Self { + Self { + name: field.name().map(String::from), + file_name: field.file_name().map(String::from), + content_type: field.content_type().map(String::from), + headers: field.headers().clone(), + } + } +} + +// ═══════════════════════════════════════════════════════════════════════════════ +// FieldData +// ═══════════════════════════════════════════════════════════════════════════════ + +/// A multipart field's parsed contents along with its metadata. +/// +/// Use this wrapper when you need access to the file name, content type, +/// or other headers alongside the parsed value. +/// +/// ```rust,ignore +/// use vespera::multipart::FieldData; +/// use tempfile::NamedTempFile; +/// +/// #[derive(Multipart, Schema)] +/// pub struct Upload { +/// pub file: FieldData, +/// } +/// ``` +#[derive(Debug)] +pub struct FieldData { + /// Metadata about the field (name, filename, content-type, headers). + pub metadata: FieldMetadata, + /// The parsed contents of the field. + pub contents: T, +} + +impl TryFromFieldWithState for FieldData +where + T: TryFromFieldWithState + Send, + S: Send + Sync, +{ + async fn try_from_field_with_state( + field: Field<'_>, + limit_bytes: Option, + state: &S, + ) -> Result { + let metadata = FieldMetadata::from(&field); + let contents = T::try_from_field_with_state(field, limit_bytes, state).await?; + Ok(Self { metadata, contents }) + } +} + +// ═══════════════════════════════════════════════════════════════════════════════ +// TypedMultipart extractor +// ═══════════════════════════════════════════════════════════════════════════════ + +/// Axum extractor for typed multipart form data. +/// +/// Wraps a struct `T` that implements [`TryFromMultipartWithState`] (typically +/// via `#[derive(Multipart)]`). +/// +/// ```rust,ignore +/// use vespera::multipart::{TypedMultipart, FieldData}; +/// use tempfile::NamedTempFile; +/// +/// #[derive(Multipart, Schema)] +/// pub struct UploadRequest { +/// pub name: String, +/// pub file: FieldData, +/// } +/// +/// #[vespera::route(post)] +/// pub async fn upload( +/// TypedMultipart(req): TypedMultipart, +/// ) -> Json { +/// Json(req.name) +/// } +/// ``` +pub struct TypedMultipart(pub T); + +impl std::ops::Deref for TypedMultipart { + type Target = T; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl std::ops::DerefMut for TypedMultipart { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl FromRequest for TypedMultipart +where + T: TryFromMultipartWithState, + S: Send + Sync + 'static, +{ + type Rejection = TypedMultipartError; + + async fn from_request(req: Request, state: &S) -> Result { + let mut multipart = axum::extract::Multipart::from_request(req, state) + .await + .map_err(TypedMultipartError::from)?; + let value = T::try_from_multipart_with_state(&mut multipart, state).await?; + Ok(Self(value)) + } +} + +// ═══════════════════════════════════════════════════════════════════════════════ +// Built-in TryFromFieldWithState implementations +// ═══════════════════════════════════════════════════════════════════════════════ + +// ─── Helpers ──────────────────────────────────────────────────────────────── + +/// Read all bytes from a multipart field, enforcing an optional size limit. +/// +/// When a limit is set, bytes are read incrementally via `chunk()` and the +/// cumulative size is checked after each chunk. Without a limit, `bytes()` is +/// called for a single-allocation read. +async fn read_field_data( + mut field: Field<'_>, + limit: Option, +) -> Result<(String, Vec), TypedMultipartError> { + let field_name = field.name().unwrap_or_default().to_string(); + + let data = if let Some(limit) = limit { + let mut buf = Vec::new(); + while let Some(chunk) = field.chunk().await? { + buf.extend_from_slice(&chunk); + if buf.len() > limit { + return Err(TypedMultipartError::FieldTooLarge { + field_name, + limit_bytes: limit, + }); + } + } + buf + } else { + field.bytes().await?.to_vec() + }; + + Ok((field_name, data)) +} + +/// Parse a string as a boolean using clap-style conventions. +/// +/// Accepted truthy values: `true`, `yes`, `y`, `1`, `on` +/// Accepted falsy values: `false`, `no`, `n`, `0`, `off` +fn str_to_bool(s: &str) -> Option { + match s.to_ascii_lowercase().as_str() { + "true" | "yes" | "y" | "1" | "on" => Some(true), + "false" | "no" | "n" | "0" | "off" => Some(false), + _ => None, + } +} + +// ─── String ───────────────────────────────────────────────────────────────── + +impl TryFromFieldWithState for String { + async fn try_from_field_with_state( + field: Field<'_>, + limit_bytes: Option, + _state: &S, + ) -> Result { + let (field_name, data) = read_field_data(field, limit_bytes).await?; + Self::from_utf8(data).map_err(|e| TypedMultipartError::WrongFieldType { + field_name, + wanted: "String".to_string(), + source: e.to_string(), + }) + } +} + +// ─── bool ─────────────────────────────────────────────────────────────────── + +impl TryFromFieldWithState for bool { + async fn try_from_field_with_state( + field: Field<'_>, + limit_bytes: Option, + _state: &S, + ) -> Result { + let (field_name, data) = read_field_data(field, limit_bytes).await?; + let text = std::str::from_utf8(&data).map_err(|e| TypedMultipartError::WrongFieldType { + field_name: field_name.clone(), + wanted: "bool".to_string(), + source: e.to_string(), + })?; + str_to_bool(text).ok_or_else(|| TypedMultipartError::WrongFieldType { + field_name, + wanted: "bool".to_string(), + source: format!("invalid boolean value: `{text}`"), + }) + } +} + +// ─── Numeric types ────────────────────────────────────────────────────────── + +macro_rules! impl_try_from_field_for_number { + ($($ty:ty),* $(,)?) => { + $( + impl TryFromFieldWithState for $ty { + async fn try_from_field_with_state( + field: Field<'_>, + limit_bytes: Option, + _state: &S, + ) -> Result { + let (field_name, data) = read_field_data(field, limit_bytes).await?; + let text = std::str::from_utf8(&data).map_err(|e| { + TypedMultipartError::WrongFieldType { + field_name: field_name.clone(), + wanted: stringify!($ty).to_string(), + source: e.to_string(), + } + })?; + text.trim().parse::<$ty>().map_err(|e| { + TypedMultipartError::WrongFieldType { + field_name, + wanted: stringify!($ty).to_string(), + source: e.to_string(), + } + }) + } + } + )* + }; +} + +impl_try_from_field_for_number!( + i8, i16, i32, i64, i128, u8, u16, u32, u64, u128, isize, usize, f32, f64, +); + +// ─── char ─────────────────────────────────────────────────────────────────── + +impl TryFromFieldWithState for char { + async fn try_from_field_with_state( + field: Field<'_>, + limit_bytes: Option, + _state: &S, + ) -> Result { + let (field_name, data) = read_field_data(field, limit_bytes).await?; + let text = std::str::from_utf8(&data).map_err(|e| TypedMultipartError::WrongFieldType { + field_name: field_name.clone(), + wanted: "char".to_string(), + source: e.to_string(), + })?; + let mut chars = text.chars(); + match (chars.next(), chars.next()) { + (Some(c), None) => Ok(c), + _ => Err(TypedMultipartError::WrongFieldType { + field_name, + wanted: "char".to_string(), + source: "expected exactly one character".to_string(), + }), + } + } +} + +// ─── NamedTempFile ────────────────────────────────────────────────────────── + +impl TryFromFieldWithState for tempfile::NamedTempFile { + async fn try_from_field_with_state( + mut field: Field<'_>, + limit_bytes: Option, + _state: &S, + ) -> Result { + let field_name = field.name().unwrap_or_default().to_string(); + let mut temp = Self::new().map_err(|e| TypedMultipartError::Other { + source: e.to_string(), + })?; + + let mut total = 0usize; + while let Some(chunk) = field.chunk().await? { + total += chunk.len(); + if let Some(limit) = limit_bytes + && total > limit + { + return Err(TypedMultipartError::FieldTooLarge { + field_name, + limit_bytes: limit, + }); + } + std::io::Write::write_all(&mut temp, &chunk).map_err(|e| { + TypedMultipartError::Other { + source: e.to_string(), + } + })?; + } + + Ok(temp) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use axum::http::StatusCode; + use axum::response::IntoResponse; + + #[test] + fn test_str_to_bool_truthy() { + for val in &[ + "true", "True", "TRUE", "yes", "Yes", "y", "Y", "1", "on", "ON", + ] { + assert_eq!(str_to_bool(val), Some(true), "expected true for `{val}`"); + } + } + + #[test] + fn test_str_to_bool_falsy() { + for val in &[ + "false", "False", "FALSE", "no", "No", "n", "N", "0", "off", "OFF", + ] { + assert_eq!(str_to_bool(val), Some(false), "expected false for `{val}`"); + } + } + + #[test] + fn test_str_to_bool_invalid() { + for val in &["maybe", "2", "", "yep", "nah"] { + assert_eq!(str_to_bool(val), None, "expected None for `{val}`"); + } + } + + // ─── Display tests for all error variants ─────────────────────────── + + #[test] + fn test_error_display() { + let err = TypedMultipartError::MissingField { + field_name: "name".to_string(), + }; + assert_eq!(err.to_string(), "Missing field: `name`"); + + let err = TypedMultipartError::FieldTooLarge { + field_name: "file".to_string(), + limit_bytes: 1024, + }; + assert_eq!( + err.to_string(), + "Field `file` exceeds size limit of 1024 bytes" + ); + + let err = TypedMultipartError::WrongFieldType { + field_name: "age".to_string(), + wanted: "i32".to_string(), + source: "invalid digit".to_string(), + }; + assert_eq!( + err.to_string(), + "Wrong type for field `age` (expected i32): invalid digit" + ); + } + + #[test] + fn test_error_display_duplicate_field() { + let err = TypedMultipartError::DuplicateField { + field_name: "email".to_string(), + }; + assert_eq!(err.to_string(), "Duplicate field: `email`"); + } + + #[test] + fn test_error_display_unknown_field() { + let err = TypedMultipartError::UnknownField { + field_name: "foo".to_string(), + }; + assert_eq!(err.to_string(), "Unknown field: `foo`"); + } + + #[test] + fn test_error_display_invalid_enum_value() { + let err = TypedMultipartError::InvalidEnumValue { + field_name: "status".to_string(), + value: "maybe".to_string(), + }; + assert_eq!( + err.to_string(), + "Invalid enum value `maybe` for field `status`" + ); + } + + #[test] + fn test_error_display_nameless_field() { + let err = TypedMultipartError::NamelessField; + assert_eq!(err.to_string(), "Encountered a field without a name"); + } + + #[test] + fn test_error_display_other() { + let err = TypedMultipartError::Other { + source: "something went wrong".to_string(), + }; + assert_eq!(err.to_string(), "something went wrong"); + } + + // ─── IntoResponse status code tests ───────────────────────────────── + + #[test] + fn test_into_response_duplicate_field() { + let err = TypedMultipartError::DuplicateField { + field_name: "x".to_string(), + }; + let resp = err.into_response(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + } + + #[test] + fn test_into_response_unknown_field() { + let err = TypedMultipartError::UnknownField { + field_name: "x".to_string(), + }; + let resp = err.into_response(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + } + + #[test] + fn test_into_response_invalid_enum_value() { + let err = TypedMultipartError::InvalidEnumValue { + field_name: "x".to_string(), + value: "bad".to_string(), + }; + let resp = err.into_response(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + } + + #[test] + fn test_into_response_nameless_field() { + let err = TypedMultipartError::NamelessField; + let resp = err.into_response(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + } + + #[test] + fn test_into_response_wrong_field_type() { + let err = TypedMultipartError::WrongFieldType { + field_name: "age".to_string(), + wanted: "i32".to_string(), + source: "err".to_string(), + }; + let resp = err.into_response(); + assert_eq!(resp.status(), StatusCode::UNSUPPORTED_MEDIA_TYPE); + } + + #[test] + fn test_into_response_field_too_large() { + let err = TypedMultipartError::FieldTooLarge { + field_name: "file".to_string(), + limit_bytes: 100, + }; + let resp = err.into_response(); + assert_eq!(resp.status(), StatusCode::PAYLOAD_TOO_LARGE); + } + + #[test] + fn test_into_response_other() { + let err = TypedMultipartError::Other { + source: "err".to_string(), + }; + let resp = err.into_response(); + assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); + } + + #[test] + fn test_into_response_missing_field() { + let err = TypedMultipartError::MissingField { + field_name: "x".to_string(), + }; + let resp = err.into_response(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + } + + // ─── Error trait ──────────────────────────────────────────────────── + + #[test] + fn test_error_trait_is_implemented() { + let err: Box = Box::new(TypedMultipartError::Other { + source: "test".to_string(), + }); + assert_eq!(err.to_string(), "test"); + } + + // ─── TypedMultipart Deref / DerefMut ──────────────────────────────── + + #[test] + fn test_typed_multipart_deref() { + let tm = TypedMultipart("hello".to_string()); + // Deref: &TypedMultipart → &String + assert_eq!(&*tm, "hello"); + assert_eq!(tm.len(), 5); // auto-deref to String method + } + + #[test] + fn test_typed_multipart_deref_mut() { + let mut tm = TypedMultipart(vec![1, 2, 3]); + // DerefMut: &mut TypedMultipart> → &mut Vec + tm.push(4); + assert_eq!(&*tm, &[1, 2, 3, 4]); + } +} diff --git a/crates/vespera_macro/src/lib.rs b/crates/vespera_macro/src/lib.rs index ebed4a5..89d60b0 100644 --- a/crates/vespera_macro/src/lib.rs +++ b/crates/vespera_macro/src/lib.rs @@ -49,6 +49,7 @@ mod metadata; mod method; mod openapi_generator; +mod multipart_impl; mod parser; mod route; mod route_impl; @@ -92,6 +93,31 @@ pub fn derive_schema(input: TokenStream) -> TokenStream { TokenStream::from(expanded) } +/// Derive macro for `Multipart` with serde attribute support. +/// +/// This is vespera's re-implementation of `axum_typed_multipart`'s derive macro +/// that natively supports `#[serde(rename_all)]` and `#[serde(rename)]` for +/// field name resolution in multipart form data. +/// +/// # Supported Attributes +/// +/// **Struct-level:** +/// - `#[serde(rename_all = "camelCase")]` — rename all fields (highest priority) +/// - `#[try_from_multipart(rename_all = "camelCase")]` — fallback rename +/// - `#[try_from_multipart(strict)]` — reject unknown/duplicate fields +/// +/// **Field-level:** +/// - `#[form_data(field_name = "...")]` — explicit field name override +/// - `#[serde(rename = "...")]` — serde field rename +/// - `#[form_data(limit = "10MiB")]` — field size limit +/// - `#[form_data(default)]` — use `Default::default()` when missing +#[cfg(not(tarpaulin_include))] +#[proc_macro_derive(Multipart, attributes(serde, form_data, try_from_multipart))] +pub fn derive_multipart(input: TokenStream) -> TokenStream { + let input = syn::parse_macro_input!(input as syn::DeriveInput); + TokenStream::from(multipart_impl::process_derive(&input)) +} + /// Generate an `OpenAPI` Schema from a type with optional field filtering. /// /// This macro creates a `vespera::schema::Schema` struct at compile time diff --git a/crates/vespera_macro/src/multipart_impl.rs b/crates/vespera_macro/src/multipart_impl.rs new file mode 100644 index 0000000..923669f --- /dev/null +++ b/crates/vespera_macro/src/multipart_impl.rs @@ -0,0 +1,1172 @@ +//! Vespera's `Multipart` derive macro implementation. +//! +//! This is a re-implementation of `axum_typed_multipart`'s derive macro that +//! natively supports `#[serde(rename_all)]` and `#[serde(rename)]` attributes +//! for field name resolution in multipart form data. +//! +//! ## Why? +//! +//! `axum_typed_multipart`'s derive macro only reads `#[try_from_multipart(rename_all)]` +//! and ignores `#[serde(rename_all)]`. This causes a mismatch: the OpenAPI spec +//! (generated by `Schema` derive) shows camelCase field names, but the runtime +//! multipart parser expects snake_case Rust field names. +//! +//! ## Field Name Resolution Priority +//! +//! 1. `#[form_data(field_name = "...")]` — explicit override (highest priority) +//! 2. `#[serde(rename = "...")]` — serde field rename +//! 3. `#[serde(rename_all = "...")]` or `#[try_from_multipart(rename_all = "...")]` applied to Rust name +//! 4. Rust field name as-is (lowest priority) + +use proc_macro2::TokenStream; +use quote::quote; +use syn::{DeriveInput, Fields, Type}; + +use crate::parser::{extract_default, extract_field_rename, extract_rename_all, rename_field}; + +/// Collected codegen fragments for each struct field. +struct FieldCodegen<'a> { + declarations: Vec, + assignments: Vec, + post_loop: Vec, + idents: Vec<&'a syn::Ident>, +} + +/// How a missing field should be handled. +enum DefaultKind { + /// No default — field is required; emit `MissingField` error. + None, + /// Use `Default::default()` — from `#[serde(default)]` or `#[form_data(default)]`. + Trait, + /// Call a custom function — from `#[serde(default = "path::to::fn")]`. + Function(String), +} + +/// Process all named fields into codegen fragments. +fn process_fields<'a>( + fields: impl Iterator, + rename_all: Option<&str>, + strict: bool, + struct_default: bool, +) -> FieldCodegen<'a> { + let mut cg = FieldCodegen { + declarations: Vec::new(), + assignments: Vec::new(), + post_loop: Vec::new(), + idents: Vec::new(), + }; + + for field in fields { + let ident = field.ident.as_ref().unwrap(); + let ty = &field.ty; + let is_vec = is_vec_type(ty); + let is_option = is_option_type(ty); + let field_name = resolve_field_name(ident, &field.attrs, rename_all); + let limit_tokens = extract_limit_tokens(&field.attrs); + let default_kind = resolve_default_kind(&field.attrs, struct_default); + + // The concrete type for TryFromFieldWithState turbofish. For Option + // and Vec the derive wraps the parsed value, so the trait Self is T. + let parse_ty = if is_option || is_vec { + extract_inner_generic(ty).unwrap_or_else(|| ty.clone()) + } else { + ty.clone() + }; + + // Variable declaration + if is_vec { + cg.declarations + .push(quote! { let mut #ident: #ty = std::vec::Vec::new(); }); + } else if is_option { + cg.declarations + .push(quote! { let mut #ident: #ty = std::option::Option::None; }); + } else { + cg.declarations.push( + quote! { let mut #ident: std::option::Option<#ty> = std::option::Option::None; }, + ); + } + + // Field value parsing — explicit turbofish types are required because + // RPITIT opaque return types prevent the compiler from inferring + // `TryFromFieldWithState::Self` through `.await`. + let try_from_call = quote! { <#parse_ty as vespera::multipart::TryFromFieldWithState<__VesperaS__>>::try_from_field_with_state }; + let parse_value = quote! { #try_from_call(__field__, #limit_tokens, __state__).await? }; + + let assignment = if is_vec { + quote! { #ident.push(#parse_value); } + } else if strict { + let set_value = quote! { #ident = std::option::Option::Some(#parse_value) }; + let dup_err = quote! { return std::result::Result::Err(vespera::multipart::TypedMultipartError::DuplicateField { field_name: std::string::String::from(#field_name) }) }; + quote! { if #ident.is_none() { #set_value ; } else { #dup_err ; } } + } else { + quote! { #ident = std::option::Option::Some(#parse_value); } + }; + + let field_match = quote! { if __field_name__ == #field_name { #assignment } }; + cg.assignments.push(field_match); + + // Post-loop: required field checks / defaults + if !is_option && !is_vec { + match &default_kind { + DefaultKind::Trait => { + cg.post_loop.push(quote! { + let #ident: #ty = #ident.unwrap_or_default(); + }); + } + DefaultKind::Function(fn_path) => { + let path: syn::ExprPath = + syn::parse_str(fn_path).expect("invalid default function path"); + cg.post_loop.push(quote! { + let #ident: #ty = #ident.unwrap_or_else(#path); + }); + } + DefaultKind::None => { + cg.post_loop.push(quote! { + let #ident = #ident.ok_or( + vespera::multipart::TypedMultipartError::MissingField { + field_name: std::string::String::from(#field_name) + } + )?; + }); + } + } + } + + cg.idents.push(ident); + } + + cg +} + +/// Process the `#[derive(TryFromMultipart)]` macro input. +pub fn process_derive(input: &DeriveInput) -> TokenStream { + let struct_name = &input.ident; + let rename_all = extract_rename_all(&input.attrs); + let strict = extract_strict(&input.attrs); + let struct_default = extract_struct_default(&input.attrs); + + let fields = match &input.data { + syn::Data::Struct(data) => match &data.fields { + Fields::Named(named) => &named.named, + _ => { + return syn::Error::new_spanned( + &input.ident, + "Multipart only supports structs with named fields", + ) + .to_compile_error(); + } + }, + _ => { + return syn::Error::new_spanned( + &input.ident, + "Multipart can only be derived for structs", + ) + .to_compile_error(); + } + }; + + let mut cg = process_fields(fields.iter(), rename_all.as_deref(), strict, struct_default); + + if strict { + cg.assignments.push(quote! { + { + return std::result::Result::Err( + vespera::multipart::TypedMultipartError::UnknownField { + field_name: __field_name__ + } + ); + } + }); + } + + let missing_name_fallback = if strict { + quote! { + return std::result::Result::Err( + vespera::multipart::TypedMultipartError::NamelessField + ) + } + } else { + quote! { continue } + }; + + let FieldCodegen { + declarations, + assignments, + post_loop, + idents, + .. + } = &cg; + + quote! { + impl<__VesperaS__: Send + Sync> vespera::multipart::TryFromMultipartWithState<__VesperaS__> for #struct_name { + async fn try_from_multipart_with_state( + __multipart__: &mut vespera::axum::extract::Multipart, + __state__: &__VesperaS__, + ) -> std::result::Result { + #(#declarations)* + + while let std::option::Option::Some(__field__) = __multipart__ + .next_field().await + .map_err(vespera::multipart::TypedMultipartError::from)? { + let __field_name__ = match __field__.name() { + | std::option::Option::Some("") + | std::option::Option::None => #missing_name_fallback, + | std::option::Option::Some(__name__) => __name__.to_string(), + }; + + #(#assignments) else * + } + + #(#post_loop)* + + std::result::Result::Ok(Self { #(#idents),* }) + } + } + } +} + +// ─── Field Name Resolution ────────────────────────────────────────────────── + +/// Resolve the multipart field name using serde + form_data attributes. +/// +/// Priority: +/// 1. `#[form_data(field_name = "...")]` +/// 2. `#[serde(rename = "...")]` +/// 3. struct-level `rename_all` applied to Rust field name +/// 4. Rust field name as-is +fn resolve_field_name( + ident: &syn::Ident, + attrs: &[syn::Attribute], + rename_all: Option<&str>, +) -> String { + // 1. Explicit form_data override + if let Some(name) = extract_form_data_field_name(attrs) { + return name; + } + + // 2. Serde field rename + if let Some(name) = extract_field_rename(attrs) { + return name; + } + + // 3. Apply rename_all to Rust field name + let rust_name = strip_raw_prefix(&ident.to_string()); + rename_field(&rust_name, rename_all) +} + +// ─── Attribute Extraction ─────────────────────────────────────────────────── + +/// Extract `field_name` from `#[form_data(field_name = "...")]`. +fn extract_form_data_field_name(attrs: &[syn::Attribute]) -> Option { + for attr in attrs { + if attr.path().is_ident("form_data") { + let mut found = None; + let _ = attr.parse_nested_meta(|meta| { + if meta.path.is_ident("field_name") + && let Ok(value) = meta.value() + && let Ok(lit) = value.parse::() + { + found = Some(lit.value()); + } + Ok(()) + }); + if found.is_some() { + return found; + } + } + } + None +} + +/// Extract `strict` flag from `#[try_from_multipart(strict)]`. +fn extract_strict(attrs: &[syn::Attribute]) -> bool { + for attr in attrs { + if attr.path().is_ident("try_from_multipart") { + let mut strict = false; + let _ = attr.parse_nested_meta(|meta| { + if meta.path.is_ident("strict") { + strict = true; + } + Ok(()) + }); + if strict { + return true; + } + } + } + false +} + +/// Extract `limit` from `#[form_data(limit = "10MiB")]` and emit as `Option` tokens. +fn extract_limit_tokens(attrs: &[syn::Attribute]) -> TokenStream { + for attr in attrs { + if attr.path().is_ident("form_data") { + let mut limit_str = None; + let _ = attr.parse_nested_meta(|meta| { + if meta.path.is_ident("limit") + && let Ok(value) = meta.value() + && let Ok(lit) = value.parse::() + { + limit_str = Some(lit.value()); + } + Ok(()) + }); + if let Some(s) = limit_str { + if s == "unlimited" { + return quote! { std::option::Option::None }; + } + if let Some(bytes) = parse_byte_unit(&s) { + return quote! { std::option::Option::Some(#bytes) }; + } + } + } + } + // Default: no limit (None) + quote! { std::option::Option::None } +} + +/// Resolve the default behavior for a field. +/// +/// Priority: +/// 1. `#[form_data(default)]` — explicit form_data override (bare default) +/// 2. `#[serde(default)]` — bare default via `Default::default()` +/// 3. `#[serde(default = "fn_path")]` — custom default function +/// 4. Struct-level `#[serde(default)]` — all fields get `Default::default()` +/// 5. No default — field is required +fn resolve_default_kind(attrs: &[syn::Attribute], struct_default: bool) -> DefaultKind { + // 1. Check #[form_data(default)] + if extract_form_data_default(attrs) { + return DefaultKind::Trait; + } + + // 2-3. Check #[serde(default)] or #[serde(default = "fn")] + if let Some(serde_default) = extract_default(attrs) { + return serde_default.map_or(DefaultKind::Trait, DefaultKind::Function); + } + + // 4. Struct-level #[serde(default)] + if struct_default { + return DefaultKind::Trait; + } + + DefaultKind::None +} + +/// Extract `default` flag from `#[form_data(default)]`. +fn extract_form_data_default(attrs: &[syn::Attribute]) -> bool { + for attr in attrs { + if attr.path().is_ident("form_data") { + let mut has_default = false; + let _ = attr.parse_nested_meta(|meta| { + if meta.path.is_ident("default") { + has_default = true; + } + Ok(()) + }); + if has_default { + return true; + } + } + } + false +} + +/// Check if the struct has `#[serde(default)]` at the struct level. +fn extract_struct_default(attrs: &[syn::Attribute]) -> bool { + // Reuse extract_default — if it returns Some(None), it's bare #[serde(default)] + // For struct-level, we only support bare default (no custom function) + extract_default(attrs).is_some() +} + +// ─── Type Utilities ───────────────────────────────────────────────────────── + +/// Extract the first generic type argument from a type like `Option` or `Vec`. +fn extract_inner_generic(ty: &Type) -> Option { + let Type::Path(type_path) = ty else { + return None; + }; + let segment = type_path.path.segments.last()?; + if let syn::PathArguments::AngleBracketed(args) = &segment.arguments + && let Some(syn::GenericArgument::Type(inner)) = args.args.first() + { + return Some(inner.clone()); + } + None +} + +/// Check if a type matches `Option`. +fn is_option_type(ty: &Type) -> bool { + matches_type_name( + ty, + &["Option", "std::option::Option", "core::option::Option"], + ) +} + +/// Check if a type matches `Vec`. +fn is_vec_type(ty: &Type) -> bool { + matches_type_name(ty, &["Vec", "std::vec::Vec"]) +} + +/// Check if a type's path matches any of the given names. +fn matches_type_name(ty: &Type, names: &[&str]) -> bool { + let path = match ty { + Type::Path(type_path) if type_path.qself.is_none() => &type_path.path, + _ => return false, + }; + let sig = path + .segments + .iter() + .map(|s| s.ident.to_string()) + .collect::>() + .join("::"); + names.contains(&sig.as_str()) +} + +/// Strip leading `r#` from raw identifiers. +fn strip_raw_prefix(s: &str) -> String { + s.strip_prefix("r#").unwrap_or(s).to_string() +} + +// ─── Byte Unit Parser ─────────────────────────────────────────────────────── + +/// Parse a human-readable byte unit string into bytes. +/// +/// Supports: `"10MiB"`, `"1GB"`, `"500KB"`, `"1024"`, `"unlimited"`. +fn parse_byte_unit(s: &str) -> Option { + let s = s.trim(); + + // Binary and decimal suffixes, longest first to avoid prefix collisions + let suffixes: &[(&str, usize)] = &[ + ("GiB", 1024 * 1024 * 1024), + ("MiB", 1024 * 1024), + ("KiB", 1024), + ("GB", 1_000_000_000), + ("MB", 1_000_000), + ("KB", 1_000), + ("B", 1), + ]; + + for (suffix, multiplier) in suffixes { + if let Some(num_str) = s.strip_suffix(suffix) { + return num_str.trim().parse::().ok().map(|n| n * multiplier); + } + } + + // Plain number (bytes) + s.parse::().ok() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_byte_unit() { + assert_eq!(parse_byte_unit("10MiB"), Some(10 * 1024 * 1024)); + assert_eq!(parse_byte_unit("50MiB"), Some(50 * 1024 * 1024)); + assert_eq!(parse_byte_unit("1GB"), Some(1_000_000_000)); + assert_eq!(parse_byte_unit("500KB"), Some(500_000)); + assert_eq!(parse_byte_unit("1024"), Some(1024)); + assert_eq!(parse_byte_unit("0"), Some(0)); + assert_eq!(parse_byte_unit("invalid"), None); + } + + #[test] + fn test_parse_byte_unit_all_suffixes() { + assert_eq!(parse_byte_unit("1GiB"), Some(1024 * 1024 * 1024)); + assert_eq!(parse_byte_unit("2KiB"), Some(2 * 1024)); + assert_eq!(parse_byte_unit("3MB"), Some(3_000_000)); + assert_eq!(parse_byte_unit("4B"), Some(4)); + assert_eq!(parse_byte_unit(" 5MiB "), Some(5 * 1024 * 1024)); + } + + #[test] + fn test_strip_raw_prefix() { + assert_eq!(strip_raw_prefix("r#type"), "type"); + assert_eq!(strip_raw_prefix("normal"), "normal"); + } + + // ─── extract_inner_generic ────────────────────────────────────────── + + #[test] + fn test_extract_inner_generic_option() { + let ty: syn::Type = syn::parse_str("Option").unwrap(); + let inner = extract_inner_generic(&ty).unwrap(); + assert_eq!(quote!(#inner).to_string(), "String"); + } + + #[test] + fn test_extract_inner_generic_vec() { + let ty: syn::Type = syn::parse_str("Vec").unwrap(); + let inner = extract_inner_generic(&ty).unwrap(); + assert_eq!(quote!(#inner).to_string(), "i32"); + } + + #[test] + fn test_extract_inner_generic_no_generics() { + let ty: syn::Type = syn::parse_str("String").unwrap(); + assert!(extract_inner_generic(&ty).is_none()); + } + + #[test] + fn test_extract_inner_generic_non_path() { + let ty: syn::Type = syn::parse_str("(i32, String)").unwrap(); + assert!(extract_inner_generic(&ty).is_none()); + } + + // ─── is_option_type / is_vec_type ─────────────────────────────────── + + #[test] + fn test_is_option_type() { + let ty: syn::Type = syn::parse_str("Option").unwrap(); + assert!(is_option_type(&ty)); + + let ty: syn::Type = syn::parse_str("std::option::Option").unwrap(); + assert!(is_option_type(&ty)); + + let ty: syn::Type = syn::parse_str("Vec").unwrap(); + assert!(!is_option_type(&ty)); + + let ty: syn::Type = syn::parse_str("String").unwrap(); + assert!(!is_option_type(&ty)); + } + + #[test] + fn test_is_vec_type() { + let ty: syn::Type = syn::parse_str("Vec").unwrap(); + assert!(is_vec_type(&ty)); + + let ty: syn::Type = syn::parse_str("std::vec::Vec").unwrap(); + assert!(is_vec_type(&ty)); + + let ty: syn::Type = syn::parse_str("Option").unwrap(); + assert!(!is_vec_type(&ty)); + + let ty: syn::Type = syn::parse_str("String").unwrap(); + assert!(!is_vec_type(&ty)); + } + + // ─── matches_type_name ────────────────────────────────────────────── + + #[test] + fn test_matches_type_name_simple() { + let ty: syn::Type = syn::parse_str("Option").unwrap(); + assert!(matches_type_name(&ty, &["Option"])); + assert!(!matches_type_name(&ty, &["Vec"])); + } + + #[test] + fn test_matches_type_name_qualified() { + let ty: syn::Type = syn::parse_str("std::option::Option").unwrap(); + assert!(matches_type_name(&ty, &["std::option::Option"])); + assert!(!matches_type_name(&ty, &["Option"])); // qualified doesn't match simple + } + + #[test] + fn test_matches_type_name_non_path() { + let ty: syn::Type = syn::parse_str("(i32, String)").unwrap(); + assert!(!matches_type_name(&ty, &["Option", "Vec"])); + } + + // ─── extract_form_data_field_name ─────────────────────────────────── + + fn parse_field(code: &str) -> syn::Field { + let input: syn::DeriveInput = syn::parse_str(&format!("struct T {{ {code} }}")).unwrap(); + match &input.data { + syn::Data::Struct(s) => match &s.fields { + Fields::Named(n) => n.named.first().unwrap().clone(), + _ => unreachable!(), + }, + _ => unreachable!(), + } + } + + fn parse_attrs(code: &str) -> Vec { + parse_field(code).attrs + } + + #[test] + fn test_extract_form_data_field_name_present() { + let attrs = parse_attrs(r#"#[form_data(field_name = "custom")] pub x: String"#); + assert_eq!( + extract_form_data_field_name(&attrs), + Some("custom".to_string()) + ); + } + + #[test] + fn test_extract_form_data_field_name_absent() { + let attrs = parse_attrs("pub x: String"); + assert_eq!(extract_form_data_field_name(&attrs), None); + } + + #[test] + fn test_extract_form_data_field_name_other_form_data_attr() { + let attrs = parse_attrs(r#"#[form_data(limit = "100")] pub x: String"#); + assert_eq!(extract_form_data_field_name(&attrs), None); + } + + // ─── extract_strict ───────────────────────────────────────────────── + + fn parse_struct_attrs(code: &str) -> Vec { + let input: syn::DeriveInput = syn::parse_str(code).unwrap(); + input.attrs + } + + #[test] + fn test_extract_strict_present() { + let attrs = parse_struct_attrs("#[try_from_multipart(strict)] struct T { }"); + assert!(extract_strict(&attrs)); + } + + #[test] + fn test_extract_strict_absent() { + let attrs = parse_struct_attrs("struct T { }"); + assert!(!extract_strict(&attrs)); + } + + #[test] + fn test_extract_strict_other_attr() { + let attrs = + parse_struct_attrs("#[try_from_multipart(rename_all = \"camelCase\")] struct T { }"); + assert!(!extract_strict(&attrs)); + } + + // ─── extract_form_data_default ────────────────────────────────────── + + #[test] + fn test_extract_form_data_default_present() { + let attrs = parse_attrs("#[form_data(default)] pub x: i32"); + assert!(extract_form_data_default(&attrs)); + } + + #[test] + fn test_extract_form_data_default_absent() { + let attrs = parse_attrs("pub x: i32"); + assert!(!extract_form_data_default(&attrs)); + } + + #[test] + fn test_extract_form_data_default_other_form_data() { + let attrs = parse_attrs(r#"#[form_data(limit = "100")] pub x: i32"#); + assert!(!extract_form_data_default(&attrs)); + } + + // ─── extract_struct_default ───────────────────────────────────────── + + #[test] + fn test_extract_struct_default_present() { + let attrs = parse_struct_attrs("#[serde(default)] struct T { }"); + assert!(extract_struct_default(&attrs)); + } + + #[test] + fn test_extract_struct_default_absent() { + let attrs = parse_struct_attrs("struct T { }"); + assert!(!extract_struct_default(&attrs)); + } + + // ─── resolve_default_kind ─────────────────────────────────────────── + + #[test] + fn test_resolve_default_kind_none() { + let attrs = parse_attrs("pub x: i32"); + assert!(matches!( + resolve_default_kind(&attrs, false), + DefaultKind::None + )); + } + + #[test] + fn test_resolve_default_kind_serde_default() { + let attrs = parse_attrs("#[serde(default)] pub x: i32"); + assert!(matches!( + resolve_default_kind(&attrs, false), + DefaultKind::Trait + )); + } + + #[test] + fn test_resolve_default_kind_serde_default_fn() { + let attrs = parse_attrs(r#"#[serde(default = "my_fn")] pub x: i32"#); + assert!( + matches!(resolve_default_kind(&attrs, false), DefaultKind::Function(ref f) if f == "my_fn") + ); + } + + #[test] + fn test_resolve_default_kind_form_data_default() { + let attrs = parse_attrs("#[form_data(default)] pub x: i32"); + assert!(matches!( + resolve_default_kind(&attrs, false), + DefaultKind::Trait + )); + } + + #[test] + fn test_resolve_default_kind_struct_level() { + let attrs = parse_attrs("pub x: i32"); + assert!(matches!( + resolve_default_kind(&attrs, true), + DefaultKind::Trait + )); + } + + #[test] + fn test_resolve_default_kind_form_data_overrides_struct_default() { + // form_data(default) takes priority, but result is the same (Trait) + let attrs = parse_attrs("#[form_data(default)] pub x: i32"); + assert!(matches!( + resolve_default_kind(&attrs, true), + DefaultKind::Trait + )); + } + + // ─── resolve_field_name ───────────────────────────────────────────── + + #[test] + fn test_resolve_field_name_plain() { + let field = parse_field("pub my_field: String"); + let name = resolve_field_name(field.ident.as_ref().unwrap(), &field.attrs, None); + assert_eq!(name, "my_field"); + } + + #[test] + fn test_resolve_field_name_rename_all() { + let field = parse_field("pub my_field: String"); + let name = resolve_field_name( + field.ident.as_ref().unwrap(), + &field.attrs, + Some("camelCase"), + ); + assert_eq!(name, "myField"); + } + + #[test] + fn test_resolve_field_name_serde_rename() { + let field = parse_field(r#"#[serde(rename = "custom")] pub my_field: String"#); + let name = resolve_field_name( + field.ident.as_ref().unwrap(), + &field.attrs, + Some("camelCase"), + ); + assert_eq!(name, "custom"); // explicit rename beats rename_all + } + + #[test] + fn test_resolve_field_name_form_data_field_name() { + let field = parse_field( + r#"#[form_data(field_name = "override")] #[serde(rename = "serde_name")] pub my_field: String"#, + ); + let name = resolve_field_name( + field.ident.as_ref().unwrap(), + &field.attrs, + Some("camelCase"), + ); + assert_eq!(name, "override"); // form_data field_name beats everything + } + + // ─── extract_limit_tokens ─────────────────────────────────────────── + + #[test] + fn test_extract_limit_tokens_none() { + let attrs = parse_attrs("pub x: String"); + let tokens = extract_limit_tokens(&attrs); + assert_eq!(tokens.to_string(), "std :: option :: Option :: None"); + } + + #[test] + fn test_extract_limit_tokens_with_value() { + let attrs = parse_attrs(r#"#[form_data(limit = "100")] pub x: String"#); + let tokens = extract_limit_tokens(&attrs); + assert_eq!( + tokens.to_string(), + "std :: option :: Option :: Some (100usize)" + ); + } + + #[test] + fn test_extract_limit_tokens_unlimited() { + let attrs = parse_attrs(r#"#[form_data(limit = "unlimited")] pub x: String"#); + let tokens = extract_limit_tokens(&attrs); + assert_eq!(tokens.to_string(), "std :: option :: Option :: None"); + } + + #[test] + fn test_extract_limit_tokens_mib() { + let attrs = parse_attrs(r#"#[form_data(limit = "10MiB")] pub x: String"#); + let tokens = extract_limit_tokens(&attrs); + let expected = 10 * 1024 * 1024; + assert_eq!( + tokens.to_string(), + format!("std :: option :: Option :: Some ({expected}usize)") + ); + } + + // ─── process_derive ───────────────────────────────────────────────── + + #[test] + fn test_process_derive_basic_struct() { + let input: syn::DeriveInput = + syn::parse_str("struct MyForm { pub name: String, pub age: i32 }").unwrap(); + let tokens = process_derive(&input); + let code = tokens.to_string(); + assert!( + code.contains("TryFromMultipartWithState"), + "should generate trait impl" + ); + assert!(code.contains("MyForm"), "should reference the struct name"); + assert!(code.contains("\"name\""), "should reference field name"); + assert!(code.contains("\"age\""), "should reference field name"); + } + + #[test] + fn test_process_derive_with_option_field() { + let input: syn::DeriveInput = + syn::parse_str("struct MyForm { pub name: String, pub bio: Option }").unwrap(); + let tokens = process_derive(&input); + let code = tokens.to_string(); + assert!(code.contains("TryFromMultipartWithState")); + // Option fields get initialized to None, no MissingField check + assert!(code.contains("Option :: None")); + } + + #[test] + fn test_process_derive_with_vec_field() { + let input: syn::DeriveInput = + syn::parse_str("struct MyForm { pub name: String, pub tags: Vec }").unwrap(); + let tokens = process_derive(&input); + let code = tokens.to_string(); + assert!( + code.contains("Vec :: new"), + "Vec fields should be initialized with Vec::new()" + ); + assert!(code.contains("push"), "Vec fields should use push()"); + } + + #[test] + fn test_process_derive_strict_mode() { + let input: syn::DeriveInput = + syn::parse_str("#[try_from_multipart(strict)] struct MyForm { pub name: String }") + .unwrap(); + let tokens = process_derive(&input); + let code = tokens.to_string(); + assert!( + code.contains("DuplicateField"), + "strict mode should check for duplicates" + ); + assert!( + code.contains("UnknownField"), + "strict mode should reject unknown fields" + ); + assert!( + code.contains("NamelessField"), + "strict mode should reject nameless fields" + ); + } + + #[test] + fn test_process_derive_with_rename_all() { + let input: syn::DeriveInput = syn::parse_str( + r#"#[serde(rename_all = "camelCase")] struct MyForm { pub user_name: String }"#, + ) + .unwrap(); + let tokens = process_derive(&input); + let code = tokens.to_string(); + assert!( + code.contains("\"userName\""), + "rename_all should convert to camelCase" + ); + } + + #[test] + fn test_process_derive_with_serde_default() { + let input: syn::DeriveInput = + syn::parse_str("#[serde(default)] struct MyForm { pub count: i32 }").unwrap(); + let tokens = process_derive(&input); + let code = tokens.to_string(); + assert!( + code.contains("unwrap_or_default"), + "struct-level default should use unwrap_or_default" + ); + } + + #[test] + fn test_process_derive_with_field_default_fn() { + let input: syn::DeriveInput = + syn::parse_str(r#"struct MyForm { #[serde(default = "my_default")] pub val: String }"#) + .unwrap(); + let tokens = process_derive(&input); + let code = tokens.to_string(); + assert!( + code.contains("unwrap_or_else"), + "field default fn should use unwrap_or_else" + ); + assert!( + code.contains("my_default"), + "should reference the default function" + ); + } + + #[test] + fn test_process_derive_non_struct_errors() { + let input: syn::DeriveInput = syn::parse_str("enum Foo { A, B }").unwrap(); + let tokens = process_derive(&input); + let code = tokens.to_string(); + assert!( + code.contains("compile_error"), + "enums should produce compile error" + ); + } + + #[test] + fn test_process_derive_tuple_struct_errors() { + let input: syn::DeriveInput = syn::parse_str("struct Foo(String, i32);").unwrap(); + let tokens = process_derive(&input); + let code = tokens.to_string(); + assert!( + code.contains("compile_error"), + "tuple structs should produce compile error" + ); + } + + #[test] + fn test_process_derive_form_data_field_name() { + let input: syn::DeriveInput = syn::parse_str( + r#"struct MyForm { #[form_data(field_name = "custom")] pub data: String }"#, + ) + .unwrap(); + let tokens = process_derive(&input); + let code = tokens.to_string(); + assert!( + code.contains("\"custom\""), + "form_data field_name should be used" + ); + } + + #[test] + fn test_process_derive_form_data_default() { + let input: syn::DeriveInput = + syn::parse_str("struct MyForm { #[form_data(default)] pub count: i32 }").unwrap(); + let tokens = process_derive(&input); + let code = tokens.to_string(); + assert!( + code.contains("unwrap_or_default"), + "form_data(default) should use unwrap_or_default" + ); + } + + #[test] + fn test_process_derive_non_strict_no_duplicate_check() { + let input: syn::DeriveInput = syn::parse_str("struct MyForm { pub name: String }").unwrap(); + let tokens = process_derive(&input); + let code = tokens.to_string(); + assert!( + !code.contains("DuplicateField"), + "non-strict should not check for duplicates" + ); + assert!( + !code.contains("UnknownField"), + "non-strict should not check for unknown fields" + ); + } + + // ─── process_fields direct tests ──────────────────────────────────── + // + // Exercise process_fields directly to ensure quote! token construction + // for each branch (parse_value, strict assignment, field matching) is + // fully traced by the coverage tool. + + fn parse_fields_from(code: &str) -> syn::DeriveInput { + syn::parse_str(code).unwrap() + } + + fn get_named_fields( + input: &syn::DeriveInput, + ) -> &syn::punctuated::Punctuated { + match &input.data { + syn::Data::Struct(s) => match &s.fields { + Fields::Named(n) => &n.named, + _ => panic!("expected named fields"), + }, + _ => panic!("expected struct"), + } + } + + #[test] + fn test_process_fields_required_field_generates_parse_value() { + let input = parse_fields_from("struct T { pub name: String }"); + let fields = get_named_fields(&input); + let cg = process_fields(fields.iter(), None, false, false); + + // parse_value is interpolated into each assignment + let assignment_code = cg + .assignments + .iter() + .map(ToString::to_string) + .collect::>() + .join(" "); + assert!( + assignment_code.contains("TryFromFieldWithState"), + "parse_value should contain turbofish call" + ); + assert!( + assignment_code.contains("try_from_field_with_state"), + "should call try_from_field_with_state" + ); + assert!( + assignment_code.contains("\"name\""), + "should match on field name" + ); + + // post_loop should have MissingField check for required fields + let post_code = cg + .post_loop + .iter() + .map(ToString::to_string) + .collect::>() + .join(" "); + assert!( + post_code.contains("MissingField"), + "required field should have MissingField check" + ); + } + + #[test] + fn test_process_fields_strict_required_field_generates_duplicate_check() { + let input = parse_fields_from("struct T { pub name: String, pub age: i32 }"); + let fields = get_named_fields(&input); + let cg = process_fields(fields.iter(), None, true, false); + + // strict mode: assignments should contain is_none + DuplicateField check + let assignment_code = cg + .assignments + .iter() + .map(ToString::to_string) + .collect::>() + .join(" "); + assert!( + assignment_code.contains("is_none"), + "strict assignment should check is_none" + ); + assert!( + assignment_code.contains("DuplicateField"), + "strict assignment should have DuplicateField" + ); + assert!( + assignment_code.contains("\"name\""), + "should match name field" + ); + assert!( + assignment_code.contains("\"age\""), + "should match age field" + ); + + // Both fields should have parse_value with turbofish + assert!( + assignment_code.contains("TryFromFieldWithState"), + "should contain turbofish" + ); + } + + #[test] + fn test_process_fields_vec_field_generates_push() { + let input = parse_fields_from("struct T { pub tags: Vec }"); + let fields = get_named_fields(&input); + let cg = process_fields(fields.iter(), None, false, false); + + let decl_code = cg + .declarations + .iter() + .map(ToString::to_string) + .collect::>() + .join(" "); + assert!( + decl_code.contains("Vec :: new"), + "Vec field should initialize with Vec::new()" + ); + + let assignment_code = cg + .assignments + .iter() + .map(ToString::to_string) + .collect::>() + .join(" "); + assert!( + assignment_code.contains("push"), + "Vec field assignment should use push" + ); + + // Vec fields should NOT have post_loop (no MissingField check) + assert!( + cg.post_loop.is_empty(), + "Vec fields should not have post-loop checks" + ); + } + + #[test] + fn test_process_fields_option_field_no_missing_check() { + let input = parse_fields_from("struct T { pub bio: Option }"); + let fields = get_named_fields(&input); + let cg = process_fields(fields.iter(), None, false, false); + + let decl_code = cg + .declarations + .iter() + .map(ToString::to_string) + .collect::>() + .join(" "); + assert!( + decl_code.contains("Option :: None"), + "Option field should initialize to None" + ); + + // Option fields should NOT have post_loop + assert!( + cg.post_loop.is_empty(), + "Option fields should not have post-loop checks" + ); + } + + #[test] + fn test_process_fields_strict_vec_field_uses_push_not_duplicate() { + let input = parse_fields_from("struct T { pub tags: Vec }"); + let fields = get_named_fields(&input); + let cg = process_fields(fields.iter(), None, true, false); + + // Even in strict mode, Vec fields use push (not duplicate check) + let assignment_code = cg + .assignments + .iter() + .map(ToString::to_string) + .collect::>() + .join(" "); + assert!( + assignment_code.contains("push"), + "Vec in strict mode should still use push" + ); + assert!( + !assignment_code.contains("DuplicateField"), + "Vec should not have duplicate check" + ); + } + + #[test] + fn test_process_fields_mixed_types() { + let input = parse_fields_from( + "struct T { pub name: String, pub tags: Vec, pub bio: Option }", + ); + let fields = get_named_fields(&input); + let cg = process_fields(fields.iter(), None, false, false); + + assert_eq!(cg.idents.len(), 3, "should have 3 fields"); + assert_eq!(cg.declarations.len(), 3, "should have 3 declarations"); + assert_eq!(cg.assignments.len(), 3, "should have 3 assignments"); + // Only 'name' is required (not Option, not Vec), so 1 post_loop + assert_eq!( + cg.post_loop.len(), + 1, + "only required field should have post-loop" + ); + } +} diff --git a/crates/vespera_macro/src/parser/schema/type_schema.rs b/crates/vespera_macro/src/parser/schema/type_schema.rs index 10a3a47..bb43199 100644 --- a/crates/vespera_macro/src/parser/schema/type_schema.rs +++ b/crates/vespera_macro/src/parser/schema/type_schema.rs @@ -275,7 +275,7 @@ fn parse_type_impl( "NaiveTime" | "Time" => string_with_format("time"), // Duration types "Duration" => string_with_format("duration"), - // File upload types (axum_typed_multipart / tempfile) + // File upload types (vespera::multipart / tempfile) // FieldData → string with binary format "FieldData" | "NamedTempFile" => string_with_format("binary"), // Standard library types that should not be referenced diff --git a/crates/vespera_macro/src/schema_macro/input.rs b/crates/vespera_macro/src/schema_macro/input.rs index 4720bd5..11a4234 100644 --- a/crates/vespera_macro/src/schema_macro/input.rs +++ b/crates/vespera_macro/src/schema_macro/input.rs @@ -121,7 +121,7 @@ pub struct SchemaTypeInput { /// Serde `rename_all` strategy (e.g., "camelCase", "`snake_case`", "`PascalCase`") /// If not specified, defaults to "camelCase" when source has no `rename_all` pub rename_all: Option, - /// Whether to generate a multipart/form-data struct (derives `TryFromMultipart` instead of serde) + /// Whether to generate a multipart/form-data struct (derives `Multipart` instead of serde) /// Use `multipart` bare keyword to set this to true. pub multipart: bool, /// Whether to omit fields that have database defaults (sea_orm `default_value` or `primary_key`). @@ -295,7 +295,7 @@ impl Parse for SchemaTypeInput { rename_all = Some(rename_all_lit.value()); } "multipart" => { - // bare `multipart` - derive TryFromMultipart instead of serde + // bare `multipart` - derive Multipart instead of serde multipart = true; } "omit_default" => { diff --git a/crates/vespera_macro/src/schema_macro/mod.rs b/crates/vespera_macro/src/schema_macro/mod.rs index 5f3789f..13ed25b 100644 --- a/crates/vespera_macro/src/schema_macro/mod.rs +++ b/crates/vespera_macro/src/schema_macro/mod.rs @@ -580,14 +580,14 @@ pub fn generate_schema_type_code( // Generate the new struct (with inline types for circular relations first) let generated_tokens = if input.multipart { - // Multipart mode: derive TryFromMultipart instead of serde - // Still emit #[serde(rename_all = ...)] so Schema derive can read it for OpenAPI field naming - // (Schema derive registers `serde` as a helper attribute, so this is valid without Serialize/Deserialize) + // Multipart mode: derive Multipart instead of serde + // Emit #[serde(rename_all = ...)] so Multipart applies the rename at runtime + // AND Schema derive reads it via extract_rename_all() fallback for OpenAPI field naming quote! { #(#inline_type_definitions)* #(#struct_doc_attrs)* - #[derive(vespera::axum_typed_multipart::TryFromMultipart, #clone_derive #schema_derive)] + #[derive(vespera::Multipart, #clone_derive #schema_derive)] #schema_name_attr #[serde(rename_all = #effective_rename_all)] pub struct #new_type_name { diff --git a/crates/vespera_macro/src/schema_macro/seaorm.rs b/crates/vespera_macro/src/schema_macro/seaorm.rs index 2045aa0..ab553df 100644 --- a/crates/vespera_macro/src/schema_macro/seaorm.rs +++ b/crates/vespera_macro/src/schema_macro/seaorm.rs @@ -65,9 +65,9 @@ pub fn convert_seaorm_type_to_chrono(ty: &Type, source_module_path: &[String]) - } "DateTimeUtc" => quote! { vespera::chrono::DateTime }, "DateTimeLocal" => quote! { vespera::chrono::DateTime }, - // axum_typed_multipart types - resolve via vespera re-exports + // Multipart types - resolve via vespera::multipart "FieldData" => { - // Preserve inner generic: FieldData → vespera::axum_typed_multipart::FieldData + // Preserve inner generic: FieldData → vespera::multipart::FieldData if let syn::PathArguments::AngleBracketed(args) = &segment.arguments { let inner_args: Vec<_> = args .args @@ -82,9 +82,9 @@ pub fn convert_seaorm_type_to_chrono(ty: &Type, source_module_path: &[String]) - } }) .collect(); - quote! { vespera::axum_typed_multipart::FieldData<#(#inner_args),*> } + quote! { vespera::multipart::FieldData<#(#inner_args),*> } } else { - quote! { vespera::axum_typed_multipart::FieldData } + quote! { vespera::multipart::FieldData } } } "NamedTempFile" => quote! { vespera::tempfile::NamedTempFile }, @@ -658,13 +658,13 @@ mod tests { #[test] fn test_convert_seaorm_type_field_data_with_generic() { - // FieldData → vespera::axum_typed_multipart::FieldData + // FieldData → vespera::multipart::FieldData let ty: syn::Type = syn::parse_str("FieldData").unwrap(); let tokens = convert_seaorm_type_to_chrono(&ty, &[]); let output = tokens.to_string(); assert!( - output.contains("vespera :: axum_typed_multipart :: FieldData"), - "Should resolve FieldData via vespera re-export: {output}" + output.contains("vespera :: multipart :: FieldData"), + "Should resolve FieldData via vespera::multipart: {output}" ); assert!( output.contains("vespera :: tempfile :: NamedTempFile"), @@ -674,12 +674,12 @@ mod tests { #[test] fn test_convert_seaorm_type_field_data_without_generic() { - // FieldData (no generics) → vespera::axum_typed_multipart::FieldData + // FieldData (no generics) → vespera::multipart::FieldData let ty: syn::Type = syn::parse_str("FieldData").unwrap(); let tokens = convert_seaorm_type_to_chrono(&ty, &[]); let output = tokens.to_string(); assert!( - output.contains("vespera :: axum_typed_multipart :: FieldData"), + output.contains("vespera :: multipart :: FieldData"), "Should resolve bare FieldData: {output}" ); // Should NOT contain nested generic @@ -696,7 +696,7 @@ mod tests { let tokens = convert_seaorm_type_to_chrono(&ty, &[]); let output = tokens.to_string(); assert!( - output.contains("vespera :: axum_typed_multipart :: FieldData"), + output.contains("vespera :: multipart :: FieldData"), "Should still resolve FieldData: {output}" ); } diff --git a/crates/vespera_macro/src/schema_macro/tests.rs b/crates/vespera_macro/src/schema_macro/tests.rs index ab5ed61..ed702b9 100644 --- a/crates/vespera_macro/src/schema_macro/tests.rs +++ b/crates/vespera_macro/src/schema_macro/tests.rs @@ -1859,7 +1859,7 @@ fn test_extract_belongs_to_from_field_with_equals_value() { #[test] fn test_generate_schema_type_code_multipart_basic() { - // Tests: multipart mode generates TryFromMultipart derive, suppresses From impl + // Tests: multipart mode generates Multipart derive, suppresses From impl let storage = to_storage(vec![create_test_struct_metadata( "UploadRequest", "pub struct UploadRequest { pub name: String, pub description: Option }", @@ -1872,8 +1872,8 @@ fn test_generate_schema_type_code_multipart_basic() { assert!(result.is_ok()); let (tokens, _metadata) = result.unwrap(); let output = tokens.to_string(); - // Should derive TryFromMultipart - assert!(output.contains("TryFromMultipart")); + // Should derive Multipart + assert!(output.contains("Multipart")); // Should NOT have From impl (multipart suppresses it) assert!(!output.contains("impl From")); // Should have the struct fields @@ -1896,8 +1896,8 @@ fn test_generate_schema_type_code_multipart_with_rename() { assert!(result.is_ok()); let (tokens, _metadata) = result.unwrap(); let output = tokens.to_string(); - // Should derive TryFromMultipart - assert!(output.contains("TryFromMultipart")); + // Should derive Multipart + assert!(output.contains("Multipart")); // Should have renamed field assert!(output.contains("document_path")); // Original name should NOT appear as field @@ -1953,8 +1953,8 @@ fn test_generate_schema_type_code_multipart_skips_relations() { // Regular fields should be present assert!(output.contains("id")); assert!(output.contains("title")); - // Should derive TryFromMultipart - assert!(output.contains("TryFromMultipart")); + // Should derive Multipart + assert!(output.contains("Multipart")); } #[test] @@ -1972,8 +1972,8 @@ fn test_generate_schema_type_code_multipart_partial() { assert!(result.is_ok()); let (tokens, _metadata) = result.unwrap(); let output = tokens.to_string(); - // Should derive TryFromMultipart - assert!(output.contains("TryFromMultipart")); + // Should derive Multipart + assert!(output.contains("Multipart")); // Fields should be wrapped in Option (partial) assert!(output.contains("Option")); // Should NOT have From impl diff --git a/examples/axum-example/Cargo.toml b/examples/axum-example/Cargo.toml index 7d745b9..76e59ca 100644 --- a/examples/axum-example/Cargo.toml +++ b/examples/axum-example/Cargo.toml @@ -13,7 +13,6 @@ serde_json = "1" tower-http = { version = "0.6", features = ["cors"] } sea-orm = { version = "^2.0.0-rc.36", features = ["sqlx-sqlite", "runtime-tokio-rustls", "macros", "with-uuid"] } uuid = { version = "1", features = ["v4", "serde"] } -axum_typed_multipart = "0.16" tempfile = "3" third = { path = "../third" } diff --git a/examples/axum-example/openapi.json b/examples/axum-example/openapi.json index 490ddea..bf41bd6 100644 --- a/examples/axum-example/openapi.json +++ b/examples/axum-example/openapi.json @@ -3774,7 +3774,7 @@ "format": "binary", "nullable": true }, - "is_active": { + "isActive": { "type": "boolean", "nullable": true }, diff --git a/examples/axum-example/src/routes/typed_form.rs b/examples/axum-example/src/routes/typed_form.rs index d6e00ea..90e293d 100644 --- a/examples/axum-example/src/routes/typed_form.rs +++ b/examples/axum-example/src/routes/typed_form.rs @@ -1,9 +1,9 @@ -use axum_typed_multipart::{FieldData, TryFromMultipart, TypedMultipart}; use serde::Serialize; use tempfile::NamedTempFile; use vespera::axum::Json; use vespera::axum::http::StatusCode; -use vespera::{Schema, route}; +use vespera::multipart::{FieldData, TypedMultipart}; +use vespera::{Multipart, Schema, route}; // ============== Request/Response DTOs ============== @@ -19,7 +19,7 @@ pub struct FileUploadResponse { pub created_at: String, } -#[derive(Debug, TryFromMultipart, Schema)] +#[derive(Debug, Multipart, Schema)] pub struct CreateFileUploadRequest { pub name: String, #[form_data(limit = "10MiB")] @@ -29,7 +29,8 @@ pub struct CreateFileUploadRequest { pub tags: Option, } -#[derive(Debug, TryFromMultipart, Schema)] +#[derive(Debug, Multipart, Schema)] +#[serde(rename_all = "camelCase")] pub struct UpdateFileUploadRequest { pub name: Option, #[form_data(limit = "10MiB")] @@ -40,7 +41,7 @@ pub struct UpdateFileUploadRequest { pub is_active: Option, } -// Generated via schema_type! with multipart: derives TryFromMultipart + Schema, +// Generated via schema_type! with multipart: derives Multipart + Schema, // partial makes all fields Option, omits the "document" field, preserves form_data attrs. // Note: multipart automatically sets clone = false (FieldData doesn't implement Clone). vespera::schema_type!(PatchFileUploadRequest from UpdateFileUploadRequest, multipart, partial, omit = ["document"]); diff --git a/examples/axum-example/tests/integration_test.rs b/examples/axum-example/tests/integration_test.rs index 04ccbf1..ce6ba81 100644 --- a/examples/axum-example/tests/integration_test.rs +++ b/examples/axum-example/tests/integration_test.rs @@ -1,8 +1,10 @@ use axum_example::{create_app, create_app_with_layer}; use axum_test::TestServer; +use axum_test::multipart::{MultipartForm, Part}; use serde::{Deserialize, Serialize}; use serde_json::json; -use vespera::{Schema, schema}; +use vespera::multipart::{FieldData, TypedMultipart}; +use vespera::{Multipart, Schema, schema}; #[tokio::test] async fn test_health_endpoint() { @@ -723,3 +725,1146 @@ async fn test_memo_update_with_added_id_field() { // Verify added field assert_eq!(result["id"], 42, "id should be present (added field)"); } + +// Tests for TypedMultipart (Multipart) request body extraction + +#[tokio::test] +async fn test_typed_form_list_file_uploads() { + let app = create_app().await; + let server = TestServer::new(app); + + let response = server.get("/typed-form").await; + + response.assert_status_ok(); + let uploads: serde_json::Value = response.json(); + + assert!(uploads.is_array()); + let uploads = uploads.as_array().unwrap(); + assert_eq!(uploads.len(), 1); + + let upload = &uploads[0]; + assert_eq!(upload["id"], 1); + assert_eq!(upload["name"], "Sample Upload"); + assert_eq!(upload["thumbnailUrl"], "https://example.com/thumb.jpg"); + assert_eq!(upload["documentUrl"], "https://example.com/doc.pdf"); + assert_eq!(upload["isActive"], true); + assert!(upload["tags"].is_array()); + assert_eq!(upload["tags"].as_array().unwrap().len(), 2); +} + +#[tokio::test] +async fn test_typed_form_create_file_upload() { + let app = create_app().await; + let server = TestServer::new(app); + + let form = MultipartForm::new() + .add_text("name", "Test Upload") + .add_text("tags", "rust, axum, vespera"); + + let response = server.post("/typed-form").multipart(form).await; + + response.assert_status_ok(); + let result: serde_json::Value = response.json(); + + assert_eq!(result["id"], 1); + assert_eq!(result["name"], "Test Upload"); + assert_eq!(result["isActive"], true); + + // Tags should be parsed from comma-separated string + let tags = result["tags"].as_array().unwrap(); + assert_eq!(tags.len(), 3); + assert_eq!(tags[0], "rust"); + assert_eq!(tags[1], "axum"); + assert_eq!(tags[2], "vespera"); + + // No files uploaded, so URLs should be null + assert!(result["thumbnailUrl"].is_null()); + assert!(result["documentUrl"].is_null()); +} + +#[tokio::test] +async fn test_typed_form_create_file_upload_with_files() { + let app = create_app().await; + let server = TestServer::new(app); + + let thumbnail_part = Part::bytes(b"fake image data".as_slice()).file_name("thumb.jpg"); + let document_part = Part::bytes(b"fake pdf data".as_slice()).file_name("doc.pdf"); + + let form = MultipartForm::new() + .add_text("name", "Upload With Files") + .add_part("thumbnail", thumbnail_part) + .add_part("document", document_part) + .add_text("tags", "files"); + + let response = server.post("/typed-form").multipart(form).await; + + response.assert_status_ok(); + let result: serde_json::Value = response.json(); + + assert_eq!(result["name"], "Upload With Files"); + assert_eq!(result["thumbnailUrl"], "uploaded_thumbnail_url"); + assert_eq!(result["documentUrl"], "uploaded_document_url"); + + let tags = result["tags"].as_array().unwrap(); + assert_eq!(tags.len(), 1); + assert_eq!(tags[0], "files"); +} + +#[tokio::test] +async fn test_typed_form_update_file_upload() { + let app = create_app().await; + let server = TestServer::new(app); + + // UpdateFileUploadRequest has #[serde(rename_all = "camelCase")] which vespera::Multipart respects. + // The Multipart derive reads serde attrs to match field names at runtime. + let form = MultipartForm::new() + .add_text("name", "Updated Upload") + .add_text("tags", "updated, tags") + .add_text("isActive", "false"); + + let response = server.put("/typed-form/42").multipart(form).await; + + response.assert_status_ok(); + let result: serde_json::Value = response.json(); + + assert_eq!(result["id"], 42); + assert_eq!(result["name"], "Updated Upload"); + // Response uses FileUploadResponse with #[serde(rename_all = "camelCase")] + assert_eq!(result["isActive"], false); + + let tags = result["tags"].as_array().unwrap(); + assert_eq!(tags.len(), 2); + assert_eq!(tags[0], "updated"); + assert_eq!(tags[1], "tags"); +} + +#[tokio::test] +async fn test_typed_form_update_file_upload_with_file() { + let app = create_app().await; + let server = TestServer::new(app); + + let thumbnail_part = Part::bytes(b"new image".as_slice()).file_name("new_thumb.jpg"); + + let form = MultipartForm::new() + .add_text("name", "Updated With File") + .add_part("thumbnail", thumbnail_part); + + let response = server.put("/typed-form/7").multipart(form).await; + + response.assert_status_ok(); + let result: serde_json::Value = response.json(); + + assert_eq!(result["id"], 7); + assert_eq!(result["name"], "Updated With File"); + assert_eq!(result["thumbnailUrl"], "updated_thumbnail_url"); + // Document not provided in this update + assert!(result["documentUrl"].is_null()); +} + +#[tokio::test] +async fn test_typed_form_patch_file_upload() { + let app = create_app().await; + let server = TestServer::new(app); + + // PatchFileUploadRequest is generated via schema_type! with multipart + partial + omit = ["document"] + // All fields are Option, so we only send the ones we want to update + let form = MultipartForm::new().add_text("name", "Patched Upload"); + + let response = server.patch("/typed-form/99").multipart(form).await; + + response.assert_status_ok(); + let result: serde_json::Value = response.json(); + + assert_eq!(result["id"], 99); + assert_eq!(result["name"], "Patched Upload"); + // document_url is always None for patch (field omitted from PatchFileUploadRequest) + assert!(result["documentUrl"].is_null()); +} + +#[tokio::test] +async fn test_typed_form_patch_file_upload_with_thumbnail() { + let app = create_app().await; + let server = TestServer::new(app); + + let thumbnail_part = Part::bytes(b"patched image".as_slice()).file_name("patched.jpg"); + + let form = MultipartForm::new() + .add_text("name", "Patched With Thumb") + .add_part("thumbnail", thumbnail_part) + .add_text("tags", "patched"); + + let response = server.patch("/typed-form/55").multipart(form).await; + + response.assert_status_ok(); + let result: serde_json::Value = response.json(); + + assert_eq!(result["id"], 55); + assert_eq!(result["name"], "Patched With Thumb"); + assert_eq!(result["thumbnailUrl"], "patched_thumbnail_url"); + assert!(result["documentUrl"].is_null()); + + let tags = result["tags"].as_array().unwrap(); + assert_eq!(tags.len(), 1); + assert_eq!(tags[0], "patched"); +} + +#[tokio::test] +async fn test_typed_form_create_minimal() { + let app = create_app().await; + let server = TestServer::new(app); + + // Only required field is "name" — all others are optional + let form = MultipartForm::new().add_text("name", "Minimal Upload"); + + let response = server.post("/typed-form").multipart(form).await; + + response.assert_status_ok(); + let result: serde_json::Value = response.json(); + + assert_eq!(result["name"], "Minimal Upload"); + assert!(result["thumbnailUrl"].is_null()); + assert!(result["documentUrl"].is_null()); + assert!(result["tags"].as_array().unwrap().is_empty()); +} + +#[tokio::test] +async fn test_openapi_contains_typed_form_routes() { + let openapi_content = std::fs::read_to_string("openapi.json").unwrap(); + let openapi: serde_json::Value = serde_json::from_str(&openapi_content).unwrap(); + + let paths = openapi.get("paths").unwrap(); + + // Verify typed-form routes exist + assert!( + paths.get("/typed-form").is_some(), + "Missing /typed-form route in OpenAPI spec" + ); + assert!( + paths.get("/typed-form/{id}").is_some(), + "Missing /typed-form/{{id}} route in OpenAPI spec" + ); + + // Verify POST /typed-form uses multipart/form-data content type + let post_op = &paths["/typed-form"]["post"]; + let request_body = post_op.get("requestBody").unwrap(); + let content = request_body.get("content").unwrap(); + assert!( + content.get("multipart/form-data").is_some(), + "POST /typed-form should use multipart/form-data content type" + ); + + // Verify PUT /typed-form/{id} uses multipart/form-data + let put_op = &paths["/typed-form/{id}"]["put"]; + let request_body = put_op.get("requestBody").unwrap(); + let content = request_body.get("content").unwrap(); + assert!( + content.get("multipart/form-data").is_some(), + "PUT /typed-form/{{id}} should use multipart/form-data content type" + ); + + // Verify PATCH /typed-form/{id} uses multipart/form-data + let patch_op = &paths["/typed-form/{id}"]["patch"]; + let request_body = patch_op.get("requestBody").unwrap(); + let content = request_body.get("content").unwrap(); + assert!( + content.get("multipart/form-data").is_some(), + "PATCH /typed-form/{{id}} should use multipart/form-data content type" + ); +} + +#[tokio::test] +async fn test_openapi_contains_typed_form_schemas() { + let openapi_content = std::fs::read_to_string("openapi.json").unwrap(); + let openapi: serde_json::Value = serde_json::from_str(&openapi_content).unwrap(); + + let schemas = openapi + .get("components") + .and_then(|c| c.get("schemas")) + .unwrap(); + + // Verify TypedMultipart request/response schemas exist + assert!( + schemas.get("CreateFileUploadRequest").is_some(), + "Missing CreateFileUploadRequest schema" + ); + assert!( + schemas.get("UpdateFileUploadRequest").is_some(), + "Missing UpdateFileUploadRequest schema" + ); + assert!( + schemas.get("PatchFileUploadRequest").is_some(), + "Missing PatchFileUploadRequest schema (generated via schema_type! multipart)" + ); + assert!( + schemas.get("FileUploadResponse").is_some(), + "Missing FileUploadResponse schema" + ); +} + +// ============== #[form_data(limit = "...")] enforcement tests ============== +// +// These tests use a standalone Multipart struct with small limits to verify +// that the `#[form_data(limit)]` attribute is correctly enforced at runtime +// for both text fields and file (NamedTempFile) fields. + +/// Test struct with intentionally small limits for limit enforcement testing. +#[derive(Debug, Multipart)] +#[allow(dead_code)] +struct FormDataLimitTestRequest { + /// No limit — accepts any size. + pub name: String, + /// 100-byte limit on a text field. + #[form_data(limit = "100")] + pub data: Option, + /// 50-byte limit on a file upload field. + #[form_data(limit = "50")] + pub file: Option>, +} + +async fn form_data_limit_handler( + TypedMultipart(req): TypedMultipart, +) -> axum::Json { + axum::Json(req.name) +} + +fn create_limit_test_app() -> axum::Router { + axum::Router::new().route("/limit-test", axum::routing::post(form_data_limit_handler)) +} + +#[tokio::test] +async fn test_form_data_limit_text_field_within_limit() { + let server = TestServer::new(create_limit_test_app()); + + // 5 bytes text — well within 100-byte limit + let form = MultipartForm::new() + .add_text("name", "test") + .add_text("data", "short"); + + let response = server.post("/limit-test").multipart(form).await; + response.assert_status_ok(); +} + +#[tokio::test] +async fn test_form_data_limit_text_field_at_boundary() { + let server = TestServer::new(create_limit_test_app()); + + // Exactly 100 bytes — should succeed (limit check is `> limit`, not `>=`) + let exact = "x".repeat(100); + let form = MultipartForm::new() + .add_text("name", "test") + .add_text("data", &exact); + + let response = server.post("/limit-test").multipart(form).await; + response.assert_status_ok(); +} + +#[tokio::test] +async fn test_form_data_limit_text_field_exceeds_limit() { + let server = TestServer::new(create_limit_test_app()); + + // 101 bytes — exceeds 100-byte limit → HTTP 413 PAYLOAD_TOO_LARGE + let over = "x".repeat(101); + let form = MultipartForm::new() + .add_text("name", "test") + .add_text("data", &over); + + let response = server.post("/limit-test").multipart(form).await; + response.assert_status(axum::http::StatusCode::PAYLOAD_TOO_LARGE); + let body = response.text(); + assert!( + body.contains("data"), + "Error should mention the field name 'data': {body}" + ); +} + +#[tokio::test] +async fn test_form_data_limit_file_field_within_limit() { + let server = TestServer::new(create_limit_test_app()); + + // 50 bytes file — exactly at 50-byte limit + let small_file = Part::bytes(vec![0u8; 50]).file_name("small.bin"); + let form = MultipartForm::new() + .add_text("name", "test") + .add_part("file", small_file); + + let response = server.post("/limit-test").multipart(form).await; + response.assert_status_ok(); +} + +#[tokio::test] +async fn test_form_data_limit_file_field_exceeds_limit() { + let server = TestServer::new(create_limit_test_app()); + + // 51 bytes file — exceeds 50-byte limit → HTTP 413 PAYLOAD_TOO_LARGE + let big_file = Part::bytes(vec![0u8; 51]).file_name("big.bin"); + let form = MultipartForm::new() + .add_text("name", "test") + .add_part("file", big_file); + + let response = server.post("/limit-test").multipart(form).await; + response.assert_status(axum::http::StatusCode::PAYLOAD_TOO_LARGE); + let body = response.text(); + assert!( + body.contains("file"), + "Error should mention the field name 'file': {body}" + ); +} + +#[tokio::test] +async fn test_form_data_no_limit_field_accepts_large_data() { + let server = TestServer::new(create_limit_test_app()); + + // "name" has no #[form_data(limit)] — should accept large values + let long_name = "x".repeat(10_000); + let form = MultipartForm::new().add_text("name", &long_name); + + let response = server.post("/limit-test").multipart(form).await; + response.assert_status_ok(); + + let result: String = response.json(); + assert_eq!(result.len(), 10_000); +} + +#[tokio::test] +async fn test_form_data_limit_unlimited_keyword() { + // Verify that parse_byte_unit handles "unlimited" (code path: returns None) + // Tested indirectly: a field without a limit already behaves as unlimited. + // This test confirms the same behavior with all fields provided. + let server = TestServer::new(create_limit_test_app()); + + let form = MultipartForm::new() + .add_text("name", "test") + .add_text("data", "y".repeat(50)) + .add_part("file", Part::bytes(vec![1u8; 30]).file_name("f.bin")); + + let response = server.post("/limit-test").multipart(form).await; + response.assert_status_ok(); +} + +// ============== #[serde(rename)] and #[serde(default)] tests ============== +// +// These tests verify that `#[derive(Multipart)]` correctly handles serde +// attributes for field renaming and default values. + +fn default_greeting() -> String { + "hello".to_string() +} + +/// Test struct with serde rename and default attributes. +#[derive(Debug, Multipart)] +#[serde(rename_all = "camelCase")] +#[allow(dead_code)] +struct SerdeAttrTestRequest { + /// Uses camelCase rename from struct-level rename_all. + pub user_name: String, + /// Explicit field rename overrides rename_all. + #[serde(rename = "customTag")] + pub tag_value: String, + /// `#[serde(default)]` uses `Default::default()` when missing. + #[serde(default)] + pub score: i32, + /// `#[serde(default = "fn")]` calls custom function when missing. + #[serde(default = "default_greeting")] + pub greeting: String, +} + +async fn serde_attr_handler( + TypedMultipart(req): TypedMultipart, +) -> axum::Json { + axum::Json(serde_json::json!({ + "userName": req.user_name, + "tagValue": req.tag_value, + "score": req.score, + "greeting": req.greeting, + })) +} + +/// Test struct with struct-level `#[serde(default)]`. +#[derive(Debug, Multipart)] +#[serde(default)] +#[allow(dead_code)] +struct StructDefaultTestRequest { + pub name: String, + pub count: i32, + pub active: bool, +} + +async fn struct_default_handler( + TypedMultipart(req): TypedMultipart, +) -> axum::Json { + axum::Json(serde_json::json!({ + "name": req.name, + "count": req.count, + "active": req.active, + })) +} + +fn create_serde_test_app() -> axum::Router { + axum::Router::new() + .route("/serde-test", axum::routing::post(serde_attr_handler)) + .route( + "/struct-default-test", + axum::routing::post(struct_default_handler), + ) +} + +// ─── serde(rename_all) tests ──────────────────────────────────────────────── + +#[tokio::test] +async fn test_serde_rename_all_camel_case() { + let server = TestServer::new(create_serde_test_app()); + + // Field "user_name" is renamed to "userName" by rename_all = "camelCase" + let form = MultipartForm::new() + .add_text("userName", "Alice") + .add_text("customTag", "rust"); + + let response = server.post("/serde-test").multipart(form).await; + response.assert_status_ok(); + + let result: serde_json::Value = response.json(); + assert_eq!(result["userName"], "Alice"); + assert_eq!(result["tagValue"], "rust"); +} + +#[tokio::test] +async fn test_serde_rename_all_rust_name_rejected() { + let server = TestServer::new(create_serde_test_app()); + + // Using Rust field name "user_name" instead of "userName" should fail + let form = MultipartForm::new() + .add_text("user_name", "Alice") + .add_text("customTag", "rust"); + + let response = server.post("/serde-test").multipart(form).await; + // "userName" is missing → MissingField error + response.assert_status(axum::http::StatusCode::BAD_REQUEST); +} + +// ─── serde(rename = "...") tests ──────────────────────────────────────────── + +#[tokio::test] +async fn test_serde_rename_explicit() { + let server = TestServer::new(create_serde_test_app()); + + // "tag_value" is renamed to "customTag" by #[serde(rename = "customTag")] + let form = MultipartForm::new() + .add_text("userName", "Alice") + .add_text("customTag", "explicit"); + + let response = server.post("/serde-test").multipart(form).await; + response.assert_status_ok(); + + let result: serde_json::Value = response.json(); + assert_eq!(result["tagValue"], "explicit"); +} + +#[tokio::test] +async fn test_serde_rename_camel_case_of_field_rejected() { + let server = TestServer::new(create_serde_test_app()); + + // "tagValue" (camelCase of Rust name) should NOT work — explicit rename takes priority + let form = MultipartForm::new() + .add_text("userName", "Alice") + .add_text("tagValue", "wrong"); + + let response = server.post("/serde-test").multipart(form).await; + // "customTag" is missing → MissingField error + response.assert_status(axum::http::StatusCode::BAD_REQUEST); +} + +// ─── serde(default) field-level tests ─────────────────────────────────────── + +#[tokio::test] +async fn test_serde_default_uses_default_trait() { + let server = TestServer::new(create_serde_test_app()); + + // Omit "score" (has #[serde(default)]) — should get i32::default() = 0 + let form = MultipartForm::new() + .add_text("userName", "Alice") + .add_text("customTag", "test"); + + let response = server.post("/serde-test").multipart(form).await; + response.assert_status_ok(); + + let result: serde_json::Value = response.json(); + assert_eq!(result["score"], 0, "score should default to 0"); +} + +#[tokio::test] +async fn test_serde_default_fn_uses_custom_function() { + let server = TestServer::new(create_serde_test_app()); + + // Omit "greeting" (has #[serde(default = "default_greeting")]) + // Should get "hello" from the custom function + let form = MultipartForm::new() + .add_text("userName", "Alice") + .add_text("customTag", "test"); + + let response = server.post("/serde-test").multipart(form).await; + response.assert_status_ok(); + + let result: serde_json::Value = response.json(); + assert_eq!( + result["greeting"], "hello", + "greeting should default to 'hello' from default_greeting()" + ); +} + +#[tokio::test] +async fn test_serde_default_overridden_when_provided() { + let server = TestServer::new(create_serde_test_app()); + + // Provide both default fields — explicit values should win + let form = MultipartForm::new() + .add_text("userName", "Alice") + .add_text("customTag", "test") + .add_text("score", "42") + .add_text("greeting", "world"); + + let response = server.post("/serde-test").multipart(form).await; + response.assert_status_ok(); + + let result: serde_json::Value = response.json(); + assert_eq!(result["score"], 42); + assert_eq!(result["greeting"], "world"); +} + +// ─── Vec field, strict mode, form_data(field_name), numeric/char tests ─── + +/// Test struct with Vec field for repeated multipart fields. +#[derive(Debug, Multipart)] +#[allow(dead_code)] +struct VecFieldTestRequest { + pub name: String, + pub tags: Vec, +} + +async fn vec_field_handler( + TypedMultipart(req): TypedMultipart, +) -> axum::Json { + axum::Json(serde_json::json!({ + "name": req.name, + "tags": req.tags, + })) +} + +/// Test struct with strict mode enabled. +#[derive(Debug, Multipart)] +#[try_from_multipart(strict)] +#[allow(dead_code)] +struct StrictModeTestRequest { + pub name: String, + pub age: i32, +} + +async fn strict_mode_handler( + TypedMultipart(req): TypedMultipart, +) -> axum::Json { + axum::Json(serde_json::json!({ + "name": req.name, + "age": req.age, + })) +} + +/// Test struct with form_data(field_name) override. +#[derive(Debug, Multipart)] +#[allow(dead_code)] +struct FieldNameOverrideTestRequest { + pub name: String, + #[form_data(field_name = "custom_field")] + pub data: String, +} + +async fn field_name_override_handler( + TypedMultipart(req): TypedMultipart, +) -> axum::Json { + axum::Json(serde_json::json!({ + "name": req.name, + "data": req.data, + })) +} + +/// Test struct with form_data(default) attribute. +#[derive(Debug, Multipart)] +#[allow(dead_code)] +struct FormDataDefaultTestRequest { + pub name: String, + #[form_data(default)] + pub count: i32, +} + +async fn form_data_default_handler( + TypedMultipart(req): TypedMultipart, +) -> axum::Json { + axum::Json(serde_json::json!({ + "name": req.name, + "count": req.count, + })) +} + +/// Test struct with numeric and char fields for type parsing coverage. +#[derive(Debug, Multipart)] +#[allow(dead_code)] +struct NumericCharTestRequest { + pub name: String, + pub count: i32, + pub score: f64, + pub initial: char, +} + +async fn numeric_char_handler( + TypedMultipart(req): TypedMultipart, +) -> axum::Json { + axum::Json(serde_json::json!({ + "name": req.name, + "count": req.count, + "score": req.score, + "initial": req.initial.to_string(), + })) +} + +fn create_coverage_test_app() -> axum::Router { + axum::Router::new() + .route("/vec-test", axum::routing::post(vec_field_handler)) + .route("/strict-test", axum::routing::post(strict_mode_handler)) + .route( + "/field-name-test", + axum::routing::post(field_name_override_handler), + ) + .route( + "/form-data-default-test", + axum::routing::post(form_data_default_handler), + ) + .route( + "/numeric-char-test", + axum::routing::post(numeric_char_handler), + ) +} + +// ─── Vec field tests ───────────────────────────────────────────────────── + +#[tokio::test] +async fn test_vec_field_multiple_values() { + let server = TestServer::new(create_coverage_test_app()); + + let form = MultipartForm::new() + .add_text("name", "Alice") + .add_text("tags", "rust") + .add_text("tags", "web") + .add_text("tags", "api"); + + let response = server.post("/vec-test").multipart(form).await; + response.assert_status_ok(); + + let result: serde_json::Value = response.json(); + assert_eq!(result["name"], "Alice"); + assert_eq!(result["tags"], json!(["rust", "web", "api"])); +} + +#[tokio::test] +async fn test_vec_field_empty() { + let server = TestServer::new(create_coverage_test_app()); + + // No "tags" fields — Vec should be empty + let form = MultipartForm::new().add_text("name", "Bob"); + + let response = server.post("/vec-test").multipart(form).await; + response.assert_status_ok(); + + let result: serde_json::Value = response.json(); + assert_eq!(result["tags"], json!([])); +} + +#[tokio::test] +async fn test_vec_field_single_value() { + let server = TestServer::new(create_coverage_test_app()); + + let form = MultipartForm::new() + .add_text("name", "Charlie") + .add_text("tags", "solo"); + + let response = server.post("/vec-test").multipart(form).await; + response.assert_status_ok(); + + let result: serde_json::Value = response.json(); + assert_eq!(result["tags"], json!(["solo"])); +} + +// ─── Strict mode tests ────────────────────────────────────────────────────── + +#[tokio::test] +async fn test_strict_mode_valid_request() { + let server = TestServer::new(create_coverage_test_app()); + + let form = MultipartForm::new() + .add_text("name", "Alice") + .add_text("age", "30"); + + let response = server.post("/strict-test").multipart(form).await; + response.assert_status_ok(); + + let result: serde_json::Value = response.json(); + assert_eq!(result["name"], "Alice"); + assert_eq!(result["age"], 30); +} + +#[tokio::test] +async fn test_strict_mode_unknown_field() { + let server = TestServer::new(create_coverage_test_app()); + + // "extra" is not a field in StrictModeTestRequest → UnknownField error + let form = MultipartForm::new() + .add_text("name", "Alice") + .add_text("age", "30") + .add_text("extra", "rejected"); + + let response = server.post("/strict-test").multipart(form).await; + response.assert_status(axum::http::StatusCode::BAD_REQUEST); + let body = response.text(); + assert!( + body.contains("Unknown field"), + "Should mention unknown field: {body}" + ); +} + +#[tokio::test] +async fn test_strict_mode_duplicate_field() { + let server = TestServer::new(create_coverage_test_app()); + + // Sending "name" twice in strict mode → DuplicateField error + let form = MultipartForm::new() + .add_text("name", "Alice") + .add_text("name", "Bob") + .add_text("age", "30"); + + let response = server.post("/strict-test").multipart(form).await; + response.assert_status(axum::http::StatusCode::BAD_REQUEST); + let body = response.text(); + assert!( + body.contains("Duplicate field"), + "Should mention duplicate field: {body}" + ); +} + +// ─── form_data(field_name) tests ──────────────────────────────────────────── + +#[tokio::test] +async fn test_form_data_field_name_override() { + let server = TestServer::new(create_coverage_test_app()); + + // "data" field is mapped to "custom_field" via form_data(field_name) + let form = MultipartForm::new() + .add_text("name", "Alice") + .add_text("custom_field", "payload"); + + let response = server.post("/field-name-test").multipart(form).await; + response.assert_status_ok(); + + let result: serde_json::Value = response.json(); + assert_eq!(result["data"], "payload"); +} + +#[tokio::test] +async fn test_form_data_field_name_rust_name_rejected() { + let server = TestServer::new(create_coverage_test_app()); + + // Using Rust field name "data" instead of "custom_field" → MissingField + let form = MultipartForm::new() + .add_text("name", "Alice") + .add_text("data", "payload"); + + let response = server.post("/field-name-test").multipart(form).await; + response.assert_status(axum::http::StatusCode::BAD_REQUEST); +} + +// ─── form_data(default) tests ─────────────────────────────────────────────── + +#[tokio::test] +async fn test_form_data_default_uses_default_trait() { + let server = TestServer::new(create_coverage_test_app()); + + // Omit "count" (has #[form_data(default)]) → Default::default() = 0 + let form = MultipartForm::new().add_text("name", "Alice"); + + let response = server.post("/form-data-default-test").multipart(form).await; + response.assert_status_ok(); + + let result: serde_json::Value = response.json(); + assert_eq!(result["count"], 0); +} + +#[tokio::test] +async fn test_form_data_default_overridden_when_provided() { + let server = TestServer::new(create_coverage_test_app()); + + let form = MultipartForm::new() + .add_text("name", "Alice") + .add_text("count", "42"); + + let response = server.post("/form-data-default-test").multipart(form).await; + response.assert_status_ok(); + + let result: serde_json::Value = response.json(); + assert_eq!(result["count"], 42); +} + +// ─── Numeric and char field parsing tests ─────────────────────────────────── + +#[tokio::test] +async fn test_numeric_char_valid_values() { + let server = TestServer::new(create_coverage_test_app()); + + let form = MultipartForm::new() + .add_text("name", "Alice") + .add_text("count", "42") + .add_text("score", "9.75") + .add_text("initial", "A"); + + let response = server.post("/numeric-char-test").multipart(form).await; + response.assert_status_ok(); + + let result: serde_json::Value = response.json(); + assert_eq!(result["count"], 42); + assert!((result["score"].as_f64().unwrap() - 9.75).abs() < f64::EPSILON); + assert_eq!(result["initial"], "A"); +} + +#[tokio::test] +async fn test_numeric_field_invalid_value() { + let server = TestServer::new(create_coverage_test_app()); + + // "not_a_number" for i32 field → WrongFieldType + let form = MultipartForm::new() + .add_text("name", "Alice") + .add_text("count", "not_a_number") + .add_text("score", "9.75") + .add_text("initial", "A"); + + let response = server.post("/numeric-char-test").multipart(form).await; + response.assert_status(axum::http::StatusCode::UNSUPPORTED_MEDIA_TYPE); +} + +#[tokio::test] +async fn test_float_field_invalid_value() { + let server = TestServer::new(create_coverage_test_app()); + + // "abc" for f64 field → WrongFieldType + let form = MultipartForm::new() + .add_text("name", "Alice") + .add_text("count", "10") + .add_text("score", "abc") + .add_text("initial", "A"); + + let response = server.post("/numeric-char-test").multipart(form).await; + response.assert_status(axum::http::StatusCode::UNSUPPORTED_MEDIA_TYPE); +} + +#[tokio::test] +async fn test_char_field_multiple_chars() { + let server = TestServer::new(create_coverage_test_app()); + + // "AB" for char field → WrongFieldType (expects exactly one character) + let form = MultipartForm::new() + .add_text("name", "Alice") + .add_text("count", "10") + .add_text("score", "1.0") + .add_text("initial", "AB"); + + let response = server.post("/numeric-char-test").multipart(form).await; + response.assert_status(axum::http::StatusCode::UNSUPPORTED_MEDIA_TYPE); +} + +#[tokio::test] +async fn test_char_field_empty_string() { + let server = TestServer::new(create_coverage_test_app()); + + // "" for char field → WrongFieldType (expects exactly one character) + let form = MultipartForm::new() + .add_text("name", "Alice") + .add_text("count", "10") + .add_text("score", "1.0") + .add_text("initial", ""); + + let response = server.post("/numeric-char-test").multipart(form).await; + response.assert_status(axum::http::StatusCode::UNSUPPORTED_MEDIA_TYPE); +} + +// ─── serde(default) struct-level tests ────────────────────────────────────── + +#[tokio::test] +async fn test_struct_level_serde_default_all_omitted() { + let server = TestServer::new(create_serde_test_app()); + + // No recognized fields — struct has #[serde(default)], all get Default::default(). + // Send an unrecognized field to produce a valid multipart body (non-strict ignores it). + let form = MultipartForm::new().add_text("_ignored", ""); + + let response = server.post("/struct-default-test").multipart(form).await; + response.assert_status_ok(); + + let result: serde_json::Value = response.json(); + assert_eq!(result["name"], "", "String::default() is empty string"); + assert_eq!(result["count"], 0, "i32::default() is 0"); + assert_eq!(result["active"], false, "bool::default() is false"); +} + +#[tokio::test] +async fn test_struct_level_serde_default_partial() { + let server = TestServer::new(create_serde_test_app()); + + // Only provide "name" — other fields should get defaults + let form = MultipartForm::new().add_text("name", "Bob"); + + let response = server.post("/struct-default-test").multipart(form).await; + response.assert_status_ok(); + + let result: serde_json::Value = response.json(); + assert_eq!(result["name"], "Bob"); + assert_eq!(result["count"], 0); + assert_eq!(result["active"], false); +} + +#[tokio::test] +async fn test_struct_level_serde_default_all_provided() { + let server = TestServer::new(create_serde_test_app()); + + // Provide all fields — explicit values should win + let form = MultipartForm::new() + .add_text("name", "Charlie") + .add_text("count", "99") + .add_text("active", "true"); + + let response = server.post("/struct-default-test").multipart(form).await; + response.assert_status_ok(); + + let result: serde_json::Value = response.json(); + assert_eq!(result["name"], "Charlie"); + assert_eq!(result["count"], 99); + assert_eq!(result["active"], true); +} + +// ============== Multipart error path coverage tests ========================== +// +// These tests trigger real axum MultipartRejection / MultipartError paths +// to cover From impls and Display arms for InvalidRequest/InvalidRequestBody. + +#[tokio::test] +async fn test_multipart_rejection_non_multipart_content_type() { + // Sending JSON to a multipart handler triggers MultipartRejection → InvalidRequest + let server = TestServer::new(create_coverage_test_app()); + + let response = server + .post("/strict-test") + .content_type("application/json") + .bytes(b"{\"name\":\"x\",\"age\":1}".to_vec().into()) + .await; + + response.assert_status(axum::http::StatusCode::BAD_REQUEST); + let body = response.text(); + assert!( + body.contains("Invalid multipart request"), + "Should use InvalidRequest Display: {body}" + ); +} + +#[tokio::test] +async fn test_multipart_rejection_missing_content_type() { + // Sending raw bytes with no content type triggers MultipartRejection + let server = TestServer::new(create_coverage_test_app()); + + let response = server + .post("/vec-test") + .bytes(b"not multipart".to_vec().into()) + .await; + + // axum rejects with 4xx because there's no multipart content-type + assert!( + response.status_code().is_client_error(), + "Should be a client error, got {}", + response.status_code() + ); +} + +#[tokio::test] +async fn test_numeric_field_non_utf8_bytes() { + // Send non-UTF-8 bytes for a numeric (i32) field → WrongFieldType from from_utf8 error + let server = TestServer::new(create_coverage_test_app()); + + let invalid_utf8 = Part::bytes(vec![0xFF, 0xFE, 0xFD]).file_name("bad.bin"); + let form = MultipartForm::new() + .add_text("name", "Alice") + .add_part("count", invalid_utf8) + .add_text("score", "1.0") + .add_text("initial", "A"); + + let response = server.post("/numeric-char-test").multipart(form).await; + response.assert_status(axum::http::StatusCode::UNSUPPORTED_MEDIA_TYPE); +} + +#[tokio::test] +async fn test_multipart_error_malformed_body_stream() { + // Send a valid multipart Content-Type so from_request succeeds, + // but with a corrupted body so next_field() returns Err(MultipartError). + // This covers From (line 135) and InvalidRequestBody Display (lines 93-94). + let server = TestServer::new(create_coverage_test_app()); + + let boundary = "TESTBOUNDARY"; + // Start a valid boundary, then inject invalid header bytes (0xFF is not valid in HTTP headers). + // multer will attempt to parse these as field headers and fail. + let mut body = Vec::new(); + body.extend_from_slice(b"--TESTBOUNDARY\r\n"); + body.extend_from_slice(&[0x01, 0x02, 0xFF, 0xFE]); // invalid header bytes + body.extend_from_slice(b"\r\n\r\ndata\r\n--TESTBOUNDARY--"); + + let response = server + .post("/strict-test") + .content_type(&format!("multipart/form-data; boundary={boundary}")) + .bytes(body.into()) + .await; + + // multer rejects the invalid header bytes → MultipartError → From → InvalidRequestBody + response.assert_status(axum::http::StatusCode::BAD_REQUEST); + let body = response.text(); + assert!( + body.contains("Invalid multipart body"), + "Expected InvalidRequestBody Display output, got: {body}" + ); +} + +#[tokio::test] +async fn test_missing_required_field() { + // Send only "name" but omit required "age" → MissingField error from post-loop check + let server = TestServer::new(create_coverage_test_app()); + + let form = MultipartForm::new().add_text("name", "Alice"); + let response = server.post("/strict-test").multipart(form).await; + + response.assert_status(axum::http::StatusCode::BAD_REQUEST); + let body = response.text(); + assert!( + body.contains("Missing field"), + "Expected MissingField error, got: {body}" + ); + assert!( + body.contains("age"), + "Should name the missing field 'age', got: {body}" + ); +} + +#[tokio::test] +async fn test_missing_multiple_required_fields() { + // Send only "name" to numeric-char-test which requires name, count, score, initial. + // Non-strict endpoint: the unmatched fields simply stay None → MissingField in post-loop. + let server = TestServer::new(create_coverage_test_app()); + + let form = MultipartForm::new().add_text("name", "Alice"); + let response = server.post("/numeric-char-test").multipart(form).await; + + response.assert_status(axum::http::StatusCode::BAD_REQUEST); + let body = response.text(); + assert!( + body.contains("Missing field"), + "Expected MissingField error, got: {body}" + ); +} diff --git a/examples/axum-example/tests/snapshots/integration_test__openapi.snap b/examples/axum-example/tests/snapshots/integration_test__openapi.snap index 69ff2d1..be6d493 100644 --- a/examples/axum-example/tests/snapshots/integration_test__openapi.snap +++ b/examples/axum-example/tests/snapshots/integration_test__openapi.snap @@ -3778,7 +3778,7 @@ expression: "std::fs::read_to_string(\"openapi.json\").unwrap()" "format": "binary", "nullable": true }, - "is_active": { + "isActive": { "type": "boolean", "nullable": true }, diff --git a/openapi.json b/openapi.json index 490ddea..bf41bd6 100644 --- a/openapi.json +++ b/openapi.json @@ -3774,7 +3774,7 @@ "format": "binary", "nullable": true }, - "is_active": { + "isActive": { "type": "boolean", "nullable": true },