牛骨文教育服务平台(让学习变的简单)
博文笔记

剖析3-sum问题(Three sum)

创建时间:2014-11-03 投稿人: 浏览次数:3064

日常生活中经常遇到求解数组中是否存在和为0的三个数,即3-sum问题,为此本文介绍一些比较实用的方法,并在暴力计算的基础上对算法逐步改进,以达到最优的算法。

首先介绍一下2-sum问题,3-sum问题其实就是2-sum问题的一个扩展而已。对于2-sum问题,最一般的想法就是使用两层循环直接枚举数组中的所有元素对,直到找到和为0的元素对为止。

int sum_2()
{
	int res = 0;
	int n = data.size();
	for(int i=0; i<n; i++)
	{
		for(int j=i+1; j<n; j++)
		{
			if(data[i] + data[j] == 0)
			{
				res ++;
			}
		}
	}
	return res;
}

上述算法的由于包含了两层循环,因此时间复杂度为O(N^2)。

观察发现,上述算法时间主要花费在数据比对,为此可以考虑使用二分查找来减少数据比对时间,要想使用二分查找,首先应该对数据进行排序,在此使用归并排序对数组进行升序排列。排序所花时间为O(NlogN),排序之后数据查找只需要O(logN)的时间,但是总共需要查找N此,为此改进后算法的时间复杂度为O(NlogN)。

int cal_sum_2()
{
	int res = 0;
	for(int i=0; i<data.size(); i++)
	{
		int j = binary_search(-data[i]);
		if(j > i)	
		res++;
	}
	return res;
}
观察上述算法发现,我们在比对的过程中还是存在了一些冗余。因为排列后的数据是从最小的数开始匹配的,我们只需计算其与最后的数据的和是否为0即可,如果大于0,则说明不存在与最小数匹配的数,此时将用较小的数来替代最大的数,反之则选用较大的数替代最小的数,如此反复,只需要扫描一遍数组即可得到所有符合条件的元素对。此算法所用的时间主要还是数组排序的时间,即O(NlogN)。

int cal_sum_2_update()
{
	int res = 0;
	for(int i=0,j=data.size()-1; i<j; )
	{
		if(data[i] + data[j] > 0)
			j--;
		else if(data[i] + data[j] < 0)
			i++;
		else
		{
			res++;
			j--;
			i++;
		}
	}
	return res;
}

上述2-sum的解题思路适用于3-sum及4-sum问题,如求解a+b+c=0,可将其转换为求解a+b=-c,此就为2-sum问题。为此将2-sum,3-sum,4-sum的求解方法以及相应的优化方法实现在如下所示的sum类中。

sum类定义
#ifndef SUM_H
#define SUM_H
#include <vector>
using std::vector;
class sum
{
private:
	vector<int> data;
public:
	sum(){};
	sum(const vector<int>& a);
	~sum(){};
	int cal_sum_2() const;
	int cal_sum_3() const;
	int cal_sum_4() const;
	int cal_sum_2_update() const;
	int cal_sum_3_update() const;
	int cal_sum_3_update2() const;
	int cal_sum_4_update() const;
	void sort(int low, int high);
	void print() const;
	friend int find(const sum& s, int target); 
};
#endif

sum类实现
#include "Sum.h"
#include <iostream>
using namespace std;
sum::sum(const vector<int>& a)
{
	data = a;
}
void sum::sort(int low, int high)
{
	if(low >= high)
		return;
	int mid = (low+high)/2;
	sort(low,mid);
	sort(mid+1,high);
	vector<int> temp;
	int l = low;
	int h = mid+1;
	while(l<=mid && h <=high)
	{
		if(data[l] > data[h])
			temp.push_back(data[h++]);
		else
			temp.push_back(data[l++]);
	}
	while(l<=mid)
		temp.push_back(data[l++]);
	while(h<=high)
		temp.push_back(data[h++]);
	for(int i=low; i<=high; i++)
	{
		data[i] = temp[i-low];
	}
}
void sum::print() const
{
	for(int i=0; i<data.size(); i++)
	{
		cout<<data[i]<<" ";
	}
	cout<<endl;
}
int find(const sum& s, int target)
{
	int low = 0;
	int high = s.data.size()-1;
	while(low <= high)
	{
		int mid = (low + high)/2;
		if(s.data[mid] < target)
		{
			low = mid+1;
		}
		else if(s.data[mid] > target)
		{
			high = mid - 1;
		}
		else
		{
			return mid;
		}
	}
	return -1;
}
int sum::cal_sum_2() const
{
	int res = 0;
	for(int i=0; i<data.size(); i++)
	{
		int j = find(*this, -data[i]);
		if(j > i)	
			res++;
	}
	return res;
}
int sum::cal_sum_3() const
{
	int res = 0;
	for(int i=0; i<data.size(); i++)
	{
		for(int j=i+1; j<data.size(); j++)
		{
			for(int p=j+1;p<data.size();p++)
			{
				if(data[i] + data[j] + data[p] == 0)
					res++;
			}
		}
	}
	return res;
}
int sum::cal_sum_4() const
{
	int res = 0;
	for(int i=0; i<data.size(); i++)
	{
		for(int j=i+1; j<data.size(); j++)
		{
			for(int p=j+1; p<data.size(); p++)
			{
				for(int q=p+1; q<data.size(); q++)
				{
					if(data[i]+data[j]+data[p]+data[q] == 0)
						res++;
				}
			}
		}
	}
	return res;
}
int sum::cal_sum_2_update() const
{
	int res = 0;
	for(int i=0,j=data.size()-1; i<j; )
	{
		if(data[i] + data[j] > 0)
			j--;
		else if(data[i] + data[j] < 0)
			i++;
		else
		{
			res++;
			j--;
			i++;
		}
	}
	return res;
}
int sum::cal_sum_3_update() const
{
	int res = 0;
	for(int i=0; i<data.size(); i++)
	{
		for(int j=i+1; j<data.size(); j++)
		{
			if(find(*this, -data[i] - data[j]) > j)
				res ++;
		}
	}
	return res;
}
int sum::cal_sum_3_update2() const
{
	int res = 0;
	for(int i=0; i<data.size(); i++)
	{
		int j=i+1;
		int p=data.size()-1;
		while(j<p)
		{
			if (data[j] + data[p] < -data[i])
				j++;
			else if(data[j] + data[p] > -data[i])
				p--;
			else
			{
				res++;
				j++;
				p--;
			}
		}
	}
	return res;
}
int sum::cal_sum_4_update() const
{
	int res = 0;
	for(int i=0; i<data.size(); i++)
	{
		for(int j=i+1; j<data.size(); j++)
		{
			for(int p=j+1; p<data.size(); p++)
			{
				if(find(*this, -data[i]-data[j]-data[p])>p)
					res++;
			}
		}
	}
	return res;
}

测试代码

#include "Sum.h"
#include <iostream>
#include <fstream>
#include <vector>
using namespace std;
void main()
{
	ifstream in("1Kints.txt");
	vector<int> a;
	while(!in.eof())
	{
		int temp;
		in>>temp;
		a.push_back(temp);
	}
	sum s(a);
	s.sort(0,a.size()-1);
	s.print();
	cout<<"s.cal_sum_2() = "<<s.cal_sum_2()<<endl;
	cout<<"s.cal_sum_2_update() = "<<s.cal_sum_2_update()<<endl;
	cout<<"s.cal_sum_3() = "<<s.cal_sum_3()<<endl;
	cout<<"s.cal_sum_3_update() = "<<s.cal_sum_3_update()<<endl;
	cout<<"s.cal_sum_3_update()2 = "<<s.cal_sum_3_update2()<<endl;
	cout<<"s.cal_sum_4() = "<<s.cal_sum_4()<<endl;
	cout<<"s.cal_sum_4_update() = "<<s.cal_sum_4_update()<<endl;
}

上述算法设计思路希望对你在今后学习算法的过程中有所帮助。





声明:该文观点仅代表作者本人,牛骨文系教育信息发布平台,牛骨文仅提供信息存储空间服务。