Nadohs/Cast-Free-Arithmetic-in-Swift

Improvement: removing extra cast to PreferredType and adding overflow arithmetic

Opened this issue · 6 comments

Hello there,

I found this repository through _iosdevweekly issue 220_ and it was fun to dig into this kind of a topic. Before I watched your presentation I just wanted to try if I could solve the behavior by my self without knowing your code. It took me a few hours to figure out how to solve some problems but I finally had a working solution. I also faced the + 0 compiler issue as well while my research. At least I could remove the double casting (casting to the PrefferedType) overhead of your code with simple protocol abstraction:

lhs, rhs --> Double --> combine(lhs operator rhs) --> convert to T

is now

T(lhs) operator T(rhs)

I solved the comparison part by implementing a rule which will cast one of the type to match the other type. Btw. you can remove Equatable from your code because Comparable does inherit from that protocol already.

Here is my implementation:

import CoreGraphics

private let rightTypeEqualsLeftType = true

protocol ConvertableIntegerType {

    init(_ value: UInt8)
    init(_ value: Int8)
    init(_ value: UInt16)
    init(_ value: Int16)
    init(_ value: UInt32)
    init(_ value: Int32)
    init(_ value: UInt64)
    init(_ value: Int64)
    init(_ value: UInt)
    init(_ value: Int)
}

extension ConvertableIntegerType {

    private init(_ value: ConvertableIntegerType) {

        switch value {

        case let number as UInt8:
            self.init(number)
        case let number as Int8:
            self.init(number)
        case let number as UInt16:
            self.init(number)
        case let number as Int16:
            self.init(number)
        case let number as UInt32:
            self.init(number)
        case let number as Int32:
            self.init(number)
        case let number as UInt64:
            self.init(number)
        case let number as Int64:
            self.init(number)
        case let number as UInt:
            self.init(number)
        case let number as Int:
            self.init(number)
        default:
            fatalError("not convertable integer type")
        }
    }
}

protocol ConvertableFloatingPointType {

    init(_ value: Float)
    init(_ value: Double)
    init(_ value: Float80)
    init(_ value: CGFloat)
}

protocol ConvertableNumberType: ConvertableIntegerType, ConvertableFloatingPointType {}

extension ConvertableNumberType {

    private init(_ value: ConvertableNumberType) {

        switch value {

        case let number as UInt8:
            self.init(number)
        case let number as Int8:
            self.init(number)
        case let number as UInt16:
            self.init(number)
        case let number as Int16:
            self.init(number)
        case let number as UInt32:
            self.init(number)
        case let number as Int32:
            self.init(number)
        case let number as UInt64:
            self.init(number)
        case let number as Int64:
            self.init(number)
        case let number as UInt:
            self.init(number)
        case let number as Int:
            self.init(number)
        case let number as Float:
            self.init(number)
        case let number as Double:
            self.init(number)
        case let number as Float80:
            self.init(number)
        case let number as CGFloat:
            self.init(number)
        default:
            fatalError("not convertable number type")
        }
    }
}

protocol ArithmeticType: Comparable {

    @warn_unused_result
    func + (lhs: Self, rhs: Self) -> Self

    @warn_unused_result
    func - (lhs: Self, rhs: Self) -> Self

    @warn_unused_result
    func * (lhs: Self, rhs: Self) -> Self

    @warn_unused_result
    func / (lhs: Self, rhs: Self) -> Self

    @warn_unused_result
    func % (lhs: Self, rhs: Self) -> Self
}

// MARK: - ArithmeticNumberType
protocol ArithmeticNumberType: ConvertableNumberType, ArithmeticType {}

@warn_unused_result
func + <T: ArithmeticNumberType>(lhs: ConvertableNumberType, rhs: ConvertableNumberType) -> T {

    return T(lhs) + T(rhs)
}

@warn_unused_result
func - <T: ArithmeticNumberType>(lhs: ConvertableNumberType, rhs: ConvertableNumberType) -> T {

    return T(lhs) - T(rhs)
}

@warn_unused_result
func * <T: ArithmeticNumberType>(lhs: ConvertableNumberType, rhs: ConvertableNumberType) -> T {

    return T(lhs) * T(rhs)
}

@warn_unused_result
func / <T: ArithmeticNumberType>(lhs: ConvertableNumberType, rhs: ConvertableNumberType) -> T {

    return T(lhs) / T(rhs)
}

@warn_unused_result
func % <T: ArithmeticNumberType>(lhs: ConvertableNumberType, rhs: ConvertableNumberType) -> T {

    return T(lhs) % T(rhs)
}

// MARK: - Comparable
@warn_unused_result
func == <T1: ArithmeticNumberType, T2: ArithmeticNumberType>(lhs: T1, rhs: T2) -> Bool {

    return rightTypeEqualsLeftType ? lhs == T1(rhs) : T2(lhs) == rhs
}

@warn_unused_result
func != <T1: ArithmeticNumberType, T2: ArithmeticNumberType>(lhs: T1, rhs: T2) -> Bool {

    return !(lhs == rhs) // Comparable does not contain != function
}

@warn_unused_result
func < <T1: ArithmeticNumberType, T2: ArithmeticNumberType>(lhs: T1, rhs: T2) -> Bool {

    return rightTypeEqualsLeftType ? lhs < T1(rhs) : T2(lhs) < rhs
}

@warn_unused_result
func > <T1: ArithmeticNumberType, T2: ArithmeticNumberType>(lhs: T1, rhs: T2) -> Bool {

    return rightTypeEqualsLeftType ? lhs > T1(rhs) : T2(lhs) > rhs
}

@warn_unused_result
func <= <T1: ArithmeticNumberType, T2: ArithmeticNumberType>(lhs: T1, rhs: T2) -> Bool {

    return rightTypeEqualsLeftType ? lhs <= T1(rhs) : T2(lhs) <= rhs
}

@warn_unused_result
func >= <T1: ArithmeticNumberType, T2: ArithmeticNumberType>(lhs: T1, rhs: T2) -> Bool {

    return rightTypeEqualsLeftType ? lhs >= T1(rhs) : T2(lhs) >= rhs
}

// MARK: - OverflowArithmeticIntegerType
infix operator &/ { associativity left precedence 150 }

infix operator &% { associativity left precedence 150 }

protocol OverflowArithmeticIntegerType: ConvertableIntegerType, _IntegerArithmeticType {}

@warn_unused_result
func &+ <T: OverflowArithmeticIntegerType>(lhs: ConvertableIntegerType, rhs: ConvertableIntegerType) -> T {

    return T.addWithOverflow(T(lhs), T(rhs)).0 // .0 containts the result and silently discard any overflow
}

@warn_unused_result
func &- <T: OverflowArithmeticIntegerType>(lhs: ConvertableIntegerType, rhs: ConvertableIntegerType) -> T {

    return T.subtractWithOverflow(T(lhs), T(rhs)).0
}

@warn_unused_result
func &* <T: OverflowArithmeticIntegerType>(lhs: ConvertableIntegerType, rhs: ConvertableIntegerType) -> T {

    return T.multiplyWithOverflow(T(lhs), T(rhs)).0
}

@warn_unused_result
func &/ <T: OverflowArithmeticIntegerType>(lhs: ConvertableIntegerType, rhs: ConvertableIntegerType) -> T {

    return T.divideWithOverflow(T(lhs), T(rhs)).0
}

@warn_unused_result
func &% <T: OverflowArithmeticIntegerType>(lhs: ConvertableIntegerType, rhs: ConvertableIntegerType) -> T {

    return T.remainderWithOverflow(T(lhs), T(rhs)).0
}

// MARK: - Extension
extension UInt8: ArithmeticNumberType, OverflowArithmeticIntegerType {}
extension Int8: ArithmeticNumberType, OverflowArithmeticIntegerType {}
extension UInt16: ArithmeticNumberType, OverflowArithmeticIntegerType {}
extension Int16: ArithmeticNumberType, OverflowArithmeticIntegerType {}
extension UInt32: ArithmeticNumberType, OverflowArithmeticIntegerType {}
extension Int32: ArithmeticNumberType, OverflowArithmeticIntegerType {}
extension UInt64: ArithmeticNumberType, OverflowArithmeticIntegerType {}
extension Int64: ArithmeticNumberType, OverflowArithmeticIntegerType {}
extension UInt: ArithmeticNumberType, OverflowArithmeticIntegerType {}
extension Int: ArithmeticNumberType, OverflowArithmeticIntegerType {}

extension Float: ArithmeticNumberType {}
extension Double: ArithmeticNumberType {}

extension CGFloat: ArithmeticNumberType {

    init(_ value: CGFloat) {

        self = value
    }

    init(_ value: Float80) {

        self.init(Double(value))
    }
}

extension Float80: ArithmeticNumberType {

    init(_ value: CGFloat) {

        self.init(Double(value))
    }
}

I'd love to see in later Swift version some kind of dynamic cast like:

let newValue = value as! value.dynamicType

so we could remove the switch overhead here.

Regards Adrian Zubarev

I like this solution!

Updated my solution to support overflow operators. I had add more protocol abstraction and to rename them do to that. Any feedback is welcome.

Regards Adrian Zubarev

Here is the solution to the problem from your mail, I accidentally solved the + 0 issue in let g, but only for this example. We still need the + 0 if we want to cast 3 different types to a forth one.

let a: Float = 1.5
let b: Int = 1

// this one is correct because the compile will cut the floating point part away
// Int(1.5) == 1
let c: Int = a + b + b // result is 3

// this one is interesting like the `+ 0` issue. it seems that the compiler is not
// smart enough to find the right the correct function so we can help him
let d: Double = a + b + b + 0

// I assume that the compiler is doing the job like this
// and it is still a Double because the last + is using my function
let e: Double = ((a + b) + b) + 0 // == 3.0

// to solve this little issue we just have to tell the compiler what to add first
let f: Double = a + b + (b + 0) // == 3.5

// this one is even better, it just solves both issues
// because b + b is Int + Int
let g: Double = a + (b + b) // == 3.5

Here is more:

let a: CGFloat = 1.5
let aa: Double = Double(a) // == 1.5

let b: Float = 1.4
let bb: Double = Double(b) // == 1.399999976158142 <-- bad one
let bbb: Double = 1.4

let c: Int = 1
let cc: Double = Double(c)

let d: Int = a + b + c // 3.0

let e: Double = 1.5 + 1.4 + 1 // == 3.9

let f: Double = a + b + c + 0 // == 3

let g: Double = Double(a) + Double(b) + Double(c) // == 3.899999976158142

let h: Double = a + b + (c + 0) // == 3.899999976158142

let i: Double = a + bbb + (c + 0) // == 3.9

let j: Double = (a + c) + bbb // == 3.9 we are forced to reorder otherwise we'll need + 0 again

Just tested some idea and found that casting from Float (32 Bit) to Double (64 Bit) will only work for all numbers n + ( x * (1 / (2 ^ y))) like 1.5, 1.25, 1.75, 1.125, 2.0, -2.375 and so on.

n == { Integer }
x == { 0, 1, 2 ... Inf }
y == { 1, 2, 3 ... Inf }

So it depends on the bit pattern while casting from Float to Double, if it has an equivalent one in Double there won't be any precision loss.

I just solved everything :) I'll update this last comment when everything is cleaned up and pushed to a repo. + 0 issue is gone and I can use the arithmetic everywhere. It has some overhead but it does its job.

Hey there, looks like an interesting solution... Although I hope we can find something better than just adding brackets, as I think it is a bit cheaty to just add brackets, it sort of takes away the automatic part of the casting.... It does seems to currently be the best we can do for now though. Thanks!