From aae33eb0d12ea3d65976db6c029798b5a69352ca Mon Sep 17 00:00:00 2001 From: c9s Date: Thu, 31 Oct 2024 15:16:26 +0800 Subject: [PATCH] fixedpoint: implement different mul mode --- pkg/fixedpoint/convert.go | 67 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 65 insertions(+), 2 deletions(-) diff --git a/pkg/fixedpoint/convert.go b/pkg/fixedpoint/convert.go index d000603a8..f61c657fa 100644 --- a/pkg/fixedpoint/convert.go +++ b/pkg/fixedpoint/convert.go @@ -8,6 +8,8 @@ import ( "errors" "fmt" "math" + "math/big" + "os" "strconv" "strings" "sync/atomic" @@ -227,11 +229,11 @@ func (v Value) IsZero() bool { } func Mul(x, y Value) Value { - return NewFromFloat(x.Float64() * y.Float64()) + return mulOp(x, y) } func (v Value) Mul(v2 Value) Value { - return NewFromFloat(v.Float64() * v2.Float64()) + return mulOp(v, v2) } func Div(x, y Value) Value { @@ -610,3 +612,64 @@ func (x Value) Clamp(min, max Value) Value { } return x } + +type Operator func(a, b Value) Value + +// float64 mul is the current fastest implementation +var mulOp = multiplyValueWithFloat64 + +func init() { + fpMode := os.Getenv("FP_MODE") + switch fpMode { + case "int64": + mulOp = multiplyWithInt64 + + case "float64": + mulOp = multiplyValueWithFloat64 + + case "mathbig": + mulOp = multiplyBigInt + + } + +} + +func multiplyBigInt(a, b Value) Value { + bigA := big.NewInt(int64(a)) + bigB := big.NewInt(int64(b)) + result := new(big.Int).Mul(bigA, bigB) + + // 恢復精度 (除以 10^8) + scalingFactor := big.NewInt(1e8) + result.Div(result, scalingFactor) + + return Value(result.Int64()) +} + +func multiplyValueWithFloat64(x, y Value) Value { + return NewFromFloat(x.Float64() * y.Float64()) +} + +func multiplyWithInt64(a, b Value) Value { + const scaleFactor = 100000000 + + highA := a / 10000 + lowA := a % 10000 + highB := b / 10000 + lowB := b % 10000 + + highResult := highA * highB + midResult1 := highA * lowB + midResult2 := lowA * highB + lowResult := (lowA * lowB) / 10000 + + // merge result + result := highResult*scaleFactor + (midResult1+midResult2)*10000 + lowResult + + // add the remainder + result += scaleFactor / 2 + + result /= scaleFactor + + return result +}