Skip to content

Commit 6094b97

Browse files
committed
Add _generic field for serialization of generics, allow exclusion with SerializeGenerics flag, and improve handling of union type discriminators
1 parent f1227df commit 6094b97

5 files changed

Lines changed: 261 additions & 30 deletions

File tree

core/shared/src/main/scala-2/fabric/rw/RWMacros.scala

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,15 +236,40 @@ object RWMacros {
236236
q"$key -> t.$memberName.json"
237237
}
238238
val allMap = toMap ++ extraMap
239-
context.Expr[Reader[T]](q"""
239+
val typeArgs = tpe.typeArgs
240+
val typeParams = tpe.typeSymbol.asClass.typeParams
241+
val hasTypeArgs = typeArgs.nonEmpty
242+
context.Expr[Reader[T]](
243+
if (hasTypeArgs) {
244+
val genericEntries = typeParams.zip(typeArgs).map { case (param, arg) =>
245+
val name = param.name.decodedName.toString
246+
q"$name -> implicitly[RW[$arg]].definition.json"
247+
}
248+
q"""
249+
import _root_.fabric._
250+
import _root_.fabric.rw._
251+
import _root_.fabric.define._
252+
import _root_.scala.collection.immutable.VectorMap
253+
254+
new ClassR[$tpe] {
255+
override protected def t2Map(t: $tpe): Map[String, Json] = {
256+
val base = VectorMap(..$allMap)
257+
if (RW.SerializeGenerics) base + ("_generic" -> Obj(..$genericEntries))
258+
else base
259+
}
260+
}
261+
"""
262+
}
263+
else q"""
240264
import _root_.fabric._
241265
import _root_.fabric.rw._
242266
import _root_.scala.collection.immutable.VectorMap
243267

244268
new ClassR[$tpe] {
245269
override protected def t2Map(t: $tpe): Map[String, Json] = VectorMap(..$allMap)
246270
}
247-
""")
271+
"""
272+
)
248273
case None =>
249274
val caseObjects = companion.typeSignature.members.collect {
250275
case s: ModuleSymbol if s.moduleClass.asType.toType <:< tpe => s.name

core/shared/src/main/scala-3/fabric/rw/CompileRW.scala

Lines changed: 164 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -595,7 +595,8 @@ object CompileRW extends CompileRW {
595595
}
596596

597597
/** Generate RW for Scala 3 union types (A | B | C).
598-
* Extracts union member types and delegates to genSealedTraitFromChildren. */
598+
* Handles the case where multiple union members share the same base class but differ by type parameters
599+
* (e.g. `Id[String] | Id[Int]`) by using full parameterized type names as discriminators. */
599600
def genUnionMacro[T: Type](orType: Any)(using Quotes): Expr[RW[T]] = {
600601
import quotes.reflect._
601602

@@ -607,26 +608,140 @@ object CompileRW extends CompileRW {
607608

608609
val memberTypes = flattenUnion(orType.asInstanceOf[TypeRepr])
609610

610-
// Generate child RW pairs — same as sealed trait children
611-
val childExprs = memberTypes.map { memberType =>
611+
// Detect if any members share the same base class (type parameter collision)
612+
val simpleNames = memberTypes.map(t => getSimpleTypeNameFromType(t))
613+
val hasCollisions = simpleNames.distinct.size != simpleNames.size
614+
615+
if (hasCollisions) {
616+
genCollisionUnionMacro[T](memberTypes)
617+
} else {
618+
// No collisions — use standard poly generation
619+
val childExprs = memberTypes.map { memberType =>
620+
memberType.asType match {
621+
case '[t] =>
622+
val rw = Expr.summon[RW[t]].getOrElse {
623+
val childSym = memberType.typeSymbol
624+
if (childSym.isClassDef && childSym.flags.is(Flags.Case)) {
625+
genMacro[t]
626+
} else {
627+
report.errorAndAbort(s"No RW found for union member type ${memberType.show}. Ensure all union member types have an RW instance.")
628+
}
629+
}
630+
val name = getSimpleTypeNameFromType(memberType)
631+
'{ (${ Expr(name) }, $rw.asInstanceOf[RW[_]]) }
632+
}
633+
}
634+
val childRWsExpr = Expr.ofList(childExprs)
635+
genPolyRW[T](childRWsExpr)
636+
}
637+
}
638+
639+
/** Generate RW for union types where multiple members share the same base class (e.g. `Id[String] | Id[Int]`).
640+
* Uses full parameterized type names as discriminators and compile-time type matching for the write path
641+
* since runtime class inspection can't distinguish erased generic variants. */
642+
private def genCollisionUnionMacro[T: Type](memberTypes: List[Any])(using Quotes): Expr[RW[T]] = {
643+
import quotes.reflect._
644+
645+
val members = memberTypes.asInstanceOf[List[TypeRepr]]
646+
647+
// Build child RW pairs — always generate fresh RWs (not summoned) so each gets concrete _generic info
648+
val childExprs = members.map { memberType =>
612649
memberType.asType match {
613650
case '[t] =>
614-
val rw = Expr.summon[RW[t]].getOrElse {
615-
val childSym = memberType.typeSymbol
616-
if (childSym.isClassDef && childSym.flags.is(Flags.Case)) {
617-
genMacro[t]
618-
} else {
619-
report.errorAndAbort(s"No RW found for union member type ${memberType.show}. Ensure all union member types have an RW instance.")
651+
val childSym = memberType.typeSymbol
652+
// Always generate fresh to ensure _generic reflects the concrete type args
653+
val rw = if (childSym.isClassDef && childSym.flags.is(Flags.Case)) {
654+
genMacro[t]
655+
} else {
656+
Expr.summon[RW[t]].getOrElse {
657+
report.errorAndAbort(s"No RW found for union member type ${memberType.show}.")
620658
}
621659
}
622-
val name = getSimpleTypeNameFromType(memberType)
623-
'{ (${ Expr(name) }, $rw.asInstanceOf[RW[_]]) }
660+
val simpleName = getSimpleTypeNameFromType(memberType)
661+
val fullName = fullTypeName(memberType)
662+
'{ (${ Expr(simpleName) }, ${ Expr(fullName) }, $rw.asInstanceOf[RW[_]]) }
624663
}
625664
}
626-
val childRWsExpr = Expr.ofList(childExprs)
665+
val childListExpr = Expr.ofList(childExprs)
627666

628-
// Reuse the same polymorphic RW generation as sealed traits
629-
genPolyRW[T](childRWsExpr)
667+
val fullTypeNameStr = members.map(fullTypeName(_)).mkString(" | ")
668+
val fullTypeNameExpr = Expr(fullTypeNameStr)
669+
670+
'{
671+
new RW[T] {
672+
private val typeField = "type"
673+
private lazy val childRWs: List[(String, String, RW[_])] = $childListExpr
674+
675+
private def matchGeneric(json: Json, candidates: List[(String, String, RW[_])]): Option[(String, RW[_])] = {
676+
json match {
677+
case Obj(map) =>
678+
map.get("_generic") match {
679+
case Some(genericJson) =>
680+
// Match by comparing _generic content against each candidate's definition.genericTypes
681+
candidates.find { case (_, _, rw) =>
682+
val expected = Obj(rw.definition.genericTypes.map(gt => gt.name -> gt.definition.json): _*)
683+
expected == genericJson
684+
}.map(c => (c._2, c._3))
685+
case None =>
686+
// No _generic field — take first candidate
687+
candidates.headOption.map(c => (c._2, c._3))
688+
}
689+
case _ => candidates.headOption.map(c => (c._2, c._3))
690+
}
691+
}
692+
693+
override def read(value: T): Json = {
694+
val simpleName = safeTypeName(value)
695+
val candidates = childRWs.filter(_._1 == simpleName)
696+
// Use first candidate for read — the child RW will embed _generic in its output
697+
candidates.headOption match {
698+
case Some((_, _, rw)) =>
699+
rw.asInstanceOf[RW[T]].read(value) match {
700+
case obj: Obj => obj.merge(Obj(typeField -> Str(simpleName)))
701+
case other => Obj(typeField -> Str(simpleName), "value" -> other)
702+
}
703+
case None => throw RWException(s"Unknown subtype: $simpleName")
704+
}
705+
}
706+
707+
override def write(json: Json): T = json match {
708+
case obj @ Obj(map) =>
709+
map.get(typeField) match {
710+
case Some(Str(typeName, _)) =>
711+
val candidates = childRWs.filter(_._1 == typeName)
712+
val (_, rw) = if (candidates.size > 1) {
713+
// Collision — use _generic to disambiguate
714+
val cleanedJson = Obj(map - typeField)
715+
matchGeneric(cleanedJson, candidates).getOrElse(
716+
throw RWException(s"Cannot disambiguate type '$typeName' — no matching _generic found. Available: ${candidates.map(_._2).mkString(", ")}")
717+
)
718+
} else {
719+
candidates.headOption.map(c => (c._2, c._3)).getOrElse(
720+
throw RWException(s"Unknown type discriminator: $typeName")
721+
)
722+
}
723+
val cleanedMap = map - typeField
724+
val cleanedJson = if (cleanedMap.isEmpty && map.size == 2 && map.contains("value")) {
725+
map("value")
726+
} else {
727+
Obj(cleanedMap)
728+
}
729+
rw.asInstanceOf[RW[T]].write(cleanedJson)
730+
case _ =>
731+
throw RWException(s"Missing or invalid '$typeField' field in JSON for union type")
732+
}
733+
case _ =>
734+
throw RWException(s"Expected JSON object for union type, got: $json")
735+
}
736+
737+
override def definition: FabricDefinition = {
738+
val childDefs = childRWs.map { case (_, fullName, rw) =>
739+
fullName -> rw.definition
740+
}.toMap.to(VectorMap)
741+
FabricDefinition(DefType.Poly(childDefs), className = Some($fullTypeNameExpr))
742+
}
743+
}
744+
}
630745
}
631746

632747
private def cleanFullName(name: String): String =
@@ -799,6 +914,36 @@ object CompileRW extends CompileRW {
799914
val genericTypesExpr = generateGenericTypes(tpe)
800915
val fieldGenericNamesExpr = extractFieldGenericNames(tpe)
801916

917+
// Check if this type has type arguments (is a generic instantiation)
918+
val typeParamSyms = typeSymbol.primaryConstructor.paramSymss.headOption match {
919+
case Some(params) if params.nonEmpty && params.head.isTypeParam => params
920+
case _ => Nil
921+
}
922+
val hasTypeArgs = tpe match {
923+
case AppliedType(_, _) if typeParamSyms.nonEmpty => true
924+
case _ => false
925+
}
926+
927+
// Generate _generic JSON value at macro time
928+
val genericJsonExpr: Option[Expr[Json]] = if (hasTypeArgs) {
929+
tpe match {
930+
case AppliedType(_, args) =>
931+
val entries = typeParamSyms.zip(args).map { case (param, arg) =>
932+
val nameExpr = Expr(param.name)
933+
arg.asType match {
934+
case '[t] =>
935+
val rw = Expr.summon[RW[t]].getOrElse {
936+
report.errorAndAbort(s"No RW found for type parameter ${param.name} (${arg.show})")
937+
}
938+
'{ ($nameExpr, $rw.definition.json) }
939+
}
940+
}
941+
val list = Expr.ofList(entries)
942+
Some('{ Obj($list: _*) })
943+
case _ => None
944+
}
945+
} else None
946+
802947
'{
803948
new ClassRW[T] {
804949
override protected def t2Map(t: T): Map[String, Json] = {
@@ -807,8 +952,12 @@ object CompileRW extends CompileRW {
807952
case Some(gen) => '{ base ++ ${ gen('{t}) } }
808953
case None => '{ base }
809954
}}
810-
${ if (hasTransient) '{ withExtra -- $transientFieldsExpr }
955+
val withTransient = ${ if (hasTransient) '{ withExtra -- $transientFieldsExpr }
811956
else '{ withExtra } }
957+
${ genericJsonExpr match {
958+
case Some(gj) => '{ if (RW.SerializeGenerics) withTransient + ("_generic" -> $gj) else withTransient }
959+
case None => '{ withTransient }
960+
}}
812961
}
813962

814963
override protected def map2T(map: Map[String, Json]): T = {

core/shared/src/main/scala/fabric/rw/RW.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,13 @@ trait RW[T] extends Reader[T] with Writer[T] {
3939
}
4040

4141
object RW extends CompileRW {
42+
/** Controls whether generic type information (`_generic` field) is included in serialized JSON output for generic
43+
* case classes. Defaults to `true`. Set to `false` to exclude `_generic` from output, which produces cleaner JSON
44+
* but loses the ability to disambiguate erased generic variants during deserialization (e.g. in union types like
45+
* `Id[String] | Id[Int]`).
46+
*/
47+
var SerializeGenerics: Boolean = true
48+
4249
def from[T](r: T => Json, w: Json => T, d: => Definition): RW[T] = new RW[T] {
4350
override def write(value: Json): T = w(value)
4451

core/shared/src/test/scala-3/spec/Scala3Spec.scala

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,39 @@ class Scala3Spec extends AnyWordSpec with Matchers {
103103
dog.json.as[Cat | Dog | Fish] should be(Dog("Buddy", "Poodle"))
104104
fishJson.as[Cat | Dog | Fish] should be(Fish("Nemo", saltwater = true))
105105
}
106+
"handle union types with generic collision — deserialization distinguishes by _generic" in {
107+
import GenericCollisionTest._
108+
type T = Box[String] | Box[Int]
109+
given RW[T] = RW.gen
110+
111+
// Box[String] serializes with value as a string, Box[Int] with value as a number
112+
val strBox = Box[String]("hello")
113+
val intBox = Box[Int](42)
114+
115+
// Serialize via concrete RWs (not the union) to get correct _generic
116+
val strJson = Box.rw[String].read(strBox)
117+
val intJson = Box.rw[Int].read(intBox)
118+
119+
// Verify _generic is present and different
120+
strJson("_generic")("T")("type").asString should be("string")
121+
intJson("_generic")("T")("type").asString should be("numeric")
122+
123+
// Wrap with type field as the union would
124+
val unionStr = strJson.asObj.merge(Obj("type" -> Str("Box"))).asObj
125+
val unionInt = intJson.asObj.merge(Obj("type" -> Str("Box"))).asObj
126+
127+
// Deserialize through union RW — _generic disambiguates
128+
val restored1 = unionStr.as[T]
129+
restored1 should be(Box("hello"))
130+
131+
val restored2 = unionInt.as[T]
132+
restored2 should be(Box(42))
133+
134+
// Prove disambiguation worked: Box[Int] deserialized the value as Int, not String
135+
// If it used the wrong child RW, value would be "42" (String) instead of 42 (Int)
136+
restored2.asInstanceOf[Box[Int]].value should be(42)
137+
restored1.asInstanceOf[Box[String]].value should be("hello")
138+
}
106139
"handle enums with parameters (ADT enums)" in {
107140
import ParameterizedEnumTest._
108141
val rgb: Shape = Shape.Circle(5.0)
@@ -454,3 +487,12 @@ object DefaultTest {
454487
given rw: RW[Config] = RW.gen
455488
}
456489
}
490+
491+
case class Id[T](value: String) derives RW
492+
493+
object GenericCollisionTest {
494+
case class Box[T](value: T)
495+
object Box {
496+
def rw[T: RW]: RW[Box[T]] = RW.gen
497+
}
498+
}

core/shared/src/test/scala/spec/RWSpecAuto.scala

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,11 @@ class RWSpecAuto extends AnyWordSpec with Matchers {
6464
Some(Address("Norman", "Oklahoma"))
6565
)
6666
val value = w.json
67-
value should be(
68-
obj(
69-
"name" -> "Test1",
70-
"value" -> obj("city" -> "San Jose", "state" -> "California"),
71-
"other" -> obj("city" -> "Norman", "state" -> "Oklahoma")
72-
)
73-
)
67+
// _generic is included for generic case classes
68+
value("name").asString should be("Test1")
69+
value("value") should be(obj("city" -> "San Jose", "state" -> "California"))
70+
value("other") should be(obj("city" -> "Norman", "state" -> "Oklahoma"))
71+
value("_generic").asObj.value.contains("T") should be(true)
7472
val w2 = value.as[Wrapper[Address]]
7573
w2 should be(w)
7674
}
@@ -96,6 +94,19 @@ class RWSpecAuto extends AnyWordSpec with Matchers {
9694
fields("value").genericName should be(Some("T"))
9795
fields("other").genericName should be(Some("T"))
9896
}
97+
"exclude _generic when SerializeGenerics is false" in {
98+
val prev = RW.SerializeGenerics
99+
try {
100+
RW.SerializeGenerics = false
101+
val w = Wrapper("Test", "hello", None)
102+
val value = w.json
103+
value.asObj.get("_generic") should be(None)
104+
// Round-trip still works without _generic
105+
value.as[Wrapper[String]] should be(w)
106+
} finally {
107+
RW.SerializeGenerics = prev
108+
}
109+
}
99110
"verify empty genericTypes for non-generic Person" in {
100111
Person.rw.definition.genericTypes should be(Nil)
101112
}
@@ -121,12 +132,9 @@ class RWSpecAuto extends AnyWordSpec with Matchers {
121132
Some(obj("city" -> "Norman"))
122133
)
123134
val value = w.json
124-
value should be(
125-
obj(
126-
"name" -> "Test2",
127-
"value" -> obj("city" -> "San Jose"),
128-
"other" -> obj("city" -> "Norman")
129-
)
135+
value("name").asString should be("Test2")
136+
value("value") should be(obj("city" -> "San Jose"))
137+
value("other") should be(obj("city" -> "Norman")
130138
)
131139
}
132140
"verify Person's DefType" in {

0 commit comments

Comments
 (0)