# Saturday, November 08, 2008
« Nemerle Macros - InheritConstructors | Main | Nemerle Macros - ExecuteReaderLoop »

Correctly implementing structural equality involves a lot of code. My macro implements this for you.

Watch the Screencast

Before

using System;
using System.Console;
using Nemerle.Utility;
 
module Program
{
    [Record]
    public class Location
    {
        public Latitude : double;
        public Longitude : double;
 
        public override Equals(other : object) : bool
        {
            match (other)
            {
                | null => false
                | l is Location => l.Latitude == Latitude && l.Longitude == Longitude
                | _ => false;
            }
        }
 
        public override GetHashCode() : int
        {
            Latitude.GetHashCode() ^ Longitude.GetHashCode()
        }
 
        public static @==(l1 : Location, l2 : Location) : bool
        {
            def l1IsNull = object.ReferenceEquals(l1, null);
            def l2IsNull = object.ReferenceEquals(l2, null);
            match (l1IsNull, l2IsNull)
            {
                | (true, true) => true
                | (false, false) => l1.Equals(l2)
                | _ => false
            }
        }
 
        public static @!=(l1 : Location, l2 : Location) : bool
        {
            !(l1 == l2)
        }
    }
 
    module Program
    {
        Main() : void
        {
            def l1 = Location(30, 70);
            def l2 = Location(30, 70);
            def l3 = Location(40, 90);
 
            WriteLine("l1.Equals(l2) -> {0}", l1.Equals(l2));
            WriteLine("l1 == l2 -> {0}", l1 == l2);
            WriteLine("l1 != l2 -> {0}", l1 != l2);
            WriteLine("l1.GetHashCode() == l2.GetHashCode() -> {0}", l1.GetHashCode() == l2.GetHashCode());
            WriteLine("l1 != l3 -> {0}", l1 != l3);
 
            _ = ReadLine();
        }
    }
}

After

using System;
using System.Console;
using Nemerle.Utility;
using SampleMacros;
 
module Program
{
    [Record]
    [ImplementEquality(Latitude, Longitude)]
    public class Location
    {
        public Latitude : double;
        public Longitude : double;
    }
 
    module Program
    {
        Main() : void
        {
            def l1 = Location(30, 70);
            def l2 = Location(30, 70);
            def l3 = Location(40, 90);
 
            WriteLine("l1.Equals(l2) -> {0}", l1.Equals(l2));
            WriteLine("l1 == l2 -> {0}", l1 == l2);
            WriteLine("l1 != l2 -> {0}", l1 != l2);
            WriteLine("l1.GetHashCode() == l2.GetHashCode() -> {0}", l1.GetHashCode() == l2.GetHashCode());
            WriteLine("l1 != l3 -> {0}", l1 != l3);
 
            _ = ReadLine();
        }
    }
}

ImplementEquality.n

using Nemerle;
using Nemerle.Compiler;
using Nemerle.Compiler.Parsetree;
using Nemerle.Macros;
 
namespace SampleMacros
{
    [MacroUsage(MacroPhase.BeforeInheritance, MacroTargets.Class)] 
    public macro ImplementEquality(typeBuilder : TypeBuilder, params members : list[PExpr]) 
    { 
        ImplementEqualityModule.AddEqualsMethod(typeBuilder, members); 
        ImplementEqualityModule.AddEqualityOperator(typeBuilder); 
        ImplementEqualityModule.AddInequalityOperator(typeBuilder); 
    } 
 
 
    module ImplementEqualityModule 
    { 
        public AddEqualsMethod(typeBuilder : TypeBuilder, members : list[PExpr]) : void 
        { 
            def other = Macros.NewSymbol(); 
            def comparisons = members.Map(m => { 
                def s = Splicable.Name(Name(m.ToString())); 
                <[ (this.$s).Equals($(other : name).$s) ]> 
            }); 
 
 
            def b = comparisons.Tail.FoldLeft(comparisons.Head, (c, acc) => <[ $acc && $c ]>); 
            typeBuilder.Define(<[ decl: 
                public override Equals(obj : object) : bool 
                { 
                    match (obj) { 
                        | null => false 
                        | $(other : name) is $(typeBuilder.ParsedTypeName) => $b 
                        | _ => false 
                    } 
                } 
            ]>); 
 
 
            def hashCodes = members.Map(m => { def s = Splicable.Name(Name(m.ToString())); <[ this.$s.GetHashCode() ]> }); 
            def hashCode = hashCodes.Tail.FoldLeft(hashCodes.Head, (c, acc) => <[ $acc ^ $c ]>); 
            typeBuilder.Define(<[ decl: 
                public override GetHashCode() : int 
                { 
                    $hashCode 
                } 
            ]>); 
        } 
 
 
        public AddEqualityOperator(typeBuilder : TypeBuilder) : void 
        { 
            typeBuilder.Define(CreateEqualityOperator(typeBuilder.ParsedTypeName, false)); 
        } 
 
 
        public AddInequalityOperator(typeBuilder : TypeBuilder) : void 
        { 
            typeBuilder.Define(CreateEqualityOperator(typeBuilder.ParsedTypeName, true)); 
        } 
 
 
        CreateEqualityOperator(typeRef : PExpr, invert : bool) : ClassMember.Function 
        { 
            def op = Splicable.Name(Name(if (invert) "!=" else "==")); 
 
 
            <[ decl: 
                public static $op(p1 : $(typeRef), p2 : $(typeRef)) : bool 
                { 
                    def p1Null = object.ReferenceEquals(null, p1); 
                    def p2Null = object.ReferenceEquals(null, p2); 
                    match (p1Null, p2Null) { 
                        | (true, true) => true 
                        | (false, false) => $(if (invert) <[false == p1.Equals(p2)]> else <[p1.Equals(p2)]> ) 
                        | _ => false 
                    } 
                } 
            ]> 
        } 
    } 
 
 
 
}