Using abstract classes to simulate tagged unions (aka sum types)

Most functional languages offer support for tagged unions (also called sum types), a type of data structure capable of successively holding values of several fixed types. This article shows how to use abstract classes to emulate such behaviour in high-level object-oriented languages such as C#, Java, or VB.Net ((.Net languages have the [StructLayout(LayoutKind.Explicit)] attribute, which makes it possible to create structs which behave a lot like C++ unions. But that only works with primitive types.)).

This article features a rather long introduction to sum types in functional languages. If you already know about this, you can skip to the core of this article.

An introduction to sum types and pattern matching

A preliminary example: linked lists

Starting with a simple example, imagine you want to create a linked list structure. Most C#/Java programmers will write something like

  class MyList<T> {
    T value;
    MyList<T> next;

    MyList(T value, MyList<T> next) {
      this.value = value;
      this.next = next;
    }
  }

  MyList<int> example = new MyList(1, new MyList(2, null));

This relies on a rather ugly trick: null is used to represent the empty list. Functional languages prefer to clearly separate two cases : the empty list (usually called nil), and the concatenation of a value (called the head of the list) and another list (called the tail).

  (* 't is a type variable, just like T above *) 
  type 't MyList = Nil | Cons of 't * ('t MyList);; 
  let example = Cons(1, Cons(2, Nil))
  {- t is a type variable, just like T above -}
  data MyList t = Nil | Cons t (MyList t)
  example = Cons 1 (Cons 2 Nil)

This introduces the concept of a sum type; a type in which values belong to one of a number of sets (here, either to the singleton set {Nil}, or to the set of all values of type t — for example integers if t is int).

Here is another example, where a value of type Value can be either an integer, or a floating-point number, or a string, or a boolean.

  type Value = 
    | IntegerValue of int 
    | FloatValue of float
    | StringValue of string
    | BooleanValue of bool
  data Value = 
      IntegerValue Int
    | FloatValue Float
    | StringValue String
    | BooleanValue Bool

Quite naturally, functional languages include a special construct to handle values whose type is a sum type. This construct is called pattern matching; it is a bit like a switch statement, only the switch is on the object’s inner type (the subset of the sum type which the object belongs to). For example, to identify the contents of a variable whose type is Value, functional programmers write

  let identify = function
    | IntegerValue (val) -> "int!" 
    | FloatValue (val) -> "float!" 
    | StringValue (str) -> "string"
    | BooleanValue (b) -> "bool!"
  identify (IntegerValue val) = "int!" 
  identify (FloatValue val) = "float!" 
  identify (StringValue str) = "string"
  identify (BooleanValue b) = "bool!"

Here’s another example, on lists this time: if functional programmers want to count elements in a list, they write

  let rec count = function
    | Nil -> 0
    | Cons (head, tail) -> 1 + (count tail)
  ;;
  count Nil = 0
  count (Cons hd tl) = 1 + count tl

Let’s move on to our last example

XML node trees

In this section, we use a sum type to represent the structure of an XML document (some node types have been omitted). We first define an attribute to be a pair of two strings, and a node to be one of DocumentFragment (a list of XML nodes), Element (a name, attributes, and children elements), Commment (some comment text), or Text (plain text).

  type attr = 
    Attribute of string * string;;

  type xmlnode = 
    | DocumentFragment of xmlnode list
    | Element of string 
                  * (attr list)
                  * (xmlnode list)
    | Comment of string
    | Text of string
  ;;
  data XmlNode = 
    | DocumentFragment [XmlNode]
    | Attribute String String
    | Element String [Attribute]
                     [XmlNode]
    | Comment String
    | Text String

With such a type declaration, writing concise tree-traversal functions is easy: as an example, we define a serialize function, which generate XML code from an xmlnode object.

  let serialize_attribute (Attribute(name, value)) =
    Printf.printf " %s=\"%s\"" name value
  ;;

  let rec serialize = function 
    | DocumentFragment nodes ->
        List.iter serialize nodes (* applies serialize to every xmlnode in nodes *)
    | Element (label, attributes, nodes) ->
        Printf.printf "<%s" label;
          List.iter serialize_attribute attributes;
        Printf.printf ">\n";
          List.iter serialize nodes;
        Printf.printf "</%s>\n" label;
    | Comment (str) ->
        Printf.printf "<!-- %s -->" str;
    | Text (str) ->
        Printf.printf "%s\n" str;
  ;;

Here is a small example to illustrate the previous function:

  serialize (
    Element("a", [
      Attribute("href", "http://pit-claudel.fr/clement/blog");
      Attribute("title", "Code crumbs")
    ], [
      Text("Code crumbs -- Personal thoughts about programming")
    ])
  );;

This prints

<a href="http://pit-claudel.fr/clement/blog" title="Code crumbs">
Code crumbs -- Personal thoughts about programming
</a>

We can now move on to the core of this article:

Simulating sum types using abstract classes

The problem

Returning to a simpler example, suppose we need to represent a generic binary tree (a tree is defined as being either a branch, containing two sub-trees, or a leaf, containing a value). Functional programmers will write something like

  type 't tree = Branch of ('t tree * 't tree) | Leaf of 't
  let example = Branch (Branch (Leaf 1, Leaf 2), Leaf 3)
  data Tree t = Branch (Tree t) (Tree t) | Leaf t
  example = Branch (Branch (Leaf 1) (Leaf 2)) (Leaf 3)

On the other hand, (hurried) C#/Java programmers will often implement this as a product type — that is, they’ll pack both cases in a single class, and use an extra field to differentiate branches and leaves:

  enum TreeType = {BRANCH, LEAF};

  class BadTree<T> {
    TreeType type;
    
    T leaf_value;
    BadTree<T> left_branch, right_branch;

    BadTree(BadTree<T> left_branch, BadTree<T> right_branch) {
         this.type = BRANCH;
         this.left_branch = left_branch;
         this.right_branch = right_branch;
         // Memory waste: leaf_value is unused (!)
    }

    BadTree(T leaf_value) {
         this.type = LEAF;
         this.leaf_value = leaf_value;
         // Memory waste: left_branch and right_branch are unused (!)
    }
  }

  BadTree<int> example = new BadTree(new BadTree(new BadTree(1), 
                                                 new BadTree(2)),
                                     new Tree(3));

This representation wastes a lot of memory and, though rather concise, quickly degenerates when more cases are added to the type definition (for example in the xmlnode case — for another real-world example, see the end of this post).

The solution

The idea is to encode the type disjunction (Branch or Leaf, DocumentFragment or Element or Comment or Text) in the inheritance hierarchy; practically speaking, the trick is to define an abstract Tree class, and make two new classes, Leaf and Branch, which both inherit from Tree:

  abstract class Tree<T> {
    public static class Leaf<T> : Tree<T> {
      public T value;

      public Leaf(T value) {
        this.value = value;
      }
    }

    public static class Branch<T> : Tree<T> {
      public Tree<T> left, right;

      public Branch(Tree<T> left, Tree<T> right) {
        this.left = left;
        this.right = right;
      }
    }
  }

  Tree<int> example = new Tree::Branch(new Tree::Branch(new Tree::Leaf(1),
                                                        new Tree::Leaf(2)),
                                       new Tree::Leaf(3));

This allows for much cleaner code ; and the memory waste is gone!

Pattern matching on abstract classes

The BadTree class, which uses an enum to distinguish branches and nodes, makes it possible to performe pattern matchings over the Tree class, by switching on the Tree.type field. Though this approach directly translates to the Tree abstract class using the is keyword ((Java equivalent: instanceof)), it’s not ideal.

  Tree<int> tree = (...);

  switch (tree.type) {
    case TreeType.BRANCH:
      (...); break;
    case TreeType.LEAF:
      (...); break;
  }
  Tree<int> tree = (...);

  if (tree is Tree::Branch) {
    Tree::Branch branch = 
      (Tree::Branch) btree;
    (...);
  } else if (tree is Tree::Leaf) {
    Tree::Leaf leaf = 
      (Tree::Leaf) tree;
    (...);
  }

Though functional, this new version of the pattern-matching code is not really pretty. A much cleaner approach is to implement the pattern matching directly in the derived classes (Branch and Leaf), by declaring an abstract method in the parent class, Tree. That is, instead of having one big function doing some explicit pattern matching in the Tree class, we’ll have multiple specialized functions — one in each class inheriting Tree. Philosophically speaking, just like we implemented type disjunctions as different classes inheriting a common parent, we’ll implement logical disjunctions as different functions overriding a common parent.

For example, here is how we can write a Fork function, which returns a new Tree with every leaf split in two identical leaves:

  abstract class Tree<T> {
    public abstract Tree<T> Fork();

    public static class Leaf<T> : Tree<T> {
      public T value;

      // Constructor omitted

      public override Tree<T> Fork() {
        return new Tree::Branch(new Tree::Leaf(value), new Tree::Leaf(value));
      }
    }

    public static class Branch<T> : Tree<T> {
      public Tree<T> left, right;

      // Constructor omitted

      public override Tree<T> Fork() {
        return new Tree::Branch(left.Fork(), right.Fork());
      }
    }
  }

Going further: a real-life example

The benefits of using this abstract class approach are particularly visible when the data types involved get more complex. Below is an example implementation of a data structure used to describe arithmetic expressions.

  abstract class ArithmeticExpression {
    // Evaluates an arithmetic expression. context is a dictionary mapping variable
    //  names to their values.
    public abstract float? Eval(Dictionary<String, float?> context);

    public class Variable : ArithmeticExpression {
      string label;

      public override float? Eval(Dictionary<String, float?> context) {
        float? value;
        context.TryGetValue(label, out value);
        return value;
      }
    }
    
    public abstract class BinaryOperation : ArithmeticExpression {
      protected ArithmeticExpression left_operand, right_operand;

      public BinaryOperation(ArithmeticExpression left_operand, ArithmeticExpression right_operand) {
        this.left_operand = left_operand;
        this.right_operand = right_operand;
      }

      protected abstract float? Apply(float left_value, float right_value);

      public override float? Eval(Dictionary<String, float?> context) {
        float? left_result = left_operand.Eval(context), right_result = right_operand.Eval(context);

        if (left_result.HasValue && right_result.HasValue)
          return Apply(left_result.Value, right_result.Value);
        else
          return null;
      }

      class Sum : BinaryOperation {
        public Sum(ArithmeticExpression left_operand, ArithmeticExpression right_operand) : base(left_operand, right_operand) { }

        protected override float? Apply(float left_value, float right_value) {
          return left_value + right_value;
        }
      }

      class Subtraction : BinaryOperation {
        public Subtraction(ArithmeticExpression left_operand, ArithmeticExpression right_operand) : base(left_operand, right_operand) { }

        protected override float? Apply(float left_value, float right_value) {
          return left_value - right_value;
        }
      }

      class Product : BinaryOperation {
        public Product(ArithmeticExpression left_operand, ArithmeticExpression right_operand) : base(left_operand, right_operand) { }

        protected override float? Apply(float left_value, float right_value) {
          return left_value * right_value;
        }
      }

      class Division : BinaryOperation {
        public Division(ArithmeticExpression left_operand, ArithmeticExpression right_operand) : base(left_operand, right_operand) { }

        protected override float? Apply(float left_value, float right_value) {
          return left_value / right_value;
        }
      }
    }
  }

2 thoughts on “Using abstract classes to simulate tagged unions (aka sum types)

  1. Valentin Waeselynck

    In a number of situations, this is not entirely satisfying because the client can’t extend the matching behavior. GoF’s Visitor Pattern solves this and deserves to be mentioned here.

    1. Clément Post author

      Indeed, that’s a good point.

      In fact, the problem in the “extensible” matching presented as unsatisfactory in this article (the one requiring the if (obj is type1) ... if (obj is type2) ... if (obj is type3) ...) is that the type information gained by explicitly testing for the type is not propagated into the corresponding branch; that is, even if (obj is type1) evaluates to true, an explicit cast to type1 is still required before the obj can be used as a type1 object in the branch’s body.

      So yes, although slightly more verbose in simple cases, the visitor pattern is another nice workaround to the lack of proper discriminated unions; thanks for your comment!

Comments are closed.