diff --git a/pkg/types/rbtree.go b/pkg/types/rbtree.go index fc1c84b2a..6064d742f 100644 --- a/pkg/types/rbtree.go +++ b/pkg/types/rbtree.go @@ -11,11 +11,9 @@ type RBTree struct { size int } -var neel = &RBNode{color: Black} - func NewRBTree() *RBTree { - var root = neel - root.parent = neel + var root = NewNil() + root.parent = NewNil() return &RBTree{ Root: root, } @@ -35,7 +33,7 @@ func (tree *RBTree) Delete(key fixedpoint.Value) bool { // the deleting node has only one child, it's easy, // we just connect the child the parent of the deleting node - if deleting.left == neel || deleting.right == neel { + if deleting.left.isNil() || deleting.right.isNil() { y = deleting // fmt.Printf("y = deleting = %+v\n", y) } else { @@ -47,16 +45,16 @@ func (tree *RBTree) Delete(key fixedpoint.Value) bool { } // y.left or y.right could be neel - if y.left != neel { - x = y.left - } else { + if y.left.isNil() { x = y.right + } else { + x = y.left } // fmt.Printf("x = %+v\n", y) x.parent = y.parent - if y.parent == neel { + if y.parent.isNil() { tree.Root = x } else if y == y.parent.left { y.parent.left = x @@ -144,18 +142,18 @@ func (tree *RBTree) DeleteFixup(current *RBNode) { } func (tree *RBTree) Upsert(key, val fixedpoint.Value) { - var y = neel + var y = NewNil() var x = tree.Root var node = &RBNode{ key: key, value: val, color: Red, - left: neel, - right: neel, - parent: neel, + left: NewNil(), + right: NewNil(), + parent: NewNil(), } - for x != neel { + for !x.isNil() { y = x if node.key == x.key { @@ -171,7 +169,7 @@ func (tree *RBTree) Upsert(key, val fixedpoint.Value) { node.parent = y - if y == neel { + if y.isNil() { tree.Root = node } else if node.key.Compare(y.key) < 0 { y.left = node @@ -183,18 +181,18 @@ func (tree *RBTree) Upsert(key, val fixedpoint.Value) { } func (tree *RBTree) Insert(key, val fixedpoint.Value) { - var y = neel + var y = NewNil() var x = tree.Root var node = &RBNode{ key: key, value: val, color: Red, - left: neel, - right: neel, - parent: neel, + left: NewNil(), + right: NewNil(), + parent: NewNil(), } - for x != neel { + for !x.isNil() { y = x if node.key.Compare(x.key) < 0 { @@ -206,7 +204,7 @@ func (tree *RBTree) Insert(key, val fixedpoint.Value) { node.parent = y - if y == neel { + if y.isNil() { tree.Root = node } else if node.key.Compare(y.key) < 0 { y.left = node @@ -220,7 +218,7 @@ func (tree *RBTree) Insert(key, val fixedpoint.Value) { func (tree *RBTree) Search(key fixedpoint.Value) *RBNode { var current = tree.Root - for current != neel && key != current.key { + for !current.isNil() && key != current.key { if key.Compare(current.key) < 0 { current = current.left } else { @@ -228,7 +226,7 @@ func (tree *RBTree) Search(key fixedpoint.Value) *RBNode { } } - if current == neel { + if current.isNil() { return nil } @@ -293,13 +291,13 @@ func (tree *RBTree) RotateLeft(x *RBNode) { var y = x.right x.right = y.left - if y.left != neel { + if !y.left.isNil() { y.left.parent = x } y.parent = x.parent - if x.parent == neel { + if x.parent.isNil() { tree.Root = y } else if x == x.parent.left { x.parent.left = y @@ -315,13 +313,13 @@ func (tree *RBTree) RotateRight(y *RBNode) { x := y.left y.left = x.right - if x.right != neel { + if !x.right.isNil() { x.right.parent = y } x.parent = y.parent - if y.parent == neel { + if y.parent.isNil() { tree.Root = x } else if y == y.parent.left { y.parent.left = x @@ -338,11 +336,11 @@ func (tree *RBTree) Rightmost() *RBNode { } func (tree *RBTree) RightmostOf(current *RBNode) *RBNode { - if current == neel || current == nil { + if current.isNil() || current == nil { return nil } - for current.right != neel { + for !current.right.isNil() { current = current.right } @@ -354,11 +352,11 @@ func (tree *RBTree) Leftmost() *RBNode { } func (tree *RBTree) LeftmostOf(current *RBNode) *RBNode { - if current == neel || current == nil { + if current.isNil() || current == nil { return nil } - for current.left != neel { + for !current.left.isNil() { current = current.left } @@ -366,12 +364,12 @@ func (tree *RBTree) LeftmostOf(current *RBNode) *RBNode { } func (tree *RBTree) Successor(current *RBNode) *RBNode { - if current.right != neel { + if !current.right.isNil() { return tree.LeftmostOf(current.right) } var newNode = current.parent - for newNode != neel && current == newNode.right { + for !newNode.isNil() && current == newNode.right { current = newNode newNode = newNode.parent } @@ -384,7 +382,7 @@ func (tree *RBTree) Preorder(cb func(n *RBNode)) { } func (tree *RBTree) PreorderOf(current *RBNode, cb func(n *RBNode)) { - if current != neel && current != nil { + if !current.isNil() && current != nil { cb(current) tree.PreorderOf(current.left, cb) tree.PreorderOf(current.right, cb) @@ -397,7 +395,7 @@ func (tree *RBTree) Inorder(cb func(n *RBNode) bool) { } func (tree *RBTree) InorderOf(current *RBNode, cb func(n *RBNode) bool) { - if current != neel && current != nil { + if !current.isNil() && current != nil { tree.InorderOf(current.left, cb) if !cb(current) { return @@ -412,7 +410,7 @@ func (tree *RBTree) InorderReverse(cb func(n *RBNode) bool) { } func (tree *RBTree) InorderReverseOf(current *RBNode, cb func(n *RBNode) bool) { - if current != neel && current != nil { + if !current.isNil() && current != nil { tree.InorderReverseOf(current.right, cb) if !cb(current) { return @@ -426,7 +424,7 @@ func (tree *RBTree) Postorder(cb func(n *RBNode) bool) { } func (tree *RBTree) PostorderOf(current *RBNode, cb func(n *RBNode) bool) { - if current != neel && current != nil { + if !current.isNil() && current != nil { tree.PostorderOf(current.left, cb) tree.PostorderOf(current.right, cb) if !cb(current) { diff --git a/pkg/types/rbtree_node.go b/pkg/types/rbtree_node.go index b387afe56..76560f6c1 100644 --- a/pkg/types/rbtree_node.go +++ b/pkg/types/rbtree_node.go @@ -20,3 +20,14 @@ type RBNode struct { key, value fixedpoint.Value color Color } + +func NewNil() *RBNode { + return &RBNode{color: Black} +} + +func (node *RBNode) isNil() bool { + if node == nil { + return true + } + return node.color == Black && node.left == nil && node.right == nil +} diff --git a/pkg/types/rbtree_test.go b/pkg/types/rbtree_test.go index 0daddf199..f3bac6aa0 100644 --- a/pkg/types/rbtree_test.go +++ b/pkg/types/rbtree_test.go @@ -2,6 +2,7 @@ package types import ( "math/rand" + "sync" "testing" "github.com/stretchr/testify/assert" @@ -11,6 +12,33 @@ import ( var itov func(int64) fixedpoint.Value = fixedpoint.NewFromInt +func TestRBTree_ConcurrentIndependence(t *testing.T) { + // each RBTree instances must not affect each other in concurrent environment + var wg sync.WaitGroup + for w := 0; w < 10; w++ { + wg.Add(1) + go func() { + defer wg.Done() + tree := NewRBTree() + for stepCnt := 0; stepCnt < 10000; stepCnt++ { + switch opCode := rand.Intn(2); opCode { + case 0: + priceI := rand.Int63n(16) + price := fixedpoint.NewFromInt(priceI) + tree.Delete(price) + case 1: + priceI := rand.Int63n(16) + volumeI := rand.Int63n(8) + tree.Upsert(fixedpoint.NewFromInt(priceI), fixedpoint.NewFromInt(volumeI)) + default: + panic("impossible") + } + } + }() + } + wg.Wait() +} + func TestRBTree_InsertAndDelete(t *testing.T) { tree := NewRBTree() node := tree.Rightmost() @@ -154,12 +182,12 @@ func TestRBTree_bulkInsert(t *testing.T) { pvs[price] = volume } tree.Inorder(func(n *RBNode) bool { - if n.left != neel { + if !n.left.isNil() { if !assert.True(t, n.key.Compare(n.left.key) > 0) { return false } } - if n.right != neel { + if !n.right.isNil() { if !assert.True(t, n.key.Compare(n.right.key) < 0) { return false } @@ -206,12 +234,12 @@ func TestRBTree_bulkInsertAndDelete(t *testing.T) { // validate tree structure tree.Inorder(func(n *RBNode) bool { - if n.left != neel { + if !n.left.isNil() { if !assert.True(t, n.key.Compare(n.left.key) > 0) { return false } } - if n.right != neel { + if !n.right.isNil() { if !assert.True(t, n.key.Compare(n.right.key) < 0) { return false }