基于决策树和线性回归模型以优化深度优先搜索(DFS)性能
近期在公司遇到一个应用和数据库间查询时的性能优化问题,在跟同事讨论解决方案时,最终选定了线性回归模型的办法。这篇短文旨在探讨利用决策树和线性回归模型来优化深度优先搜索(DFS)算法性能的Demo并且执行性能评估用于来用于其他选型参考比较。
树的构建与数据生成
尝试定义了一个简单的树结构和一个生成比较大的树的方法
class TreeNode:
def __init__(self, value):
self.value = value
self.children = []
def createTreeBesar(depth, breadth):
def addChildren(node, currentDepth):
if currentDepth < depth:
for _ in range(breadth):
child = TreeNode(random.randint(1, 100))
node.children.append(child)
addChildren(child, currentDepth + 1)
root = TreeNode(random.randint(1, 100))
addChildren(root, 1)
return root
产生样本数据
def generateSampleData():
data = []
for _ in range(10000):
value = random.randint(1, 1000)
priority = random.random()
data.append([value, priority])
data = np.array(data)
X = data[:, :-1]
y = data[:, -1]
return X, y
模型的训练与加载
为避免浪费每次的运行时间和适合性能评估,将保存模型
def trainPriorityModel():
X, y = generateSampleData()
model = DecisionTreeRegressor()
model.fit(X, y)
joblib.dump(model, 'priorityModel.pkl')
return model
def trainIndexModel(data):
values = [node.value for node in data]
positions = list(range(len(data)))
model = LinearRegression()
model.fit(np.array(values).reshape(-1, 1), positions)
joblib.dump(model, 'indexModel.pkl')
return model
def loadModel(filePath, trainFunc):
if os.path.exists(filePath):
return joblib.load(filePath)
else:
return trainFunc()
深度优先搜索和性能评估
def standardDfs(node, visited):
if node is None or node in visited:
return
visited.add(node)
for child in node.children:
standardDfs(child, visited)
def indexedDfs(node, visited, indexModel, data):
if node is None or node in visited:
return
visited.add(node)
for child in node.children:
locatedNode = locateNode(indexModel, child.value, data)
indexedDfs(locatedNode, visited, indexModel, data)
def evaluatePerformance(treeRoot, priorityModel, indexModel, data):
startTime = time.time()
visitedStandard = set()
standardDfs(treeRoot, visitedStandard)
standardTime = time.time() - startTime
print(f"Standard DFS Run Time: {standardTime:.6f} SEC")
startTime = time.time()
visitedIndexed = set()
indexedDfs(treeRoot, visitedIndexed, indexModel, data)
indexedTime = time.time() - startTime
print(f"Indexed DFS Run Time: {indexedTime:.6f} SEC")
结果
结果好像很好的样子?