Skip to content

Commit f8005e2

Browse files
committed
Enable codecs re-using or discarding parameters
1 parent d6937f0 commit f8005e2

File tree

2 files changed

+53
-32
lines changed

2 files changed

+53
-32
lines changed

core/shared/src/main/scala/porcupine/codec.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import scala.deriving.Mirror
2727
trait Encoder[A]:
2828
outer =>
2929

30-
def parameters: Int
30+
def parameters: State[Int, List[Int]]
3131

3232
def encode(a: A): List[LiteValue]
3333

@@ -53,7 +53,7 @@ object Encoder:
5353

5454
def product[A, B](fa: Encoder[A], fb: Encoder[B]) = new:
5555
def parameters =
56-
fa.parameters + fb.parameters
56+
(fa.parameters, fb.parameters).mapN(_ ++ _)
5757

5858
def encode(ab: (A, B)) =
5959
val (a, b) = ab
@@ -130,7 +130,7 @@ object Codec:
130130
apply: T => LiteValue,
131131
unapply: PartialFunction[LiteValue, T],
132132
) extends Codec[T] {
133-
override def parameters: Int = 1
133+
override def parameters: State[Int, List[Int]] = State(idx => (idx + 1, List(idx)))
134134
override def encode(a: T): List[LiteValue] = apply(a) :: Nil
135135
override def decode: StateT[Either[Throwable, *], List[LiteValue], T] = StateT {
136136
case unapply(l) :: tail => Right((tail, l))
@@ -154,7 +154,7 @@ object Codec:
154154
new Simple("NULL", _ => LiteValue.Null, { case LiteValue.Null => None })
155155

156156
def unit: Codec[Unit] = new:
157-
def parameters: Int = 0
157+
def parameters = State.pure(List.empty)
158158
def encode(u: Unit) = Nil
159159
def decode = StateT.pure(())
160160

@@ -165,7 +165,7 @@ object Codec:
165165

166166
def product[A, B](fa: Codec[A], fb: Codec[B]) = new:
167167
def parameters =
168-
fa.parameters + fb.parameters
168+
(fa.parameters, fb.parameters).mapN(_ ++ _)
169169

170170
def encode(ab: (A, B)) =
171171
val (a, b) = ab

core/shared/src/main/scala/porcupine/sql.scala

Lines changed: 48 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -35,41 +35,61 @@ object Query:
3535
Query(fab.sql, fab.encoder.contramap(f), fab.decoder.map(g))
3636

3737
final class Fragment[A](
38-
val parts: List[Either[String, Int]],
38+
val part: Fragment.Part,
3939
val encoder: Encoder[A],
4040
):
41-
def sql: String = parts.foldMap {
42-
case Left(s) => s
43-
case Right(i) => ("?, " * (i - 1)) ++ "?"
44-
}
41+
def sql: String = part.compile.runA(1).value
4542

4643
def command: Query[A, Unit] = Query(sql, encoder, Codec.unit)
4744

4845
def query[B](decoder: Decoder[B]): Query[A, B] = Query(sql, encoder, decoder)
4946

50-
def apply(a: A): Fragment[Unit] = Fragment(parts, encoder.contramap(_ => a))
47+
def apply(a: A): Fragment[Unit] = Fragment(part, encoder.contramap(_ => a))
5148

5249
def stripMargin: Fragment[A] = stripMargin('|')
5350

5451
def stripMargin(marginChar: Char): Fragment[A] =
55-
val head = parts.headOption
56-
val tail = parts.tail
57-
val ps = head.map {
58-
_.leftMap(_.stripMargin(marginChar))
59-
}.toList ++ tail.map {
60-
_.leftMap(str =>
61-
str.takeWhile(_ != '\n') + str.dropWhile(_ != '\n').stripMargin(marginChar),
62-
)
63-
}
64-
Fragment(ps, encoder)
52+
Fragment(part.stripMargin(true, marginChar), encoder)
6553

6654
object Fragment:
55+
sealed trait Part:
56+
def compile: State[Int, String]
57+
def concatenate(other: Part): Part = other match {
58+
case Part.Concatenate(values) => Part.Concatenate(this :: values)
59+
case _ => Part.Concatenate(List(this, other))
60+
}
61+
def stripMargin(head: Boolean, marginChar: Char): Part
62+
63+
object Part:
64+
final case class Literal(x: String) extends Part:
65+
def compile = State.pure(x)
66+
def stripMargin(head: Boolean, marginChar: Char) =
67+
if (head) Literal(x.stripMargin(marginChar))
68+
else Literal(x.takeWhile(_ != '\n') ++ x.dropWhile(_ != '\n').stripMargin(marginChar))
69+
final case class Concatenate(values: List[Part]) extends Part:
70+
def compile = values.traverse(_.compile).map(_.combineAll)
71+
override def concatenate(other: Part) = other match {
72+
case Concatenate(values) => Concatenate(this.values ++ values)
73+
case _ => Concatenate(this.values :+ other)
74+
}
75+
def stripMargin(head: Boolean, marginChar: Char) =
76+
values match {
77+
case h :: t =>
78+
Concatenate(
79+
h.stripMargin(head, marginChar) :: t.map(_.stripMargin(false, marginChar)),
80+
)
81+
case other => this
82+
}
83+
final case class Parameters(advance: State[Int, List[Int]]) extends Part:
84+
def compile = advance.map(_.map(idx => s"?$idx").mkString(", "))
85+
def stripMargin(head: Boolean, marginChar: Char) = this
86+
6787
given ContravariantMonoidal[Fragment] = new:
68-
val unit = Fragment(List.empty, Codec.unit)
88+
val unit = Fragment(Part.Concatenate(List.empty), Codec.unit)
6989
def product[A, B](fa: Fragment[A], fb: Fragment[B]) =
70-
Fragment(fa.parts ++ fb.parts, (fa.encoder, fb.encoder).tupled)
90+
Fragment(fa.part.concatenate(fb.part), (fa.encoder, fb.encoder).tupled)
7191
def contramap[A, B](fa: Fragment[A])(f: B => A) =
72-
Fragment(fa.parts, fa.encoder.contramap(f))
92+
Fragment(fa.part, fa.encoder.contramap(f))
7393

7494
given Monoid[Fragment[Unit]] = new:
7595
def empty = ContravariantMonoidal[Fragment].unit
@@ -93,13 +113,13 @@ private def sqlImpl(
93113

94114
// TODO appending to `List` is slow
95115
val fragment =
96-
parts.zipAll(args, '{ "" }, '{ "" }).foldLeft('{ List.empty[Either[String, Int]] }) {
97-
case ('{ $acc: List[Either[String, Int]] }, ('{ $p: String }, '{ $s: String })) =>
98-
'{ $acc :+ Left($p) :+ Left($s) }
99-
case ('{ $acc: List[Either[String, Int]] }, ('{ $p: String }, '{ $e: Encoder[t] })) =>
100-
'{ $acc :+ Left($p) :+ Right($e.parameters) }
101-
case ('{ $acc: List[Either[String, Int]] }, ('{ $p: String }, '{ $f: Fragment[t] })) =>
102-
'{ $acc :+ Left($p) :++ $f.parts }
116+
parts.zipAll(args, '{ "" }, '{ "" }).foldLeft('{ List.empty[Fragment.Part] }) {
117+
case ('{ $acc: List[Fragment.Part] }, ('{ $p: String }, '{ $s: String })) =>
118+
'{ $acc :+ Fragment.Part.Literal($p) :+ Fragment.Part.Literal($s) }
119+
case ('{ $acc: List[Fragment.Part] }, ('{ $p: String }, '{ $e: Encoder[t] })) =>
120+
'{ $acc :+ Fragment.Part.Literal($p) :+ Fragment.Part.Parameters($e.parameters) }
121+
case ('{ $acc: List[Fragment.Part] }, ('{ $p: String }, '{ $f: Fragment[t] })) =>
122+
'{ $acc :+ Fragment.Part.Literal($p) :+ $f.part }
103123
}
104124

105125
val encoder = args.collect {
@@ -125,5 +145,6 @@ private def sqlImpl(
125145
}
126146

127147
(fragment, encoder) match
128-
case ('{ $s: List[Either[String, Int]] }, '{ $e: Encoder[a] }) => '{ Fragment[a]($s, $e) }
148+
case ('{ $s: List[Fragment.Part] }, '{ $e: Encoder[a] }) =>
149+
'{ Fragment[a](Fragment.Part.Concatenate($s), $e) }
129150
case _ => sys.error("porcupine pricked itself")

0 commit comments

Comments
 (0)