Skip to content

Commit 4621ec7

Browse files
fix: #[trace] compilation error when using trait objects (#130)
Co-authored-by: Andy Lok <andylokandy@hotmail.com>
1 parent 99f2c9e commit 4621ec7

6 files changed

Lines changed: 124 additions & 8 deletions

File tree

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

fastrace-macro/src/lib.rs

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ use syn::punctuated::Punctuated;
1919
use syn::spanned::Spanned;
2020
use syn::*;
2121

22+
use crate::visit_mut::VisitMut;
23+
2224
/// An attribute macro designed to eliminate boilerplate code.
2325
///
2426
/// This macro automatically creates a span for the annotated function. The span name defaults to
@@ -124,14 +126,13 @@ pub fn trace(
124126

125127
let args = parse_macro_input!(args as Args);
126128
let input = parse_macro_input!(item as ItemFn);
127-
128129
let func_name = input.sig.ident.to_string();
129-
// check for async_trait-like patterns in the block, and instrument
130-
// the future instead of the wrapper
130+
131+
// Check for async_trait-like patterns in the block, and instrument
132+
// the future instead of the wrapper.
131133
let func_body = if let Some(internal_fun) =
132134
get_async_trait_info(&input.block, input.sig.asyncness.is_some())
133135
{
134-
// let's rewrite some statements!
135136
match internal_fun.kind {
136137
// async-trait <= 0.1.43
137138
AsyncTraitKind::Function => {
@@ -141,23 +142,26 @@ pub fn trace(
141142
}
142143
// async-trait >= 0.1.44
143144
AsyncTraitKind::Async(async_expr) => {
144-
// fallback if we couldn't find the '__async_trait' binding, might be
145-
// useful for crates exhibiting the same behaviors as async-trait
146145
let instrumented_block =
147-
gen_block(&func_name, &async_expr.block, true, false, &args);
146+
gen_block(&func_name, &async_expr.block, true, false, &args, None);
148147
let async_attrs = &async_expr.attrs;
149148
quote::quote! {
150149
Box::pin(#(#async_attrs) * #instrumented_block)
151150
}
152151
}
153152
}
154153
} else {
154+
let output_ty = match input.sig.output {
155+
ReturnType::Type(_, ref ty) => (**ty).clone(),
156+
ReturnType::Default => parse_quote! { () },
157+
};
155158
gen_block(
156159
&func_name,
157160
&input.block,
158161
input.sig.asyncness.is_some(),
159162
input.sig.asyncness.is_some(),
160163
&args,
164+
Some(output_ty),
161165
)
162166
};
163167

@@ -371,10 +375,14 @@ fn gen_block(
371375
async_context: bool,
372376
async_keyword: bool,
373377
args: &Args,
378+
output_ty: Option<Type>,
374379
) -> proc_macro2::TokenStream {
375380
let name = gen_name(block.span(), func_name, args);
376381
let properties = gen_properties(block.span(), args);
377382
let crate_path = &args.crate_path;
383+
let output_ty_hint = output_ty
384+
.map(erase_impl_trait)
385+
.unwrap_or_else(|| parse_quote! { _ });
378386

379387
// Generate the instrumented function body.
380388
// If the function is an `async fn`, this will wrap it in an async block.
@@ -392,7 +400,10 @@ fn gen_block(
392400
{
393401
let __span__ = #crate_path::Span::enter_with_local_parent( #name ) #properties;
394402
#crate_path::future::FutureExt::in_span(
395-
async move { #block },
403+
async move {
404+
let __ret__: #output_ty_hint = #block;
405+
__ret__
406+
},
396407
__span__,
397408
)
398409
}
@@ -547,3 +558,25 @@ fn path_to_string(path: &Path) -> String {
547558
}
548559
res
549560
}
561+
562+
/// Replaces any `impl Trait` with `_` so it can be used as the type in
563+
/// a `let` statement's LHS.
564+
struct ImplTraitEraser;
565+
566+
impl VisitMut for ImplTraitEraser {
567+
fn visit_type_mut(&mut self, t: &mut Type) {
568+
if let Type::ImplTrait(..) = t {
569+
*t = syn::TypeInfer {
570+
underscore_token: Token![_](t.span()),
571+
}
572+
.into();
573+
} else {
574+
syn::visit_mut::visit_type_mut(self, t);
575+
}
576+
}
577+
}
578+
579+
fn erase_impl_trait(mut ty: Type) -> Type {
580+
ImplTraitEraser.visit_type_mut(&mut ty);
581+
ty
582+
}

tests/macros/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ log = { workspace = true }
1717
logcall = { version = "0.1" }
1818
tokio = { workspace = true }
1919
trybuild = { version = "1.0" }
20+
futures = { version = "0.3" }
21+
async-stream = { version = "0.3" }
2022

2123
# The procedural macro `trace` only supports async-trait higher than or equal to 0.1.52
2224
async-trait = { version = "0.1.52" }
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
use async_trait::async_trait;
2+
use std::future::Future;
3+
use std::pin::Pin;
4+
5+
use fastrace::trace;
6+
7+
#[derive(Debug)]
8+
pub struct InnerError;
9+
10+
#[derive(Debug)]
11+
pub struct OuterError(InnerError);
12+
13+
pub type MyFuture = Pin<Box<dyn Future<Output = Result<u32, OuterError>> + Send>>;
14+
15+
#[async_trait]
16+
pub trait MyTrait {
17+
async fn f() -> Result<MyFuture, OuterError>;
18+
}
19+
20+
pub struct MyStruct;
21+
22+
#[async_trait]
23+
impl MyTrait for MyStruct {
24+
#[trace]
25+
async fn f() -> Result<MyFuture, OuterError> {
26+
let inner = async { Err(InnerError) };
27+
28+
let mapped = async move { inner.await.map_err(OuterError) };
29+
30+
Ok(Box::pin(mapped))
31+
}
32+
}
33+
34+
#[tokio::main]
35+
async fn main() {
36+
let _ = MyStruct::f().await;
37+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
use std::future::Future;
2+
use std::pin::Pin;
3+
4+
use fastrace::trace;
5+
6+
#[derive(Debug)]
7+
pub struct InnerError;
8+
9+
#[derive(Debug)]
10+
pub struct OuterError(InnerError);
11+
12+
pub type MyFuture = Pin<Box<dyn Future<Output = Result<u32, OuterError>> + Send>>;
13+
14+
#[trace]
15+
pub async fn f() -> Result<MyFuture, OuterError> {
16+
let inner = async { Err(InnerError) };
17+
18+
let mapped = async move { inner.await.map_err(OuterError) };
19+
20+
Ok(Box::pin(mapped))
21+
}
22+
23+
#[tokio::main]
24+
async fn main() {
25+
let _ = f().await;
26+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
use fastrace::trace;
2+
use futures::Stream;
3+
4+
#[trace]
5+
async fn stream() -> impl Stream<Item = i64> {
6+
async_stream::stream! {
7+
for i in 0..100 {
8+
yield i;
9+
}
10+
}
11+
}
12+
13+
#[tokio::main]
14+
async fn main() {
15+
let _ = stream().await;
16+
}

0 commit comments

Comments
 (0)