使用过SQL的读者应该都知道SQL存在注入的可能,即没有严格检查用户输入数据的合法性。这里不讨论SQL的注入以及防止注入,只谈一下在C++中对将要执行的SQL中的字符串参数进行转义。

最近项目中遇到一个SQL相关的问题,用户在客户端输入了一个字符串数据,这个字符串数据需要保存到数据库,但是恰好有一个用户输入了一个带引号的字符串数据,导致服务器在执行SQL语句进行存储的时候出现语法错误。

我们在实际项目中,应该会封装一个函数来对SQL语句进行格式化,比如:

 1string FormatSQL(const char* format, ...)
 2{
 3	char buf[8192];
 4
 5	va_list args;
 6	va_start(args, format);
 7	vsnprintf(buf, sizeof(buf), format, args);
 8	va_end(args);
 9
10	return buf;
11}
12
13void testSQL()
14{
15	const char* name = "witton";
16	const char* desc = "'hello";
17	string SQL = FormatSQL("insert tmp(name, desc) values('%s', '%s')", name, desc);
18	printf("%s\n", SQL.c_str());
19}

这个SQL语句最后会变成:

1insert tmp(name, desc) values('witton', ''hello')

可以看到由于没有进行转义desc中的单引号,所以导致SQL出现语法错误。

MySQL的C API提供了两个函数来进行字符串转义:

1unsigned long	STDCALL mysql_escape_string(char *to,const char *from,
2					    unsigned long from_length);
3					    
4unsigned long STDCALL mysql_real_escape_string(MYSQL *mysql,
5					       char *to,const char *from,
6					       unsigned long length);

这两个函数推荐使用mysql_real_escape_string来进行转义,因为mysql_real_escape_string会要求传入MYSQL指针,该数据结构中有相应的字符集编码,转义时可以根据设定的字符集编码来进行转义。

回到前面的FormatSQL函数,它是使用的标准C的方式来格式化的,如果有很多字符串参数,就需要对每一个字符串参数写一行代码进行转义,比如前面的testSQL函数,就需要写成如下所示代码:

 1void testSQL()
 2{
 3	const char* name = "witton";
 4	const char* desc = "'hello";
 5
 6	char bufName[256];
 7	char bufDesc[256];
 8	mysql_real_escape_string(mysql, bufName, name, strlen(name));
 9	mysql_real_escape_string(mysql, bufDesc, desc, strlen(desc));
10	
11	string SQL = FormatSQL("insert tmp(name, desc) values('%s', '%s')", bufName, bufDesc);
12	printf("%s\n", SQL.c_str());
13}

如果是新写的项目,在最开始就注意这些细节还好,如果是老项目,原来没有做这些,一个一个挨着去修改,显得非常麻烦,我现在所在的项目就是这样一种情况。

有没有办法不改变现有的对FormatSQL函数的调用方式,又能够在调用FormatSQL的过程中自动对参数中字符串进行转义呢?

答案就是对FormatSQL进行修改,使用C++的变长模板参数,在匹配到字符串时自动调用转义函数。

先来看看C++11引入的变长模板参数:

1template<typename... Args> class VarTemplate;

也可以使用在模板函数上,标准 C 中的 printf 函数, 虽然也能达成变长形参的调用,但并非类型安全。 而 C++11 除了能定义类型安全的变长参数函数外, 还可以使类似 printf 的函数能自然地处理自定义类型的对象。 除了在模板参数中能使用 … 表示变长模板参数外, 函数参数也使用同样的表示法代表变长参数, 例如前面的FormatSQL函数可以写成:

1template<typename... Args>
2void FormatSQL(const char* format, Args&& ... args);

由于在调用FormatSQL时通常会有常量作为参数,所以变长参数在传入时使用的右值引用。

我们先来定义转义函数,该函数在遇到字符串类型时则调用转义函数,否则不作转义:

 1template<typename Arg>
 2Arg& EscapeArg(Arg& arg)
 3{
 4	//非字符串不作转义
 5	return arg;
 6}
 7
 8// const char arg[]参数匹配的字符串常量,比如FormatSQL("%s", "witton")中的"witton"
 9const char* EscapeArg(const char arg[])
10{
11	size_t len = strlen(arg);
12	char* buf = new char[len * 2];
13	mysql_escape_string(buf, arg, (unsigned long)len);
14	return buf;
15}
16
17// string参数匹配的是string变量比如:
18// string str = "witton";
19// FormatSQL("%s", str);
20const char* EscapeArg(string& str)
21{
22	return EscapeArg(str.c_str());
23}

转义函数写好了,转义好了还需要把每个转义后的参数连接起来再格式化出来:

1void Concat(string& str, size_t len, const char* format, ...)
2{
3	str.resize(len);
4	char* buf = (char*)str.c_str();
5	va_list ap;
6	va_start(ap, format);
7	vsnprintf(buf, len, format, ap);
8	va_end(ap);
9}

为了提高性能,这里要求外部传入一个string来存入结果,这个string由外部来决定长度,以避免空间不足。

接下来就是实现FormatSQL函数,由于该函数是一个变长模板函数,在使用中需要对每一个参数调用EscapeArg函数来进行转义,然后再把转义后的结果依次传入Concat进行格式化。

变长模板参数Args … args在使用过程中需要展开,展开方式有两种:

一种是递归方式,比如:

 1#include <iostream>
 2template<typename T0>
 3void printf1(T0 value) {
 4    std::cout << value << std::endl;
 5}
 6template<typename T, typename... Ts>
 7void printf1(T value, Ts... args) {
 8    std::cout << value << std::endl;
 9    printf1(args...);
10}
11int main() {
12    printf1(1, 2, "123", 1.1);
13    return 0;
14}

另外一种是使用逗号表达式的方式,比如:

 1template<typename OS,typename T> void outstr(OS& o,T t)
 2{
 3	o << t;
 4}
 5template<typename... ARG> auto argcat(ARG... arg)->string
 6{
 7	ostringstream os;
 8	int arr[] = { (outstr(os,arg),0)...};
 9	return os.str();
10}
11 
12int main()
13{
14	cout << argcat(1, 2.3, "my name is", '\t',"lc") << endl;
15 
16	return 0;
17}

由于需要把转义后的结果依次传入Concat进行格式化,所以最好的方式是使用类逗号扩展方式:

1template<typename ... Args>
2string FormatSQL(const char* fmt, Args&& ... args)
3{
4	size_t len = 8192;
5	string str;
6	Concat(str, len, fmt, EscapeArg(args)...);
7	return str;
8}

这里就基本上完成了要求。细心的读者可能已经发现了问题,就是在调用EscapeArg进行字符串扩展的时候分配了内存,但是没有释放,会造成内存泄漏。所以我们需要把分配的内存地址保存下来,待Concat格式化完成后进行释放。另外,在Concat格式化时要求输入一个长度len,这里是写的固定长度,如果超过则会出问题,所以我们还需要根据参数计算一个合适长度。计算工作就一同交给EscapeArg来做。

下面就直接把最终代码附上:

 1template<typename Arg>
 2Arg& EscapeArg(size_t& len, vector<char*>& vct, Arg& arg)
 3{
 4	len += sizeof(Arg) * 8;
 5	return arg;
 6}
 7
 8const char* EscapeArg(size_t& Len, vector<char*>& vct, const char arg[])
 9{
10	size_t len = strlen(arg);
11	char* buf = new char[len * 2];
12	Len += mysql_escape_string(buf, arg, (unsigned long)len);
13	vct.push_back(buf);
14	return buf;
15}
16
17const char* EscapeArg(size_t& Len, vector<char*>& vct, string& str)
18{
19	return EscapeArg(Len, vct, str.c_str());
20}
21
22void Concat(string& str, size_t len, const char* format, ...)
23{
24	str.resize(len);
25	char* buf = (char*)str.c_str();
26	va_list ap;
27	va_start(ap, format);
28	vsnprintf(buf, len, format, ap);
29	va_end(ap);
30}
31
32template<typename ... Args>
33string FormatSQL(const char* fmt, Args&& ... args)
34{
35	vector<char*> vct;
36	size_t len = strlen(fmt);
37	string str;
38	Concat(str, len, fmt, EscapeArg(len, vct, args)...);
39	for (auto iter : vct)
40		delete[] iter;
41	return str;
42}

该代码在VS2015、GCC 4.9.3、Clang 9下编译测试通过。GCC以及Clang需要添加参数-std=c++11。