首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在gmpy2和numba等优化方法中防止大整数溢出

在gmpy2和numba等优化方法中防止大整数溢出
EN

Stack Overflow用户
提问于 2022-01-11 09:11:58
回答 2查看 255关注 0票数 1

我试图在使用gmpy2的JIT修饰(优化)例程中检查一个大整数是否是一个完美的平方。这里的例子只是为了说明性的目的(从理论的角度来看,这样的方程或椭圆曲线可以被不同地对待/更好地处理)。我的代码似乎溢出,因为它产生的解决方案不是真正的:

代码语言:javascript
复制
import numpy as np
from numba import jit
import gmpy2
from gmpy2 import mpz, xmpz

import time
import sys

@jit('void(uint64)')
def findIntegerSolutionsGmpy2(limit: np.uint64):
    for x in np.arange(0, limit+1, dtype=np.uint64):
        y = mpz(x**6-4*x**2+4)
        if gmpy2.is_square(y):
            print([x,gmpy2.sqrt(y),y])

def main() -> int:
    limit = 100000000
    start = time.time()
    findIntegerSolutionsGmpy2(limit)
    end = time.time()
    print("Time elapsed: {0}".format(end - start))
    return 0

if __name__ == '__main__':
    sys.exit(main())

使用limit = 1000000000,例程在大约范围内完成。4秒。这个限制,我把它交给修饰函数,不会超过64位的无符号整数(这里似乎不是问题)。

我读到大整数不能与numba的JIT优化结合使用(例如,请参阅这里)。

我的问题:是否有可能在(GPU)优化代码中使用大整数?

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2022-01-22 07:18:08

错误结果的真正原因很简单,您忘记了将x转换为mpz,因此语句x ** 6 - 4 * x ** 2 + 4被提升为np.uint64类型并通过溢出计算(因为语句中的xnp.uint64)。修复很简单,只需添加x = mpz(x)

代码语言:javascript
复制
@jit('void(uint64)', forceobj = True)
def findIntegerSolutionsGmpy2(limit: np.uint64):
    for x in np.arange(0, limit+1, dtype=np.uint64):
        x = mpz(x)
        y = mpz(x**6-4*x**2+4)
        if gmpy2.is_square(y):
            print([x,gmpy2.sqrt(y),y])

另外,您可能会注意到,我添加了forceobj = True,这是为了在开始时抑制Numba编译警告。

在这个修复之后,一切都很好,你不会看到错误的结果。

如果您的任务是检查表达式是否给出了严格的平方,那么我决定为您发明并实现另一个解决方案,代码如下。

它的工作原理如下。您可能会注意到,如果一个数字是平方的,那么它也是平方模任意数(取模数是x % N运算)。

我们可以取任何数字,例如一些素数的乘积,K = 2 * 2 * 3 * 5 * 7 * 11 * 13 * 17 * 19。现在我们可以做一个简单的滤波器,计算所有的平方模K,在位向量内标记这个平方,然后检查在这个滤波器位向量中模K有哪些数。

过滤器K(素数的乘积),上面提到的,只留下1%的候选方。我们也可以做第二阶段,应用相同的过滤器与其他素数,例如K2 = 23 * 29 * 31 * 37 * 41。这将过滤他们,即使是3%。总之,我们将有剩余的1% * 3% = 0.03%数量的初始候选人。

经过两次过滤后,只剩下几个号码需要检查。它们可以很容易地用gmpy2.is_square()快速检查.

过滤阶段可以很容易地封装到Numba函数中,就像我下面所做的那样,这个函数可以有额外的Numba parallel = True,这将告诉Numba自动在所有CPU核上并行运行所有Numpy操作。

在我使用limit = 1 << 30的代码中,这表示要检查的所有x的限制,我使用block = 1 << 26,这意味着一次要检查多少个数字,并行Numba函数。如果您有足够的内存,您可以将block设置为更大,以便更有效地占用所有CPU核心。大小的块1 << 26大约使用1GB的内存。

在使用我的过滤思想和使用多核CPU之后,我的代码解决了与您相同的任务,速度比您快一百倍。

在网上试试!

代码语言:javascript
复制
import numpy as np, numba

@numba.njit('u8[:](u8[:], u8, u8, u1[:])', cache = True, parallel = True)
def do_filt(x, i, K, filt):
    x += i; x %= K
    x2 = x
    x2 *= x2;     x2 %= K
    x6 = x2 * x2; x6 %= K
    x6 *= x2;     x6 %= K
    x6 += np.uint64(4 * K + 4)
    x2 <<= np.uint64(2)
    x6 -= x2; x6 %= K
    y = x6
    #del x2
    filt_y = filt[y]
    filt_y_i = np.flatnonzero(filt_y).astype(np.uint64)
    return filt_y_i

def main():
    import math
    gmpy2 = None
    import gmpy2
    
    Int = lambda x: (int(x) if gmpy2 is None else gmpy2.mpz(x))
    IsSquare = lambda x: gmpy2.is_square(x)
    Sqrt = lambda x: Int(gmpy2.sqrt(x))
    
    Ks = [2 * 2 * 3 * 5 * 7 * 11 * 13 * 17 * 19,    23 * 29 * 31 * 37 * 41]
    filts = []
    for i, K in enumerate(Ks):
        a = np.arange(K, dtype = np.uint64)
        a *= a
        a %= K
        filts.append((K, np.zeros((K,), dtype = np.uint8)))
        filts[-1][1][a] = 1
        print(f'filter {i} ratio', round(len(np.flatnonzero(filts[-1][1])) / K, 4))
    
    limit = 1 << 30
    block = 1 << 26
    
    for i in range(0, limit, block):
        print(f'i block {i // block:>3} (2^{math.log2(i + 1):>6.03f})')
        x = np.arange(0, min(block, limit - i), dtype = np.uint64)
        
        for ifilt, (K, filt) in enumerate(filts):
            len_before = len(x)
            x = do_filt(x, i, K, filt)
            print(f'squares filtered by filter {ifilt}:', round(len(x) / len_before, 4))
        
        x_to_check = x
        print(f'remain to check {len(x_to_check)}')
        
        sq_x = []
        for x0 in x_to_check:
            x = Int(i + x0)
            y = x ** 6 - 4 * x ** 2 + 4
            if not IsSquare(y):
                continue
            yr = Sqrt(y)
            assert yr * yr == y
            sq_x.append((int(x), int(yr)))
        print('squares found', len(sq_x))
        print(sq_x)
        
        del x

if __name__ == '__main__':
    main()

输出:

代码语言:javascript
复制
filter 0 ratio 0.0094
filter 1 ratio 0.0366
i block   0 (2^ 0.000)
squares filtered by filter 0: 0.0211
squares filtered by filter 1: 0.039
remain to check 13803
squares found 2
[(0, 2), (1, 1)]
i block   1 (2^24.000)
squares filtered by filter 0: 0.0211
squares filtered by filter 1: 0.0392
remain to check 13880
squares found 0
[]
i block   2 (2^25.000)
squares filtered by filter 0: 0.0211
squares filtered by filter 1: 0.0391
remain to check 13835
squares found 0
[]
i block   3 (2^25.585)
squares filtered by filter 0: 0.0211
squares filtered by filter 1: 0.0393
remain to check 13907
squares found 0
[]

...............................
票数 0
EN

Stack Overflow用户

发布于 2022-01-14 11:04:51

现在,我可以通过以下代码设法避免精度的丢失:

代码语言:javascript
复制
@jit('void(uint64)')
def findIntegerSolutionsGmpy2(limit: np.uint64):
    for x in np.arange(0, limit+1, dtype=np.uint64):
        x_ = mpz(int(x))**2
        y = x_**3-mpz(4)*x_+mpz(4)
        if gmpy2.is_square(y):
            print([x,gmpy2.sqrt(y),y])

但是通过使用limit = 100000000,这个修正/固定的例程不再在4秒内完成。现在用了912秒。很可能我们在精确性和速度之间有一个无法克服的差距。

使用CUDA变得更快,即5分钟(拥有128 of内存的机器、英特尔Xeon E5-2630 v4、2.20GHz处理器和两张特斯拉V100类型的图形卡,每张都有16 of内存),但我得到的结果也是错误的。

代码语言:javascript
复制
%%time
from numba import jit, cuda
import numpy as np
from math import sqrt

@cuda.jit
def findIntegerSolutionsCuda(arr):
    i=0
    for x in range(0, 1000000000+1):
        y = float(x**6-4*x**2+4)
        sqr = int(sqrt(y))
        if sqr*sqr == int(y):
            arr[i][0]=x
            arr[i][1]=sqr
            arr[i][2]=y
            i+=1

arr=np.zeros((10,3))
findIntegerSolutionsCuda[128, 255](arr)

print(arr)
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/70664185

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档