Full Haskell-like Type Class resolution in Java

TL;DR #

This post explores how far Java’s type system can be pushed by building a small Haskell-style type class instance resolution engine entirely in plain Java.

Using only reflection, generic type metadata, and a tiny first-order unifier, we can automatically resolve type class instances (including higher-kinded ones), support overlapping instances, and derive witnesses for arbitrarily nested generic types.

It’s not intended for production, but as an experiment it reveals just how much structure Java actually preserves at runtime — and how surprisingly close it can get to type-level programming without language changes.

Goal: First-Order Type Classes #

Given:

@TypeClass
interface Show<A> {
  String show(A value);

  // Convenience shortcut
  static <A> String show(Show<A> showA, A value) {
    return showA.show(value);
  }

  @TypeClass.Witness
  static Show<Integer> showInteger() { ... }

  @TypeClass.Witness
  static <A> Show<List<A>> showList(Show<A> showA) { ... }
}

Instead of:

Show<List<List<Integer>>> w =
    Show.showList(Show.showList(Show.showInteger()));

println(Show.show(w, List.of(List.of(1, 2), List.of(3, 4))));
// Prints: [[1, 2], [3, 4]]

We want:

println(Show.show(witness(), List.of(List.of(1, 2), List.of(3, 4))));
// Prints: [[1, 2], [3, 4]]

Which means that the witness() method was able to automatically produce the value of type Show<List<List<Integer>>> for us!

How do we get there?

First, we must understand some aspects of how Java generics work at runtime.

Note: I use the words 'witness' and 'type class instance' interchangeably.

Aside: Java Generics and Type Erasure #

It is a well-known fact that Java generics get erased at runtime.

However, there are a few scenarios in which generics do get reified!

Note: 'Type Reification' is the process by which type information is made concrete and available at runtime.

Given:

interface Stream<T> { ... }

interface IntStream extends Stream<Integer> { ... }

Then:

System.out.println(IntStream.class.getInterfaces()[0]);
// Prints: Stream

System.out.println(IntStream.class.getGenericInterfaces()[0]);
// Prints: Stream<Integer>

This means that Java did preserve the generic type arguments for the supertype Stream<Integer>. (We just need to know where to look for it.)

Note that this also occurs with anonymous classes:

var s = new Stream<String>() { ... };

System.out.println(s.getClass().getGenericInterfaces()[0]);
// Prints: Stream<String>

This leads us to our first workaround:

Aside: Capturing types #

We define:

interface Ty<T> {
  default Type type() {
    return requireNonNull(
        ((ParameterizedType) getClass().getGenericInterfaces()[0])
            .getActualTypeArguments()[0]);
  }
}

Which lets us write:

Type type = new Ty<Map<String, List<Integer>>>() {}.type();

System.out.println(type);
// Prints: Map<String, List<Integer>>

Why do we need this?

Because our witness() method will need a runtime representation of the type that it is trying to instantiate.

So our use case has turned into:

Show.show(witness(new Ty<>() {}), List.of(List.of(1, 2), List.of(3, 4)));

Note that we do not have to specify the type argument for Ty<>. Thankfully, type inference does that for us.

Type inference works in this case because the List.of(...) parameter has a well-defined type, and the call to <A> Show.show(Show<A>, A) lets type inference flow from the second argument to the first.

Aside: Understanding java.lang.reflect.Type's hierarchy #

The standard java.lang.reflect.Type hierarchy consists of:

Where:

Moreover, a java.lang.reflect.Method can be queried via TypeVariable<?>[] getTypeParameters(), Type[] getGenericParameterTypes(), and Type getGenericReturnType().

For example:

class Example {
  static <A> List<A> hello(String s, A[] arr, Optional<?> opt) { ... }
}

Inspecting the Example's hello method at runtime would yield something like:

Method(
  name: "hello",
  typeParameters:
    [TypeVariable(A)],
  genericReturnType:
    ParameterizedType(
      rawType: List.class,
      typeArguments: [TypeVariable(A)],
    ),
  genericParameterTypes:
    [
      String.class,
      GenericArrayType(componentType: TypeVariable(A)),
      ParameterizedType(
        rawType: Optional.class,
        typeArguments: [WildcardType()]
      )
    ]
)

Note that the three ocurrences of TypeVariable(A) are unique in the sense that they compare equal() to each other, but not to other type variable instances, even if their names are also A.

Subgoal: Parsing java.lang.reflect.Type into an AST #

We want to do this for a few reasons:

Here's the AST:

sealed interface ParsedType {
  record Var(TypeVariable<?> java) implements ParsedType {}

  record App(ParsedType fun, ParsedType arg) implements ParsedType {}

  record ArrayOf(ParsedType elementType) implements ParsedType {}

  record Const(Class<?> java) implements ParsedType {}

  record Primitive(Class<?> java) implements ParsedType {}
}

And these are our parsing rules:

sealed interface ParsedType {
  // ...

  static ParsedType parse(Type java) {
    return switch (java) {
      case Class<?> arr when arr.isArray() ->
          new ArrayOf(parse(arr.getComponentType()));
      case Class<?> prim when prim.isPrimitive() ->
          new Primitive(prim);
      case Class<?> c ->
          new Const(c);
      case TypeVariable<?> v ->
          new Var(v);
      case ParameterizedType p ->
          parseAll(p.getActualTypeArguments()).stream()
              .reduce(parse(p.getRawType()), App::new);
      case GenericArrayType a ->
          new ArrayOf(parse(a.getGenericComponentType()));
      case WildcardType w ->
          throw new IllegalArgumentException("Cannot parse wildcard type: " + w);
      default ->
          throw new IllegalArgumentException("Unsupported type: " + java);
    };
  }
}

Note: the rule for ParameterizedType will take a type like T<A, B> and turn it into App(App(Const(T), A), B).

This means that the generic type T is first applied to A and then to B.

For example, Map<Integer, List<String>> becomes:

App(
  App(Const(Map.class), Const(Integer.class)),
  App(Const(List.class), Const(String.class))
)

That's it! Really, it's not much, but the added uniformity will help our code down the line.

Recap #

Until this point we have:

Now, let's move on to the problem of type class instance resolution.

Subgoal: Witness resolution #

Let's consider the following scenario:

@TypeClass
interface Show<A> {
  String show(A value);

  // Convenience shortcut
  static <A> String show(Show<A> showA, A value) {
    return showA.show(value);
  }

  @TypeClass.Witness
  static Show<Integer> showInteger() { ... }

  @TypeClass.Witness
  static <A> Show<List<A>> showList(Show<A> showA) { ... }
}

record Pair<A, B>(A fst, B snd) {
  @TypeClass.Witness
  public static <A, B> Show<Pair<A, B>> show(Show<A> showA, Show<B> showB) { ... }
}

We observe that, for example, in order to summon a witness for Show<List<Pair<Integer, List<Integer>>>, then we must apply several witness constructors recursively.

What is a witness constructor? It is a public static method annotated with @TypeClass.Witness.

But first, we must find the witness constructors!

How do we do that? Do we need to do a runtime scan of all loaded classes?

That would be way too complicated.

In order to reduce our search scope, we can define the following convention:

When trying to resolve a witness C<T> for some type class C and a concrete type T, then we only look for witness constructors within the definitions of C and T and nowhere else.

For example, when looking for Show<Integer>, then we scan the methods of Show and Integer.

Generally, we will prefer to define witness constructors within concrete types and not within type classes. However, for built-in type like Integer, we have no choice and we must define its witness constructors within the relevant type classes, as we have done with static Show<Integer> showInteger().

Now that we are able to find the relevant witness constructors, how do we know which of them we must invoke and in which order?

First, we must understand type unification:

Aside: Type Unification #

Type unification is the process of finding a substitution that makes two types identical.

For example, given Pair<[A], String> and Pair<Integer, String>, then the substitution {A -> Integer} would make the first type identical to the second.

Here, I have placed the type A between square brackets to indicate that it is a type variable. We only substitute type variables!

Unification may fail when it encounters incompatible types. For example, List<Integer> and List<String> cannot be unified because no substitution exists that would make them identical.

Conversely, when two types are already identical, unification succeeds with an empty substitution {}.

Listing: Type Unification algorithm for ParsedType #

class Unification {
  public static Maybe<Map<ParsedType.Var, ParsedType>> unify(ParsedType t1, ParsedType t2) {
    return switch (Pair.of(t1, t2)) {
      case Pair<ParsedType, ParsedType>(ParsedType.Var var1, ParsedType.Primitive p) ->
          Maybe.nothing(); // no primitives in generics
      case Pair<ParsedType, ParsedType>(ParsedType.Var var1, var t) ->
          Maybe.just(Map.of(var1, t));
      case Pair<ParsedType, ParsedType>(ParsedType.Const const1, ParsedType.Const const2)
          when const1.equals(const2) ->
          Maybe.just(Map.of());
      case Pair<ParsedType, ParsedType>(
              ParsedType.App(var fun1, var arg1),
              ParsedType.App(var fun2, var arg2)) ->
          Maybe.apply(Maps::merge, unify(fun1, fun2), unify(arg1, arg2));
      case Pair<ParsedType, ParsedType>(
              ParsedType.ArrayOf(var elem1),
              ParsedType.ArrayOf(var elem2)) ->
          unify(elem1, elem2);
      case Pair<ParsedType, ParsedType>(
              ParsedType.Primitive(var prim1),
              ParsedType.Primitive(var prim2))
          when prim1.equals(prim2) ->
          Maybe.just(Map.of());
      default ->
          Maybe.nothing();
    };
  }

  public static ParsedType substitute(Map<ParsedType.Var, ParsedType> map, ParsedType type) {
    return switch (type) {
      case ParsedType.Var var ->
          map.getOrDefault(var, var);
      case ParsedType.App(var fun, var arg) ->
          new ParsedType.App(substitute(map, fun), substitute(map, arg));
      case ParsedType.ArrayOf var ->
          new ParsedType.ArrayOf(substitute(map, var.elementType()));
      case ParsedType.Primitive p ->
          p;
      case ParsedType.Const c ->
          c;
    };
  }

  public static List<ParsedType> substituteAll(
      Map<ParsedType.Var, ParsedType> map, List<ParsedType> types) {
    return types.stream().map(t -> substitute(map, t)).toList();
  }
}

Extra: A type representation for static methods #

record FuncType(Method java, List<ParsedType> paramTypes, ParsedType returnType) {
  public static FuncType parse(Method method) {
    if (!Modifier.isStatic(method.getModifiers())) {
      throw new IllegalArgumentException("Method must be static: " + method);
    }
    return new FuncType(
        method,
        ParsedType.parseAll(method.getGenericParameterTypes()),
        ParsedType.parse(method.getGenericReturnType()));
  }
}

Subgoal: Witness resolution, part 2 #

Why is type unification relevant?

Type unification will guide the witness resolution process in two ways:

For example, when resolving Show<List<Integer>>, we check:

  1. Which witness constructors can we find?
  1. Which witness constructors apply to our goal?
  1. Recurse until we have no further goals.

That outlines the witness resolution algorithm.

Recap #

Until this point we have:

Now, let's put it all together!

Subgoal: Witness resolution implementation #

Let's start with the witness constructor lookup code:

@Retention(RetentionPolicy.RUNTIME)
@interface TypeClass {
  @Retention(RetentionPolicy.RUNTIME)
  @interface Witness {}
}

class TypeClasses {
  // ...

  private static List<InstanceConstructor> findRules(ParsedType target) {
    return switch (target) {
      case ParsedType.App(var fun, var arg) ->
          Lists.concat(findRules(fun), findRules(arg));
      case ParsedType.Const(var java) ->
          rulesOf(java);
      case ParsedType.Var(var java) ->
          List.of();
      case ParsedType.ArrayOf(var elem) ->
          List.of();
      case ParsedType.Primitive(var java) ->
          List.of();
    };
  }

  private static List<InstanceConstructor> rulesOf(Class<?> cls) {
    return Arrays.stream(cls.getDeclaredMethods())
        .filter(TypeClasses::isWitnessMethod)
        .map(FuncType::parse)
        .map(InstanceConstructor::new)
        .toList();
  }

  private static boolean isWitnessMethod(Method m) {
    return m.accessFlags().contains(PUBLIC)
        && m.accessFlags().contains(STATIC)
        && m.isAnnotationPresent(TypeClass.Witness.class);
  }

  private record InstanceConstructor(FuncType func) {}

  // ...
}

Now, let's move on to how we choose relevant witness constructors and find our subsequent goals:

class TypeClasses {
  // ...

  private static List<Candidate> findCandidates(ParsedType target) {
    return findRules(target).stream()
        .flatMap(
            rule ->
                rule
                    .tryMatch(target)
                    .map(requirements -> new Candidate(rule, requirements))
                    .stream())
        .toList();
  }

  private record Candidate(WitnessRule rule, List<ParsedType> requirements) {}

  // Spoiler: we will have another subtype of WitnessRule later on (;
  private sealed interface WitnessRule {
    Maybe<List<ParsedType>> tryMatch(ParsedType target);

    Object instantiate(List<Object> dependencies);
  }

  private record InstanceConstructor(FuncType func) implements WitnessRule {
    @Override
    public Maybe<List<ParsedType>> tryMatch(ParsedType target) {
      return Unification.unify(func.returnType(), target)
          .map(map -> Unification.substituteAll(map, func.paramTypes()));
    }

    @Override
    public Object instantiate(List<Object> dependencies) {
      try {
        return func.java().invoke(null, dependencies.toArray());
      } catch (Exception e) {
        throw new RuntimeException(e);
      }
    }
  }

  // ...
}

Then, the code that drives the recursive resolution:

class TypeClasses {
  // ...

  private static Either<SummonError, Object> summon(ParsedType target) {
    return switch (ZeroOneMore.of(findCandidates(target, context))) {
      case ZeroOneMore.One<Candidate>(Candidate(var rule, var requirements)) ->
          summonAll(requirements, context)
              .map(rule::instantiate)
              .mapLeft(error -> new SummonError.Nested(target, error));
      case ZeroOneMore.Zero<Candidate>() ->
          Either.left(new SummonError.NotFound(target));
      case ZeroOneMore.More<Candidate>(var candidates) ->
          Either.left(new SummonError.Ambiguous(target, candidates));
    };
  }

  private static Either<SummonError, List<Object>> summonAll(List<ParsedType> targets) {
    return Either.traverse(targets, TypeClasses::summon);
  }

  private sealed interface SummonError {
    record NotFound(ParsedType target) implements SummonError {}

    record Ambiguous(ParsedType target, List<Candidate> candidates) implements SummonError {}

    record Nested(ParsedType target, SummonError cause) implements SummonError {}
  }

  // ...
}

And, finally, the public entry point for all of it:

class TypeClasses {
  public static <T> T witness(Ty<T> ty) {
    return switch (summon(ParsedType.parse(ty.type()))) {
      case Either.Left<SummonError, Object>(SummonError error) ->
          throw new WitnessResolutionException(error);
      case Either.Right<SummonError, Object>(Object instance) -> {
        @SuppressWarnings("unchecked")
        T typedInstance = (T) instance;
        yield typedInstance;
      }
    };
  }

  public static class WitnessResolutionException extends RuntimeException {
    private WitnessResolutionException(SummonError error) {
      super(error.format());
    }
  }

  // ...
}

Surprinsingly, that's it!

You can find a mostly accurate compilation of all of the above code in this Gist. (This was the first version of the type class system that I built.)

Now, let's see it in action.

Examples: First-Order Type Classes & Instances #

Type Class: Show #

Given:

@TypeClass
interface Show<A> {
  String show(A a);

  static <A> String show(Show<A> showA, A a) {
    return showA.show(a);
  }

  @TypeClass.Witness
  static Show<Integer> integerShow() {
    return i -> Integer.toString(i);
  }


  @TypeClass.Witness
  static Show<String> stringShow() {
    return s -> "\"" + s + "\"";
  }

  @TypeClass.Witness
  static <A> Show<Optional<A>> optionalShow(Show<A> showA) {
    return optA -> optA.map(a -> "Some(" + showA.show(a) + ")").orElse("None");
  }

  @TypeClass.Witness
  static <A> Show<List<A>> listShow(Show<A> showA) { ... }

  @TypeClass.Witness
  static <K, V> Show<Map<K, V>> mapShow(Show<K> showK, Show<V> showV) { ... }
}

Then:

Map<String, List<Optional<Integer>>> m1 =
    Map.of(
        "a", List.of(Optional.of(1), Optional.empty()),
        "b", List.of(Optional.of(2), Optional.of(3)));

println(Show.show(witness(new Ty<>() {}), m1));
// Prints: {"a": [Some(1), None], "b": [Some(2), Some(3)]}

Note that this requires the recursive instantiation of 5 witness constructors. Here's what some debug logs show:

Instantiating: () -> Show[A](String)
Instantiating: () -> Show[A](Integer)
Instantiating: ∀ A. (Show[A](A)) -> Show[A](Optional[T](A))
Instantiating: ∀ A. (Show[A](A)) -> Show[A](List[E](A))
Instantiating: ∀ K V. (Show[A](K), Show[A](V)) -> Show[A](Map[K, V](K)(V))

Type Class: Monoid & Num #

Given:

@TypeClass
interface Monoid<A> {
  A combine(A a1, A a2);

  A identity();

  static <A> A combineAll(Monoid<A> monoid, List<A> elements) {
    A result = monoid.identity();
    for (A element : elements) {
      result = monoid.combine(result, element);
    }
    return result;
  }
}

@TypeClass
interface Num<A> {
  A add(A a1, A a2);

  A mul(A a1, A a2);

  A zero();

  A one();

  @TypeClass.Witness
  static Num<Integer> integerNum() {
    return new Num<>() {
      @Override
      public Integer add(Integer a1, Integer a2) {
        return a1 + a2;
      }

      @Override
      public Integer mul(Integer a1, Integer a2) {
        return a1 * a2;
      }

      @Override
      public Integer zero() {
        return 0;
      }

      @Override
      public Integer one() {
        return 1;
      }
    };
  }
}

record Sum<A>(A value) {
  @TypeClass.Witness
  public static <A> Monoid<Sum<A>> monoid(Num<A> num) {
    return new Monoid<>() {
      @Override
      public Sum<A> combine(Sum<A> s1, Sum<A> s2) {
        return new Sum<>(num.add(s1.value(), s2.value()));
      }

      @Override
      public Sum<A> identity() {
        return new Sum<>(num.zero());
      }
    };
  }
}

Then:

var sums = List.of(new Sum<>(3), new Sum<>(5), new Sum<>(10));

println(Monoid.combineAll(witness(new Ty<>() {}), sums));
// Prints: Sum[value=18]

Type Class: PrintAll #

This example is based on: https://wiki.haskell.org/Varargs

It abuses type classes in order to implement variadic functions. Please read the link above to understand how this works.

It's really striking how it just works with our system!

Given:

@TypeClass
interface PrintAll<T> {
  T printAll(List<String> strings);

  static <T> T of(PrintAll<T> printAll) {
    return printAll.printAll(List.of());
  }

  @TypeClass.Witness
  static PrintAll<Void> base() {
    return strings -> {
      for (String s : strings) {
        System.out.println(s);
      }
      return null;
    };
  }

  @TypeClass.Witness
  static <A, R> PrintAll<Function<A, R>> func(Show<A> showA, PrintAll<R> printR) {
    return strings -> a -> printR.printAll(Lists.concat(strings, List.of(showA.show(a))));
  }
}

Then:

Function<String, Function<List<String>, Function<Integer, Void>>> printer =
    PrintAll.of(witness(new Ty<>() {}));

printer.apply("Items:").apply(JavaList.of("apple", "banana", "cherry")).apply(42);
// Prints:
// "Items:"
// ["apple", "banana", "cherry"]
// 42

Type Class: Type Equality! #

Reified type equality is very neat construct that I'd like to write more about soon.

For now, let's appreciate how this neatly encodes Haskell's own type equality in Java.

Given:

@TypeClass
sealed interface TyEq<A, B> {
  A castR(B b);

  B castL(A a);

  static <T> TyEq<T, T> refl() {
    return new Refl<>();
  }

  record Refl<T>() implements TyEq<T, T> {

    @Override
    public T castR(T t) {
      return t;
    }

    @Override
    public T castL(T t) {
      return t;
    }
  }

  @TypeClass.Witness
  static <T> TyEq<T, T> tyEqRefl() {
    return refl();
  }
}

Of course, we can manually construct refl() instances. And these can be really useful. (I will write about this soon.)

But we can also request them as witnesses in a witness constructor:

@TypeClass
interface SumAllInt<A> {
  A sum(List<Integer> list);

  static <T> T of(SumAllInt<T> sumAllInt) {
    return sumAllInt.sum(List.of());
  }

  @TypeClass.Witness
  static SumAllInt<Integer> base() {
    return list -> list.stream().mapToInt(Integer::intValue).sum();
  }

  @TypeClass.Witness
  static <A, R> SumAllInt<Function<A, R>> func(TyEq<A, Integer> eq, SumAllInt<R> sumR) {
    return list -> a -> sumR.sum(Lists.concat(list, List.of(eq.castL(a))));
  }
}

Similar to the PrintAll example, this lets us summon variadic functions:

Function<Integer, Function<Integer, Function<Integer, Integer>>> sum =
    SumAllInt.of(witness(new Ty<>() {}));
println(sum.apply(1).apply(2).apply(3));

However, notice how in the func rule, we requested TyEq<A, Integer>. This constrains the A type argument to just Integer.

Funny enough, this is actually only necessary in Haskell due to how integer literals are overloaded. So I implemented this here for no gain. But it was cool that it just worked!

Goal: Higher-Order Type Classes #

Consider the Functor type class in Haskell:

class Functor f where
  fmap :: (a -> b) -> f a - > f b

Notice how f is not a type?

It is not one because it is being applied to types a and b, respectively.

This means that f is a type constructor.

A type constructor is a sort of type-level function that builds new types from existing types.

Here, we mean function in a mathematical sense. A mathematical relation is a function if it is both injective (one-to-one) and surjective (non-partial).

In Java, class List<T> can be seen as a type constructor. Though we don't use type application syntax as in Haskell, we do use type parameterization syntax: List<Integer>. In Java, generic types cannot be partially applied. That is, for a type class C<T1, T2, ..., TN>, we must provide all N type arguments at once.

Let's try representing the Functor type class with a Java interface:

interface Functor<F> {
  <A, B> F<B> fmap(Function<A, B> f, F<A> fa);
}

Uh-oh:

The type F is not generic; it cannot be parameterized with arguments <B>

Indeed, Java requires that the type parameter F is a type, and not a type constructor.

Are we cooked?

Not quite.

Aside: What is a kind? #

Simply: values have types; types have kinds.

42 has type Integer.

Integer has kind * (read: star).

The syntax * for the kind of plain types comes from Haskell.

We could also say Integer has kind Type, but that may be confusing given how overloaded the word 'type' is in this context.

List<Integer> also has kind *.

So what is the kind of List itself?

List has kind * -> *.

That is, it is a type-level function that accepts a type and returns a type.

Aside: Higher-Kinded Types in Java #

I am not sure where this workaround originated, but it is rather elegant.

Consider:

interface TApp<F, A> {}

abstract class TagBase {}

sealed interface Maybe<A> extends TApp<Maybe.Tag, A> {
  record Nothing<A>() implements Maybe<A> {}
  
  record Just<A>(A value) implements Maybe<A> {}

  final class Tag extends TagBase {}

  static <A> Maybe<A> unwrap(TApp<Maybe.Tag, A> m) {
    return (Maybe<A>) m; // unsafe if we misuse TApp and tag types
  }
}

Let's unpack this:

How is this useful?

Now we do have a mechanism (more of a convention) to represent Functor!

Check it out:

interface Functor<F> {
  <A, B> TApp<F, B> fmap(Function<A, B> f, TApp<F, A> fa);
}

Remember: TApp<F, A> means F<A>.

And we can also define the witness constructor:

sealed interface Maybe<A> extends TApp<Maybe.Tag, A> {
  // ...

  default <B> Maybe<B> map(Function<A, B> f) { ... }

  @TypeClass.Witness
  static Functor<Maybe.Tag> functor() {
    return new Functor<>() {
      @Override
      public <A, B> TApp<Maybe.Tag, B> fmap(Function<A, B> f, TApp<Maybe.Tag, A> fa) {
        return unwrap(fa).map(f);
      }
    };
  }
}

Let's unpack that:

That's it!

Now, we must extend our type class system to understand these typing conventions.

Subgoal: Extending ParsedType to support HKTs #

This is actually not very involved. Check it out:

sealed interface ParsedType {
  // ...

  static ParsedType parse(Type java) {
    return switch (java) {
      // New:
      case Class<?> tag
        when parseTagType(tag) instanceof Maybe.Just<Class<?>>(var tagged) ->
          new Const(tagged);
      // New:
      case ParameterizedType p
          when parseAppType(p)
              instanceof Maybe.Just<Pair<Type, Type>>(Pair<Type, Type>(var fun, var arg)) ->
          new App(parse(fun), parse(arg));
      // etc
    };
  }

  private static Maybe<Class<?>> parseTagType(Class<?> c) {
    return switch (c.getEnclosingClass()) {
      case Class<?> enclosing when c.getSuperclass().equals(TagBase.class) ->
          Maybe.just(enclosing);
      case null -> Maybe.nothing();
      default -> Maybe.nothing();
    };
  }

  private static Maybe<Pair<Type, Type>> parseAppType(ParameterizedType t) {
    return switch (t.getRawType()) {
      case Class<?> raw when raw.equals(TApp.class) ->
          Maybe.just(Pair.of(t.getActualTypeArguments()[0], t.getActualTypeArguments()[1]));
      default -> Maybe.nothing();
    };
  }
}

We only had to add two new pattern-match cases to parse:

That's it!

I was also surprised! The witness resolution code needs no changes whatsoever!

Although, I was worried about potential bugs coming from misuses of Tags and TApp. So I came up with a lightweight embedded kind-checking mechanism.

Aside: Kind-checking Java-embedded HKTs #

Given:

interface Kind<K extends Kind.Base> {
  sealed interface Base {}

  // KStar = *
  final class KStar implements Base {}

  // KArr k = * -> k
  final class KArr<K extends Base> implements Base {}
}

abstract class TagBase<K extends Kind.Base> implements Kind<K> {}

// Full application of a unary type constructor
// TApp :: (* -> *) -> * -> *
interface TApp<Tag extends Kind<KArr<KStar>>, A> extends Kind<KStar> {}

// Partial application of a binary type constructor
// TPar :: (* -> * -> *) -> * -> (* -> *)
interface TPar<Tag extends Kind<KArr<KArr<KStar>>>, A> extends Kind<KArr<KStar>> {}

Then:

sealed interface Maybe<A> extends TApp<Maybe.Tag, A> {
  // ...

  final class Tag extends TagBase<KArr<KStar>> {}

  static <A> Maybe<A> unwrap(TApp<Tag, A> value) {
    return (Maybe<A>) value;
  }
}

And also:

@FunctionalInterface
interface State<S, A> extends TApp<TPar<State.Tag, S>, A> {
  // ...

  @TypeClass.Witness
  static <S> Functor<TPar<Tag, S>> functor() { ... }

  final class Tag extends TagBase<KArr<KArr<KStar>>> {}

  static <S, A> State<S, A> unwrap(TApp<TPar<Tag, S>, A> value) {
    return (State<S, A>) value;
  }
}

Notice in this case that State has two type parameters:

Examples: Higher-Kinded Type Clases #

Type Class: Functor, Applicative, Monad #

As expected:

@TypeClass
interface Functor<F extends Kind<KArr<KStar>>> {
  <A, B> TApp<F, B> map(Function<A, B> f, TApp<F, A> fa);
}

@TypeClass
interface Applicative<F extends Kind<KArr<KStar>>> extends Functor<F> {
  <A> TApp<F, A> pure(A a);

  <A, B> TApp<F, B> ap(TApp<F, Function<A, B>> ff, TApp<F, A> fa);

  @Override
  default <A, B> TApp<F, B> map(Function<A, B> f, TApp<F, A> fa) {
    return ap(pure(f), fa);
  }
}

@TypeClass
interface Monad<M extends Kind<KArr<KStar>>> extends Applicative<M> {
  <A, B> TApp<M, B> flatMap(Function<A, TApp<M, B>> f, TApp<M, A> fa);

  @Override
  default <A, B> TApp<M, B> map(Function<A, B> f, TApp<M, A> fa) {
    return flatMap(a -> pure(f.apply(a)), fa);
  }

  @Override
  default <A, B> TApp<M, B> ap(TApp<M, Function<A, B>> ff, TApp<M, A> fa) {
    return flatMap(a -> map(f -> f.apply(a), ff), fa);
  }
}

Type Class: Traversable #

Given:

@TypeClass
interface Traversable<T extends Kind<KArr<KStar>>> {
  <F extends Kind<KArr<KStar>>, A, B> TApp<F, ? extends TApp<T, B>> traverse(
      Applicative<F> applicative, Function<A, TApp<F, B>> f, TApp<T, A> ta);

  static <F extends Kind<KArr<KStar>>, T extends Kind<KArr<KStar>>, A, B>
      TApp<F, ? extends TApp<T, B>> traverse(
          Traversable<T> traversable,
          Applicative<F> applicative,
          TApp<T, A> tA,
          Function<A, TApp<F, B>> f) {
    return traversable.traverse(applicative, f, tA);
  }
}

Then:

// Unfortunately for HKTs, we must wrap regular Java lists
record JavaList<A>(List<A> toList) implements TApp<JavaList.Tag, A> {
  // ...

  @TypeClass.Witness
  public static <A> Show<JavaList<A>> show(Show<A> showA) { ... }

  @TypeClass.Witness
  public static Functor<JavaList.Tag> functor() { ... }

  @TypeClass.Witness
  public static Traversable<JavaList.Tag> traversable() { ... }

  // ...
}

For example:

println(
    Traversable.traverse(
        witness(new Ty<>() {}), // for Traversable<JavaList>
        witness(new Ty<>() {}), // for Applicative<Maybe>
        JavaList.of(1, 2, 3)
        Maybe::just));
// Prints: Just[value=JavaList[toList=[1, 2, 3]]]

Conclusion #

This was a lot of fun to work on. I had never before implemented type classes in a language, and it turned out to be way simpler than I thought. The key insight was figuring out that type unification was all that I needed. Stupidly, I did not check Haskell's spec to figure that out. But it did come somewhat naturally after a bit of thinking, given my experience in implementing type systems.

This project was partially inspired by this recent Brian Goetz JVMLS talk where he presents his early ideas on how to bring type clases to Java. His proposed syntax is Show<Integer>.witness. And, of course, the mechanism is supposed to resolve witnesses at compile time.

You can find the complete implementation in this Gist.

The code also contains several other type class examples like QuickCheck's Arbitrary type class.

Future work #

As of the time of writing, ChatGPT reports the following:

Annex: Context Instances #

Consider this example:

static <A> String example(Show<A> showA, A value) {
  return Show.show(witness(new Ty<>() {}), JavaList.of(value));
}

It is equivalent to the following Haskell code:

example :: Show a => a -> String
example value = show [value]

In the Java code, we try to lookup Show<List<A>>, but we don't know what A is!

Sure, at runtime we may know its real type, but we actually would like resolution to be static. (Even though we use reflection for witness resolution.)

In the Haskell code, the available instance of Show a as captured by the function's signature becomes available as a 'context instance' that is used to derive Show [a] for the call of show.

How can we achieve this in Java?

First, we define a type that can capture a witness along with its static type:

abstract class Ctx<T> {
  private final T instance;

  Ctx(T instance) {
    this.instance = instance;
  }

  public T instance() {
    return instance;
  }

  public Type type() {
    return requireNonNull(
        ((ParameterizedType) getClass().getGenericSuperclass())
            .getActualTypeArguments()[0]);
  }
}

It leverages the same mechanism as Ty<T> to capture static types.

Then, we pass it to the witness() method:

static <A> String example(Show<A> showA, A value) {
  return Show.show(
      witness(new Ty<>() {}, new Ctx<>(showA) {}),
      JavaList.of(value));
}

Finally, we update a bit of our type class resolution code:

class TypeClasses {
  // New: second parameter
  public static <T> T witness(Ty<T> ty, Ctx<?>... context) {
    return switch (summon(ParsedType.parse(ty.type()), parseContext(context))) {
      // ...
    };
  }

  // New: parsing Ctx<?> into ContextInstance
  private static List<ContextInstance> parseContext(Ctx<?>[] context) {
    return Arrays.stream(context)
        .map(ctx -> new ContextInstance(ctx.instance(), ParsedType.parse(ctx.type())))
        .toList();
  }

  // ...

  // New: second parameter
  private static Either<SummonError, Object> summon(
      ParsedType target, List<ContextInstance> context) {
    // ...
  }

  // New: second parameter
  private static Either<SummonError, List<Object>> summonAll(
      List<ParsedType> targets, List<ContextInstance> context) {
    // ...
  }

  // New: second parameter
  private static List<Candidate> findCandidates(
      ParsedType target, List<ContextInstance> context) {
    // New: use the context instances along with the discovered witness constructors!
    return Stream.<WitnessRule>concat(
            context.stream(),
            findRules(target).stream())
        .flatMap(...)
        .toList();
  }

  // ...

  // New: a new case class for WitnessRule
  private record ContextInstance(Object instance, ParsedType type) implements WitnessRule {
    @Override
    public Maybe<List<ParsedType>> tryMatch(ParsedType target) {
      // This is a concrete instance, so we only check for type equality
      return target.equals(type) ? Maybe.just(List.of()) : Maybe.nothing();
    }

    @Override
    public Object instantiate(List<Object> dependencies) {
      // Trivial
      return instance;
    }
  }
}

That's all. A bit noisy, but rather simple.

Now this works:

static <A> String example(Show<A> showA, A value) {
  return Show.show(
      witness(new Ty<>() {}, new Ctx<>(showA) {}),
      JavaList.of(value));
}

println(example(witness(new Ty<>() {}), 123));

Here, the type captured by new Ctx<>(showA) {} is Show<A>, where A is the unique type variable belonging to the example method.

The Show<List<A>> witness that we are trying to summon needs a Show<A> whose type could only possibly exist in this static context!

Annex: Overlapping Instances #

In Haskell, the String type is defined as:

type String = [Char]

That is, a String is just a list of Char.

Now, notice the difference here:

show [1, 2, 3]
-- "[1, 2, 3]"

show [('a', 1), ('b', 2)]
-- "[('a', 1), ('b', 2)]"

show ['a', 'b']
-- "\"ab\"" what?

The generic Show instance for [a] simply intercalates ", " between elements.

But the Show instance for [Char] behaves differently.

Why is that? This is due to a language extension called overlapping instances.

It allows otherwise ambiguous instances to coexist:

instance Show a => Show [a] where ...

instance {-# OVERLAPPING #-} Show [Char] where ...

The OVERLAPPING pragma tells the compiler that this instance may override another instance iff it is more specific.

The rules for instance specificity are explained in the same link.

Now, consider in Java:

sealed interface FwdList<A> extends TApp<FwdList.Tag, A> {
  record Nil<A>() implements FwdList<A> {}

  record Cons<A>(A head, FwdList<A> tail) implements FwdList<A> {}

  @TypeClass.Witness
  static <A> Show<FwdList<A>> show(Show<A> showA) { ... }

  // Ambiguous!
  @TypeClass.Witness
  static Show<FwdList<Character>> show() { ... }
}

FwdList (name inspired by C++'s std::forward_list) implements a data structure like Haskell's lists.

As it is, witness resolution will fail with an ambiguous witness error.

In order to support overlapping instance, we must apply some changes.

First, let's model Haskell's OVERLAPPING and OVERLAPPABLE pragmas:

@Retention(RetentionPolicy.RUNTIME)
@interface TypeClass {
  @Retention(RetentionPolicy.RUNTIME)
  @interface Witness {
    Overlap overlap() default Overlap.NONE;

    enum Overlap {
      NONE,
      OVERLAPPING,
      OVERLAPPABLE
    }
  }
}

Then, we make it accessible from InstanceConstructor:

private record InstanceConstructor(FuncType func) implements WitnessRule {
  public TypeClass.Witness.Overlap overlap() {
    return func.java().getAnnotation(TypeClass.Witness.class).overlap();
  }

  // ...
}

Finally, we implement the overlapping instances deduction algorithm as described in Haskell's spec:

private static List<InstanceConstructor> reduceOverlapping(
    List<InstanceConstructor> candidates) {
  return candidates.stream()
      .filter(
          iX ->
              candidates.stream()
                  .filter(iY -> iX != iY)
                  .noneMatch(cY -> isOverlappedBy(iX, cY)))
      .toList();
}

private static boolean isOverlappedBy(
    InstanceConstructor iX, InstanceConstructor iY) {
  return (iX.overlap() == OVERLAPPABLE || iY.overlap() == OVERLAPPING)
      && isSubstitutionInstance(iX, iY)
      && !isSubstitutionInstance(iY, iX);
}

private static boolean isSubstitutionInstance(
    InstanceConstructor base, InstanceConstructor reference) {
  return Unification.unify(base.func().returnType(), reference.func().returnType())
      .fold(() -> false, map -> !map.isEmpty());
}

And we apply it to our candidate resolution function:

private static List<Candidate> findCandidates(
    ParsedType target, List<ContextInstance> context) {
  return Stream.<WitnessRule>concat(
          context.stream(),
          reduceOverlapping(findRules(target)).stream())
      .flatMap(...)
      .toList();
}

And, of course, we annotate our witness constructor:

sealed interface FwdList<A> extends TApp<FwdList.Tag, A> {
  record Nil<A>() implements FwdList<A> {}

  record Cons<A>(A head, FwdList<A> tail) implements FwdList<A> {}

  @TypeClass.Witness
  static <A> Show<FwdList<A>> show(Show<A> showA) { ... }

  // OK now!
  @TypeClass.Witness(overlap = OVERLAPPING)
  static Show<FwdList<Character>> show() { ... }
}

Now, our witness resolution code will pick Show<FwdList<Character>> because it is more specific AND it declares that it may overlap other instances.