From 0a602bc259e0d41dc39a695be37b0adf01186c6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E3=81=AA=E3=82=8B=E3=81=BF?= Date: Thu, 16 Jun 2022 01:46:33 +0800 Subject: [PATCH] rebalance: add ValueMap --- pkg/fixedpoint/value_map.go | 135 +++++++++++++++++++++++++++++++ pkg/fixedpoint/value_map_test.go | 123 ++++++++++++++++++++++++++++ 2 files changed, 258 insertions(+) create mode 100644 pkg/fixedpoint/value_map.go create mode 100644 pkg/fixedpoint/value_map_test.go diff --git a/pkg/fixedpoint/value_map.go b/pkg/fixedpoint/value_map.go new file mode 100644 index 000000000..f65f2a0cf --- /dev/null +++ b/pkg/fixedpoint/value_map.go @@ -0,0 +1,135 @@ +package fixedpoint + +type ValueMap map[string]Value + +func (m ValueMap) Eq(n ValueMap) bool { + if len(m) != len(n) { + return false + } + + for m_k, m_v := range m { + n_v, ok := n[m_k] + if !ok { + return false + } + + if !m_v.Eq(n_v) { + return false + } + } + + return true +} + +func (m ValueMap) Add(n ValueMap) ValueMap { + if len(m) != len(n) { + panic("unequal length") + } + + o := ValueMap{} + + for k, v := range m { + o[k] = v.Add(n[k]) + } + + return o +} + +func (m ValueMap) Sub(n ValueMap) ValueMap { + if len(m) != len(n) { + panic("unequal length") + } + + o := ValueMap{} + + for k, v := range m { + o[k] = v.Sub(n[k]) + } + + return o +} + +func (m ValueMap) Mul(n ValueMap) ValueMap { + if len(m) != len(n) { + panic("unequal length") + } + + o := ValueMap{} + + for k, v := range m { + o[k] = v.Mul(n[k]) + } + + return o +} + +func (m ValueMap) Div(n ValueMap) ValueMap { + if len(m) != len(n) { + panic("unequal length") + } + + o := ValueMap{} + + for k, v := range m { + o[k] = v.Div(n[k]) + } + + return o +} + +func (m ValueMap) AddScalar(x Value) ValueMap { + o := ValueMap{} + + for k, v := range m { + o[k] = v.Add(x) + } + + return o +} + +func (m ValueMap) SubScalar(x Value) ValueMap { + o := ValueMap{} + + for k, v := range m { + o[k] = v.Sub(x) + } + + return o +} + +func (m ValueMap) MulScalar(x Value) ValueMap { + o := ValueMap{} + + for k, v := range m { + o[k] = v.Mul(x) + } + + return o +} + +func (m ValueMap) DivScalar(x Value) ValueMap { + o := ValueMap{} + + for k, v := range m { + o[k] = v.Div(x) + } + + return o +} + +func (m ValueMap) Sum() Value { + var sum Value + for _, v := range m { + sum = sum.Add(v) + } + return sum +} + +func (m ValueMap) Normalize() ValueMap { + sum := m.Sum() + if sum.Eq(Zero) { + panic("zero sum") + } + + return m.DivScalar(sum) +} diff --git a/pkg/fixedpoint/value_map_test.go b/pkg/fixedpoint/value_map_test.go new file mode 100644 index 000000000..8535caaba --- /dev/null +++ b/pkg/fixedpoint/value_map_test.go @@ -0,0 +1,123 @@ +package fixedpoint + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_ValueMap_Eq(t *testing.T) { + m1 := ValueMap{ + "A": NewFromFloat(3.0), + "B": NewFromFloat(4.0), + } + + m2 := ValueMap{} + + m3 := ValueMap{"A": NewFromFloat(5.0)} + + m4 := ValueMap{ + "A": NewFromFloat(6.0), + "B": NewFromFloat(7.0), + } + + m5 := ValueMap{ + "A": NewFromFloat(3.0), + "B": NewFromFloat(4.0), + } + + assert.True(t, m1.Eq(m1)) + assert.False(t, m1.Eq(m2)) + assert.False(t, m1.Eq(m3)) + assert.False(t, m1.Eq(m4)) + assert.True(t, m1.Eq(m5)) +} + +func Test_ValueMap_Add(t *testing.T) { + m1 := ValueMap{ + "A": NewFromFloat(3.0), + "B": NewFromFloat(4.0), + } + + m2 := ValueMap{ + "A": NewFromFloat(5.0), + "B": NewFromFloat(6.0), + } + + m3 := ValueMap{ + "A": NewFromFloat(8.0), + "B": NewFromFloat(10.0), + } + + m4 := ValueMap{"A": NewFromFloat(8.0)} + + assert.Equal(t, m3, m1.Add(m2)) + assert.Panics(t, func() { m1.Add(m4) }) +} + +func Test_ValueMap_AddScalar(t *testing.T) { + x := NewFromFloat(5.0) + + m1 := ValueMap{ + "A": NewFromFloat(3.0), + "B": NewFromFloat(4.0), + } + + m2 := ValueMap{ + "A": NewFromFloat(3.0).Add(x), + "B": NewFromFloat(4.0).Add(x), + } + + assert.Equal(t, m2, m1.AddScalar(x)) +} + +func Test_ValueMap_DivScalar(t *testing.T) { + x := NewFromFloat(5.0) + + m1 := ValueMap{ + "A": NewFromFloat(3.0), + "B": NewFromFloat(4.0), + } + + m2 := ValueMap{ + "A": NewFromFloat(3.0).Div(x), + "B": NewFromFloat(4.0).Div(x), + } + + assert.Equal(t, m2, m1.DivScalar(x)) +} + +func Test_ValueMap_Sum(t *testing.T) { + m := ValueMap{ + "A": NewFromFloat(3.0), + "B": NewFromFloat(4.0), + } + + assert.Equal(t, NewFromFloat(7.0), m.Sum()) +} + +func Test_ValueMap_Normalize(t *testing.T) { + a := NewFromFloat(3.0) + b := NewFromFloat(4.0) + + m := ValueMap{ + "A": a, + "B": b, + } + + n := ValueMap{ + "A": a.Div(a.Add(b)), + "B": b.Div(a.Add(b)), + } + + assert.True(t, m.Normalize().Eq(n)) +} + +func Test_ValueMap_Normalize_zero_sum(t *testing.T) { + m := ValueMap{ + "A": Zero, + "B": Zero, + } + + assert.Panics(t, func() { m.Normalize() }) +}