首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >在numpy数组中找到平衡点

在numpy数组中找到平衡点
EN

Stack Overflow用户
提问于 2015-05-19 04:38:51
回答 5查看 1.3K关注 0票数 3

考虑一下这个数组:

代码语言:javascript
运行
复制
a = np.array([1,2,3,4,3,2,1])

我想要得到均分数组的元素,即元素之前的数组之和等于数组后面数组的和。在本例中,第4个元素a[3]均匀地划分数组。有没有一种更快的方法来做到这一点?还是必须对所有元素进行迭代?

期望的功能:

代码语言:javascript
运行
复制
f(a) =  3
EN

回答 5

Stack Overflow用户

回答已采纳

发布于 2015-05-19 08:30:38

如果所有的输入值都是非负的,那么最有效的方法之一似乎是构建一个累积和数组,然后二进制搜索它的位置,两边之和的一半。然而,也很容易搞错这样的二进制搜索。在尝试让二进制搜索在所有边缘情况下工作时,我最后进行了以下测试:

代码语言:javascript
运行
复制
class SplitpointTest(unittest.TestCase):
    def testFloatRounding(self):
        # Due to rounding error, the cumulative sums for these inputs are
        # [1.1, 3.3000000000000003, 3.3000000000000003, 5.5, 6.6]
        # and [0.1, 0.7999999999999999, 0.7999999999999999, 1.5, 1.6]
        # Note that under default settings, numpy won't display
        # enough precision to see that.
        self.assertEquals(2, splitpoint([1.1, 2.2, 1e-20, 2.2, 1.1]))
        self.assertEquals(2, splitpoint([0.1, 0.7, 1e-20, 0.7, 0.1]))

    def testIntRounding(self):
        self.assertEquals(1, splitpoint([1, 1, 1]))
    def testIntPrecision(self):
        self.assertEquals(2, splitpoint([2**60, 1, 1, 1, 2**60]))
    def testIntMax(self):
        self.assertEquals(
            2,
            splitpoint(numpy.array([40, 23, 1, 63], dtype=numpy.int8))
        )

    def testIntZeros(self):
        self.assertEquals(
            4,
            splitpoint(numpy.array([0, 1, 0, 2, 0, 2, 0, 1], dtype=int))
        )
    def testFloatZeros(self):
        self.assertEquals(
            4,
            splitpoint(numpy.array([0, 1, 0, 2, 0, 2, 0, 1], dtype=float))
        )

在决定不值得之前,我看了以下几个版本:

代码语言:javascript
运行
复制
def splitpoint(a):
    c = numpy.cumsum(a)
    return numpy.searchsorted(c, c[-1]/2)
    # Fails on [1, 1, 1]

def splitpoint(a):
    c = numpy.cumsum(a)
    return numpy.searchsorted(c, c[-1]/2.0)
    # Fails on [2**60, 1, 1, 1, 2**60]

def splitpoint(a):
    c = numpy.cumsum(a)
    if c.dtype.kind == 'f':
        # Floating-point input.
        return numpy.searchsorted(c, c[-1]/2.0)
    elif c.dtype.kind in ('i', 'u'):
        # Integer input.
        return numpy.searchsorted(c, (c[-1]+1)//2)
    else:
        # Probably an object dtype. No great options.
        return numpy.searchsorted(c, c[-1]/2.0)
    # Fails on numpy.array([63, 1, 63], dtype=int8)

def splitpoint(a):
    c = numpy.cumsum(a)
    if c.dtype.kind == 'f':
        # Floating-point input.
        return numpy.searchsorted(c, c[-1]/2.0)
    elif c.dtype.kind in ('i', 'u'):
        # Integer input.
        return numpy.searchsorted(c, c[-1]//2 + c[-1]%2)
    else:
        # Probably an object dtype. No great options.
        return numpy.searchsorted(c, c[-1]/2.0)
    # Still fails the floating-point rounding and zeros tests.

如果我继续努力的话,我也许能让它发挥作用,但这是不值得的。the 21的第二个解决方案是基于显式最小化左和和绝对差的解,它更容易推理,也更适用。添加了a = numpy.asarray(a)后,它通过了上述所有测试用例和以下测试,这些测试扩展了该算法需要处理的输入类型:

代码语言:javascript
运行
复制
class SplitpointGeneralizedTest(unittest.TestCase):
    def testNegatives(self):
        self.assertEquals(2, splitpoint([-1, 5, 2, 4]))
    def testComplex(self):
        self.assertEquals(2, splitpoint([1+1j, -5+2j, 43, -4+3j]))
    def testObjectDtype(self):
        from fractions import Fraction
        from decimal import Decimal
        self.assertEquals(2, splitpoint(map(Fraction, [1.5, 2.5, 3.5, 4])))
        self.assertEquals(2, splitpoint(map(Decimal, [1.5, 2.5, 3.5, 4])))

除非特别发现它太慢,否则我会使用chw21 21的第二个解决方案。在我测试它的稍微修改的形式中,这将是如下所示:

代码语言:javascript
运行
复制
def splitpoint(a):
    a = np.asarray(a)
    c1 = a.cumsum()
    c2 = a[::-1].cumsum()[::-1]
    return np.argmin(np.abs(c1-c2))

我能看到的唯一缺陷是,如果输入有一个无符号的dtype,并且没有精确分割输入的索引,则该算法可能不会返回与拆分输入最接近的索引,因为np.abs(c1-c2)没有对无符号数据类型执行正确的操作。在没有拆分索引的情况下,从来没有指定算法应该做什么,所以这种行为是可以接受的,尽管它可能值得在注释中记录一下np.abs(c1-c2)和无符号的dtype。如果我们想要最接近分割输入的索引,我们可以以额外的运行时为代价获得它:

代码语言:javascript
运行
复制
def splitpoint(a):
    a = np.asarray(a)
    c1 = a.cumsum()
    c2 = a[::-1].cumsum()[::-1]
    if a.dtype.kind == 'u':
        # np.abs(c1-c2) doesn't work on unsigned ints
        absdiffs = np.where(c1>c2, c1-c2, c2-c1)
    else:
        # c1>c2 doesn't work on complex input.
        # We also use this case for other dtypes, since it's
        # probably faster.
        absdiffs = np.abs(c1-c2)
    return np.argmin(absdiffs)

当然,下面是对此行为的测试,修改后的表单通过,而未修改的表单失败:

代码语言:javascript
运行
复制
class SplitpointUnsignedTest(unittest.TestCase):
    def testBestApproximation(self):
        self.assertEquals(1, splitpoint(numpy.array([5, 5, 4, 5], dtype=numpy.uint32)))
票数 4
EN

Stack Overflow用户

发布于 2015-05-19 04:46:06

我会去做这样的事:

代码语言:javascript
运行
复制
def equib(a):
    c = a.cumsum()
    return np.argmin(np.abs(c-(c[-1]/2)))

首先,我们建立了a的累积和。Cumsum意思是c[i] = sum(a[:i])。而不是我们所看到的绝对值,数值和总重量之间的差异变得最小。

Update @DSM注意到我的第一个版本有一点偏移,下面是另一个版本:

代码语言:javascript
运行
复制
def equib(a):
    c1 = a.cumsum()
    c2 = a[::-1].cumsum()[::-1]
    return np.argmin(np.abs(c1-c2))
票数 4
EN

Stack Overflow用户

发布于 2018-06-20 16:02:09

您可以从下面的代码中找到平衡点。

代码语言:javascript
运行
复制
    a = [1,3,5,2,2]
b = equilibirum(a)
n = len(a)
first_sum = 0
last_sum = 0
if n==1:
    print (1)
    return 0
for i in range(n):
    first_sum=first_sum+a[i]
    for j in range(i+2,n):
        last_sum=last_sum+a[j]
    if first_sum ==last_sum:
        s=i+2
        print (s)
        return 0
    last_sum=0
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/30316867

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档