近期在公司遇到一个应用和数据库间查询时的性能优化问题,在跟同事讨论解决方案时,最终选定了线性回归模型的办法。这篇短文旨在探讨利用决策树和线性回归模型来优化深度优先搜索(DFS)算法性能的Demo并且执行性能评估用于来用于其他选型参考比较。

树的构建与数据生成

尝试定义了一个简单的树结构和一个生成比较大的树的方法

class TreeNode:
    def __init__(self, value):
        self.value = value
        self.children = []

def createTreeBesar(depth, breadth):
    def addChildren(node, currentDepth):
        if currentDepth < depth:
            for _ in range(breadth):
                child = TreeNode(random.randint(1, 100))
                node.children.append(child)
                addChildren(child, currentDepth + 1)
    root = TreeNode(random.randint(1, 100))
    addChildren(root, 1)
    return root

产生样本数据

def generateSampleData():
    data = []
    for _ in range(10000):
        value = random.randint(1, 1000)
        priority = random.random()
        data.append([value, priority])
    data = np.array(data)
    X = data[:, :-1]
    y = data[:, -1]
    return X, y

模型的训练与加载

为避免浪费每次的运行时间和适合性能评估,将保存模型

def trainPriorityModel():
    X, y = generateSampleData()
    model = DecisionTreeRegressor()
    model.fit(X, y)
    joblib.dump(model, 'priorityModel.pkl')
    return model

def trainIndexModel(data):
    values = [node.value for node in data]
    positions = list(range(len(data)))
    model = LinearRegression()
    model.fit(np.array(values).reshape(-1, 1), positions)
    joblib.dump(model, 'indexModel.pkl')
    return model

def loadModel(filePath, trainFunc):
    if os.path.exists(filePath):
        return joblib.load(filePath)
    else:
        return trainFunc()

深度优先搜索和性能评估

def standardDfs(node, visited):
    if node is None or node in visited:
        return
    visited.add(node)
    for child in node.children:
        standardDfs(child, visited)

def indexedDfs(node, visited, indexModel, data):
    if node is None or node in visited:
        return
    visited.add(node)
    for child in node.children:
        locatedNode = locateNode(indexModel, child.value, data)
        indexedDfs(locatedNode, visited, indexModel, data)

def evaluatePerformance(treeRoot, priorityModel, indexModel, data):
    startTime = time.time()
    visitedStandard = set()
    standardDfs(treeRoot, visitedStandard)
    standardTime = time.time() - startTime
    print(f"Standard DFS Run Time: {standardTime:.6f} SEC")

    startTime = time.time()
    visitedIndexed = set()
    indexedDfs(treeRoot, visitedIndexed, indexModel, data)
    indexedTime = time.time() - startTime
    print(f"Indexed DFS Run Time: {indexedTime:.6f} SEC")

结果

结果好像很好的样子?
indexed_dfs_opt

评论