types: rbtree: resolve neel reusing problem

This commit is contained in:
Raphanus Lo 2022-08-08 00:33:52 +08:00
parent 0d82f32769
commit 318590f41b
3 changed files with 78 additions and 41 deletions

View File

@ -11,11 +11,9 @@ type RBTree struct {
size int size int
} }
var neel = &RBNode{color: Black}
func NewRBTree() *RBTree { func NewRBTree() *RBTree {
var root = neel var root = NewNil()
root.parent = neel root.parent = NewNil()
return &RBTree{ return &RBTree{
Root: root, Root: root,
} }
@ -35,7 +33,7 @@ func (tree *RBTree) Delete(key fixedpoint.Value) bool {
// the deleting node has only one child, it's easy, // the deleting node has only one child, it's easy,
// we just connect the child the parent of the deleting node // we just connect the child the parent of the deleting node
if deleting.left == neel || deleting.right == neel { if deleting.left.isNil() || deleting.right.isNil() {
y = deleting y = deleting
// fmt.Printf("y = deleting = %+v\n", y) // fmt.Printf("y = deleting = %+v\n", y)
} else { } else {
@ -47,16 +45,16 @@ func (tree *RBTree) Delete(key fixedpoint.Value) bool {
} }
// y.left or y.right could be neel // y.left or y.right could be neel
if y.left != neel { if y.left.isNil() {
x = y.left
} else {
x = y.right x = y.right
} else {
x = y.left
} }
// fmt.Printf("x = %+v\n", y) // fmt.Printf("x = %+v\n", y)
x.parent = y.parent x.parent = y.parent
if y.parent == neel { if y.parent.isNil() {
tree.Root = x tree.Root = x
} else if y == y.parent.left { } else if y == y.parent.left {
y.parent.left = x y.parent.left = x
@ -144,18 +142,18 @@ func (tree *RBTree) DeleteFixup(current *RBNode) {
} }
func (tree *RBTree) Upsert(key, val fixedpoint.Value) { func (tree *RBTree) Upsert(key, val fixedpoint.Value) {
var y = neel var y = NewNil()
var x = tree.Root var x = tree.Root
var node = &RBNode{ var node = &RBNode{
key: key, key: key,
value: val, value: val,
color: Red, color: Red,
left: neel, left: NewNil(),
right: neel, right: NewNil(),
parent: neel, parent: NewNil(),
} }
for x != neel { for !x.isNil() {
y = x y = x
if node.key == x.key { if node.key == x.key {
@ -171,7 +169,7 @@ func (tree *RBTree) Upsert(key, val fixedpoint.Value) {
node.parent = y node.parent = y
if y == neel { if y.isNil() {
tree.Root = node tree.Root = node
} else if node.key.Compare(y.key) < 0 { } else if node.key.Compare(y.key) < 0 {
y.left = node y.left = node
@ -183,18 +181,18 @@ func (tree *RBTree) Upsert(key, val fixedpoint.Value) {
} }
func (tree *RBTree) Insert(key, val fixedpoint.Value) { func (tree *RBTree) Insert(key, val fixedpoint.Value) {
var y = neel var y = NewNil()
var x = tree.Root var x = tree.Root
var node = &RBNode{ var node = &RBNode{
key: key, key: key,
value: val, value: val,
color: Red, color: Red,
left: neel, left: NewNil(),
right: neel, right: NewNil(),
parent: neel, parent: NewNil(),
} }
for x != neel { for !x.isNil() {
y = x y = x
if node.key.Compare(x.key) < 0 { if node.key.Compare(x.key) < 0 {
@ -206,7 +204,7 @@ func (tree *RBTree) Insert(key, val fixedpoint.Value) {
node.parent = y node.parent = y
if y == neel { if y.isNil() {
tree.Root = node tree.Root = node
} else if node.key.Compare(y.key) < 0 { } else if node.key.Compare(y.key) < 0 {
y.left = node y.left = node
@ -220,7 +218,7 @@ func (tree *RBTree) Insert(key, val fixedpoint.Value) {
func (tree *RBTree) Search(key fixedpoint.Value) *RBNode { func (tree *RBTree) Search(key fixedpoint.Value) *RBNode {
var current = tree.Root var current = tree.Root
for current != neel && key != current.key { for !current.isNil() && key != current.key {
if key.Compare(current.key) < 0 { if key.Compare(current.key) < 0 {
current = current.left current = current.left
} else { } else {
@ -228,7 +226,7 @@ func (tree *RBTree) Search(key fixedpoint.Value) *RBNode {
} }
} }
if current == neel { if current.isNil() {
return nil return nil
} }
@ -293,13 +291,13 @@ func (tree *RBTree) RotateLeft(x *RBNode) {
var y = x.right var y = x.right
x.right = y.left x.right = y.left
if y.left != neel { if !y.left.isNil() {
y.left.parent = x y.left.parent = x
} }
y.parent = x.parent y.parent = x.parent
if x.parent == neel { if x.parent.isNil() {
tree.Root = y tree.Root = y
} else if x == x.parent.left { } else if x == x.parent.left {
x.parent.left = y x.parent.left = y
@ -315,13 +313,13 @@ func (tree *RBTree) RotateRight(y *RBNode) {
x := y.left x := y.left
y.left = x.right y.left = x.right
if x.right != neel { if !x.right.isNil() {
x.right.parent = y x.right.parent = y
} }
x.parent = y.parent x.parent = y.parent
if y.parent == neel { if y.parent.isNil() {
tree.Root = x tree.Root = x
} else if y == y.parent.left { } else if y == y.parent.left {
y.parent.left = x y.parent.left = x
@ -338,11 +336,11 @@ func (tree *RBTree) Rightmost() *RBNode {
} }
func (tree *RBTree) RightmostOf(current *RBNode) *RBNode { func (tree *RBTree) RightmostOf(current *RBNode) *RBNode {
if current == neel || current == nil { if current.isNil() || current == nil {
return nil return nil
} }
for current.right != neel { for !current.right.isNil() {
current = current.right current = current.right
} }
@ -354,11 +352,11 @@ func (tree *RBTree) Leftmost() *RBNode {
} }
func (tree *RBTree) LeftmostOf(current *RBNode) *RBNode { func (tree *RBTree) LeftmostOf(current *RBNode) *RBNode {
if current == neel || current == nil { if current.isNil() || current == nil {
return nil return nil
} }
for current.left != neel { for !current.left.isNil() {
current = current.left current = current.left
} }
@ -366,12 +364,12 @@ func (tree *RBTree) LeftmostOf(current *RBNode) *RBNode {
} }
func (tree *RBTree) Successor(current *RBNode) *RBNode { func (tree *RBTree) Successor(current *RBNode) *RBNode {
if current.right != neel { if !current.right.isNil() {
return tree.LeftmostOf(current.right) return tree.LeftmostOf(current.right)
} }
var newNode = current.parent var newNode = current.parent
for newNode != neel && current == newNode.right { for !newNode.isNil() && current == newNode.right {
current = newNode current = newNode
newNode = newNode.parent newNode = newNode.parent
} }
@ -384,7 +382,7 @@ func (tree *RBTree) Preorder(cb func(n *RBNode)) {
} }
func (tree *RBTree) PreorderOf(current *RBNode, cb func(n *RBNode)) { func (tree *RBTree) PreorderOf(current *RBNode, cb func(n *RBNode)) {
if current != neel && current != nil { if !current.isNil() && current != nil {
cb(current) cb(current)
tree.PreorderOf(current.left, cb) tree.PreorderOf(current.left, cb)
tree.PreorderOf(current.right, cb) tree.PreorderOf(current.right, cb)
@ -397,7 +395,7 @@ func (tree *RBTree) Inorder(cb func(n *RBNode) bool) {
} }
func (tree *RBTree) InorderOf(current *RBNode, cb func(n *RBNode) bool) { func (tree *RBTree) InorderOf(current *RBNode, cb func(n *RBNode) bool) {
if current != neel && current != nil { if !current.isNil() && current != nil {
tree.InorderOf(current.left, cb) tree.InorderOf(current.left, cb)
if !cb(current) { if !cb(current) {
return return
@ -412,7 +410,7 @@ func (tree *RBTree) InorderReverse(cb func(n *RBNode) bool) {
} }
func (tree *RBTree) InorderReverseOf(current *RBNode, cb func(n *RBNode) bool) { func (tree *RBTree) InorderReverseOf(current *RBNode, cb func(n *RBNode) bool) {
if current != neel && current != nil { if !current.isNil() && current != nil {
tree.InorderReverseOf(current.right, cb) tree.InorderReverseOf(current.right, cb)
if !cb(current) { if !cb(current) {
return return
@ -426,7 +424,7 @@ func (tree *RBTree) Postorder(cb func(n *RBNode) bool) {
} }
func (tree *RBTree) PostorderOf(current *RBNode, cb func(n *RBNode) bool) { func (tree *RBTree) PostorderOf(current *RBNode, cb func(n *RBNode) bool) {
if current != neel && current != nil { if !current.isNil() && current != nil {
tree.PostorderOf(current.left, cb) tree.PostorderOf(current.left, cb)
tree.PostorderOf(current.right, cb) tree.PostorderOf(current.right, cb)
if !cb(current) { if !cb(current) {

View File

@ -20,3 +20,14 @@ type RBNode struct {
key, value fixedpoint.Value key, value fixedpoint.Value
color Color color Color
} }
func NewNil() *RBNode {
return &RBNode{color: Black}
}
func (node *RBNode) isNil() bool {
if node == nil {
return true
}
return node.color == Black && node.left == nil && node.right == nil
}

View File

@ -2,6 +2,7 @@ package types
import ( import (
"math/rand" "math/rand"
"sync"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@ -11,6 +12,33 @@ import (
var itov func(int64) fixedpoint.Value = fixedpoint.NewFromInt var itov func(int64) fixedpoint.Value = fixedpoint.NewFromInt
func TestRBTree_ConcurrentIndependence(t *testing.T) {
// each RBTree instances must not affect each other in concurrent environment
var wg sync.WaitGroup
for w := 0; w < 10; w++ {
wg.Add(1)
go func() {
defer wg.Done()
tree := NewRBTree()
for stepCnt := 0; stepCnt < 10000; stepCnt++ {
switch opCode := rand.Intn(2); opCode {
case 0:
priceI := rand.Int63n(16)
price := fixedpoint.NewFromInt(priceI)
tree.Delete(price)
case 1:
priceI := rand.Int63n(16)
volumeI := rand.Int63n(8)
tree.Upsert(fixedpoint.NewFromInt(priceI), fixedpoint.NewFromInt(volumeI))
default:
panic("impossible")
}
}
}()
}
wg.Wait()
}
func TestRBTree_InsertAndDelete(t *testing.T) { func TestRBTree_InsertAndDelete(t *testing.T) {
tree := NewRBTree() tree := NewRBTree()
node := tree.Rightmost() node := tree.Rightmost()
@ -154,12 +182,12 @@ func TestRBTree_bulkInsert(t *testing.T) {
pvs[price] = volume pvs[price] = volume
} }
tree.Inorder(func(n *RBNode) bool { tree.Inorder(func(n *RBNode) bool {
if n.left != neel { if !n.left.isNil() {
if !assert.True(t, n.key.Compare(n.left.key) > 0) { if !assert.True(t, n.key.Compare(n.left.key) > 0) {
return false return false
} }
} }
if n.right != neel { if !n.right.isNil() {
if !assert.True(t, n.key.Compare(n.right.key) < 0) { if !assert.True(t, n.key.Compare(n.right.key) < 0) {
return false return false
} }
@ -206,12 +234,12 @@ func TestRBTree_bulkInsertAndDelete(t *testing.T) {
// validate tree structure // validate tree structure
tree.Inorder(func(n *RBNode) bool { tree.Inorder(func(n *RBNode) bool {
if n.left != neel { if !n.left.isNil() {
if !assert.True(t, n.key.Compare(n.left.key) > 0) { if !assert.True(t, n.key.Compare(n.left.key) > 0) {
return false return false
} }
} }
if n.right != neel { if !n.right.isNil() {
if !assert.True(t, n.key.Compare(n.right.key) < 0) { if !assert.True(t, n.key.Compare(n.right.key) < 0) {
return false return false
} }